split_train_val.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # coding:utf-8
  2. # brief:根据指定目录下的json配置文件,将其随机分为训练集和测试集,并将路径写到dataSet目录下的txt文件中
  3. import os
  4. import random
  5. import json
  6. # directory-主路径
  7. # fileType-指定文件类型
  8. # fileList-目标类型文件列表(路径+文件名)
  9. def SearchFiles(directory, fileType):
  10. fileList = []
  11. for root, subDirs, files in os.walk(directory):
  12. for fileName in files:
  13. if fileName.endswith(fileType):
  14. # json_file = open(directory + '/' + fileName, 'r', encoding='UTF-8')
  15. # json_data = json.load(json_file)
  16. # jsonName = json_data['imagePath']
  17. # jsonName = jsonName.replace('汽车 ', 'car')
  18. # jsonName = jsonName.replace('(', '_')
  19. # jsonName = jsonName.replace(')', '')
  20. # json_data['imagePath'] = jsonName
  21. # json_file = open(directory + '/' + fileName, 'w', encoding='UTF-8')
  22. # json.dump(json_data, json_file, indent=2, ensure_ascii=False)
  23. # print(jsonName)
  24. fileList.append(fileName)
  25. # for fileName in fileList:
  26. # if fileName.find('汽车 '):
  27. # newName = fileName.replace('汽车 ', 'car')
  28. # newName = newName.replace('(', '_')
  29. # newName = newName.replace(')', '')
  30. # os.rename(fileName, newName)
  31. return fileList
  32. if __name__ == '__main__':
  33. run_path = 'D:/DeepLearning/pytorch-gpu117/xm_lidar/'
  34. last_file_path = '/home/zx/doc/private_hub/yolov8/ultralytics-main/examples/train_xm_lidar/labels/'
  35. # last_file_path = 'D:/DeepLearning/pytorch-gpu117/yolov8_study/train-seg/labels/'
  36. txt_save_path = run_path + 'dataSet'
  37. if not os.path.exists(txt_save_path):
  38. os.makedirs(txt_save_path)
  39. trainval_percent = 1
  40. train_percent = 0.9
  41. total_json = SearchFiles(run_path + 'labels', '.txt')
  42. print(total_json)
  43. num = len(total_json)
  44. list_index = range(num)
  45. tv = int(num * trainval_percent)
  46. tr = int(tv * train_percent)
  47. trainval = random.sample(list_index, tv)
  48. train = random.sample(trainval, tr)
  49. file_trainval = open(txt_save_path + '/trainval.txt', 'w')
  50. file_test = open(txt_save_path + '/test.txt', 'w')
  51. file_train = open(txt_save_path + '/train.txt', 'w')
  52. file_val = open(txt_save_path + '/val.txt', 'w')
  53. for i in list_index:
  54. name = last_file_path + total_json[i][:-3] + 'jpg\n'
  55. if i in trainval:
  56. file_trainval.write(name)
  57. if i in train:
  58. file_train.write(name)
  59. else:
  60. file_val.write(name)
  61. else:
  62. file_test.write(name)
  63. file_trainval.close()
  64. file_train.close()
  65. file_val.close()
  66. file_test.close()