在机器学习和深度学习中创建属于自己的数据集

def CreateDataSet(file_path):

    """ demo :

        file_path: ./datasets/
        datasets/
            train/
                Classification_1/
                    img_1.jpg
                    img_2.jpg
                    img_3.jpg
                    ...
                Classification_2/
                    img_1.jpg
                    img_2.jpg
                    img_3.jpg
                    ...
                Classification_3/
                    ...
            val/
                Classification_1/
                ...
                Classification_2/
                ...
            test/
                ...			"""

    class_train_filename = os.listdir(file_path + 'train' + '/')
    class_val_filename = os.listdir(file_path + 'val' + '/')
    class_test_filename = os.listdir(file_path + 'test' + '/')

    train_data, val_data, test_data = [], [], []
    train_label, val_label, test_label = [], [], []


    for index in range(len(class_train_filename)):
        path = file_path + 'train' + '/' + class_train_filename[index]
        dir_name_list = os.listdir(path + '/')
        for item in dir_name_list:
            img_path = path + '/' + item
            item_image = cv2.imread(img_path)       # 读取图片数据信息
            train_data.append(item_image)
            train_label.append(class_train_filename[index])

    for index in range(len(class_val_filename)):
        path = file_path + 'val' + '/' + class_val_filename[index]
        dir_name_list = os.listdir(path + '/')
        for item in dir_name_list:
            img_path = path + '/' + item
            item_image = cv2.imread(img_path)
            val_data.append(item_image)
            val_label.append(class_val_filename[index])

    for index in range(len(class_test_filename)):
        path = file_path + 'test' + '/' + class_test_filename[index]
        dir_name_list = os.listdir(path + '/')
        for item in dir_name_list:
            img_path = path + '/' + item
            item_image = cv2.imread(img_path)
            test_data.append(item_image)
            test_label.append(class_test_filename[index])


    return train_data, train_label, \
           val_data, val_label, \
           test_data, test_label

猜你喜欢

转载自blog.csdn.net/Stybill_LV_/article/details/110847179