程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

python劃分數據集並使各類別的數目相近

編輯:Python

最近項目拿到了一個別人標注但沒有劃分的數據集,有13類,不過經過統計發現各類別的數目差距較大,最多的一類有五萬多張圖片,最少的一類只有兩千多張,如果使用傳統的劃分方法,對所有的數據進行隨機劃分,將會導致樣本嚴重不均衡的問題,甚至可能出現訓練集中不存在某一類圖片,因此考慮以最少的一類圖片數目為基准,對每一類都選擇兩千張左右的圖片,並且使用蓄水池算法保證選取的隨機性,考慮到同一張圖片中可能存在多個目標,並且目標也不一定是同類,因此對每一張圖片的標注文件只參考其第一個標注的目標類別(如果標注文件中有沒有標注的目標,需要先判斷),最後對每一類圖片按照數據集劃分的比例隨機劃分到訓練集、驗證集、測試集中,雖然無法保證最終劃分的數據集每一類圖片數目非常相近,但大致差別不會太大,並且保證了訓練集、驗證集、測試集中每一類都會存在一定數目的圖片。

import os
import xml.dom.minidom
import random
master_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))
data_root = os.path.join(master_root, "name of your dataset") # data_root = os.path.join(master_root, "coco")
ImageSets_path = os.path.join(data_root, "ImageSets/Main")
train_txt_path = os.path.join(ImageSets_path, "train.txt")
val_txt_path = os.path.join(ImageSets_path, "val.txt")
test_txt_path = os.path.join(ImageSets_path, "test.txt")
none_tag_path = os.path.join(ImageSets_path, "none_tag.txt")
xml_path = os.path.join(data_root, "Annotations/")
classes = ['classes of your dataset']
# classes = [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat']
files = os.listdir(xml_path)
# 蓄水池抽樣算法
def add(list, size, len, file):
if (len < size):
list.append(file)
else:
i = random.randint(0, len)
if (i < size):
list[i] = file
def create_imagesets_train_val_test(lists, traintxt_full_path, valtxt_full_path, testtxt_full_path):
# 訓練集比例
train_percent = 0.6
# 驗證集比例
val_percent = 0.2
# 測試集比例
test_percent = 0.2
ftrain = open(traintxt_full_path, 'w')
fval = open(valtxt_full_path, 'w')
ftest = open(testtxt_full_path, 'w')
trainList = []
valList = []
testList = []
for list in lists:
num = len(list)
num_train = int(num * train_percent) # 訓練集個數
num_val = int(num * val_percent) # 驗證集個數
# 隨機選num_train個train文件
train_list = random.sample(list, num_train)
for i in train_list:
trainList.append(i)
list.remove(i)
val_list = random.sample(list, num_val)
for j in val_list:
valList.append(j)
list.remove(j)
test_list = list
for k in test_list:
testList.append(k)
trainList.sort()
valList.sort()
testList.sort()
for i in trainList:
ftrain.write(i) # train.txt文件寫入
for j in valList:
fval.write(j) # val.txt文件寫入
for k in testList:
ftest.write(k) # test.txt文件寫入
ftrain.close() # 關閉train.txt
fval.close() # 關閉val.txt
ftest.close() # 關閉test.txt
if __name__ == '__main__':
lists = [[] for i in range(len(classes))]
sizes = []
length = [0 for i in range(len(classes))]
for i in range(len(classes)):
sizes.append(random.randint(1950, 2250)) # 大概數目
# 記錄沒標注的圖片
none_tag = []
none = open(none_tag_path, 'w')
# 遍歷所有標注文件
for file in files:
xmlfile = xml_path + file
dom = xml.dom.minidom.parse(xmlfile) # 讀取xml文檔
root = dom.documentElement # 得到文檔元素對象
objectlist = root.getElementsByTagName("object")
if len(objectlist) == 0:
none_tag.append(os.path.splitext(file)[0] + '\n')
else:
# 如果有標注就按第一個標注的對象分類
object = objectlist[0]
namelist = object.getElementsByTagName("name")
objectname = namelist[0].childNodes[0].data
if objectname in classes:
cls_id = classes.index(objectname)
add(lists[cls_id], sizes[cls_id], length[cls_id], os.path.splitext(file)[0] + '\n') # 使用蓄水池算法實現隨機選取樣本
length[cls_id] += 1
for n in none_tag:
none.write(n) # none_tag.txt文件寫入
none.close() # 關閉none_tag.txt
create_imagesets_train_val_test(lists, train_txt_path, val_txt_path, test_txt_path)

  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved