split_train_val.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # coding:utf-8
  2. # brief:根据指定目录下的json配置文件,将其随机分为训练集和测试集,并将路径写到dataSet目录下的txt文件中
  3. import os
  4. import random
  5. import json
  6. # directory-主路径
  7. # fileType-指定文件类型
  8. # fileList-目标类型文件列表(路径+文件名)
  9. def SearchFiles(directory, fileType):
  10. fileList=[]
  11. for root, subDirs, files in os.walk(directory):
  12. for fileName in files:
  13. if fileName.endswith(fileType):
  14. # json_file = open(directory + '/' + fileName, 'r', encoding='UTF-8')
  15. # json_data = json.load(json_file)
  16. # jsonName = json_data['imagePath']
  17. # jsonName = jsonName.replace('汽车 ', 'car')
  18. # jsonName = jsonName.replace('(', '_')
  19. # jsonName = jsonName.replace(')', '')
  20. # json_data['imagePath'] = jsonName
  21. # json_file = open(directory + '/' + fileName, 'w', encoding='UTF-8')
  22. # json.dump(json_data, json_file, indent=2, ensure_ascii=False)
  23. # print(jsonName)
  24. fileList.append(directory + '/' + fileName)
  25. # for fileName in fileList:
  26. # if fileName.find('汽车 '):
  27. # newName = fileName.replace('汽车 ', 'car')
  28. # newName = newName.replace('(', '_')
  29. # newName = newName.replace(')', '')
  30. # os.rename(fileName, newName)
  31. return fileList
  32. if __name__ == '__main__':
  33. run_path = 'D:/DeepLearning/pytorch-gpu117/yolov8_study/train/'
  34. txt_save_path = run_path + 'dataSet'
  35. if not os.path.exists(txt_save_path):
  36. os.makedirs(txt_save_path)
  37. trainval_percent = 1
  38. train_percent = 0.9
  39. total_json = SearchFiles(run_path + 'labels', '.jpg')
  40. print(total_json)
  41. num = len(total_json)
  42. list_index = range(num)
  43. tv = int(num * trainval_percent)
  44. tr = int(tv * train_percent)
  45. trainval = random.sample(list_index, tv)
  46. train = random.sample(trainval, tr)
  47. file_trainval = open(txt_save_path + '/trainval.txt', 'w')
  48. file_test = open(txt_save_path + '/test.txt', 'w')
  49. file_train = open(txt_save_path + '/train.txt', 'w')
  50. file_val = open(txt_save_path + '/val.txt', 'w')
  51. for i in list_index:
  52. name = total_json[i] + '\n'
  53. if i in trainval:
  54. file_trainval.write(name)
  55. if i in train:
  56. file_train.write(name)
  57. else:
  58. file_val.write(name)
  59. else:
  60. file_test.write(name)
  61. file_trainval.close()
  62. file_train.close()
  63. file_val.close()
  64. file_test.close()