from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import random
import cv2
import numpy as np
from Queue import Queue
from threading import Thread
from yolo.dataset.dataset import DataSet
class TextDataSet(DataSet):
"""TextDataSet
process text input file dataset
text file format:
image_path xmin1 ymin1 xmax1 ymax1 class1 xmin2 ymin2 xmax2 ymax2 class2
"""
def __init__(self, common_params, dataset_params):
"""
Args:
common_params: A dict
dataset_params: A dict
"""
# process params
self.data_path = str(dataset_params['path'])# data/pascal_voc.txt
self.width = int(common_params['image_size'])# 448
self.height = int(common_params['image_size'])# 448
self.batch_size = int(common_params['batch_size'])# 16
self.thread_num = int(dataset_params['thread_num'])# 5
self.max_objects = int(common_params['max_objects_per_image'])# 20
#record and image_label queue
self.record_queue = Queue(maxsize=10000)
self.image_label_queue = Queue(maxsize=512)
# 存储[image_path xmin ymin xmax ymax class]的列表
self.record_list = []
# filling the record_list
input_file = open(self.data_path, 'r')
for line in input_file:
# 去除后面的'\n'
line = line.strip()
ss = line.split(' ')
# string--float
ss[1:] = [float(num) for num in ss[1:]]
self.record_list.append(ss)
self.record_point = 0
# 图片的数量
self.record_number = len(self.record_list)
self.num_batch_per_epoch = int(self.record_number / self.batch_size)
# t_record_producer这个线程是从record_list(总图片数)装载数据到record_queue(10000),装载一遍后shuffle
t_record_producer = Thread(target=self.record_producer)
t_record_producer.daemon = True
t_record_producer.start()
# t_record_producer这个线程是从record_queue(10000)(总图片数)将数据从将[image_path xmin ymin xmax ymax class]转换成[image_data labels object_num],装载到image_label_queue(512)
for i in range(self.thread_num):
t = Thread(target=self.record_customer)
t.daemon = True
t.start()
def record_producer(self):
"""record_queue's processor
"""
# 队列生成器 record_queue存储10000个[image_path xmin1 ymin1 xmax1 ymax1 class1 xmin2 ymin2 xmax2 ymax2 class2]这样的数据
while True:
if self.record_point % self.record_number == 0:
random.shuffle(self.record_list)
self.record_point = 0
self.record_queue.put(self.record_list[self.record_point])
self.record_point += 1
def record_process(self, record):
"""record process
Args: record
Returns:
image: 3-D ndarray
labels: 2-D list [self.max_objects, 5] (xcenter, ycenter, w, h, class_num)
object_num: total object number int
"""
image = cv2.imread(record[0])
# 转换颜色空间,图片矩阵BGR,B通道与R通道交换???
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h = image.shape[0]
w = image.shape[1]
width_rate = self.width * 1.0 / w
height_rate = self.height * 1.0 / h
# 照片resize 448*448
image = cv2.resize(image, (self.height, self.width))
# 生成长度为20的列表,列表的元素是[0,0,0,0,0]
labels = [[0, 0, 0, 0, 0]] * self.max_objects
i = 1
object_num = 0
while i < len(record):
xmin = record[i]
ymin = record[i + 1]
xmax = record[i + 2]
ymax = record[i + 3]
class_num = record[i + 4]
# boxes中心坐标
xcenter = (xmin + xmax) * 1.0 / 2 * width_rate
ycenter = (ymin + ymax) * 1.0 / 2 * height_rate
# boxes宽度,高度
box_w = (xmax - xmin) * width_rate
box_h = (ymax - ymin) * height_rate
labels[object_num] = [xcenter, ycenter, box_w, box_h, class_num]
object_num += 1
i += 5
if object_num >= self.max_objects:
break
# 返回一个图像的数据,labels(20*5),目标数目
return [image, labels, object_num]
def record_customer(self):
"""record queue's customer
"""
# 将[image_path xmin ymin xmax ymax class]转换成[image_data labels object_num]
while True:
item = self.record_queue.get()
out = self.record_process(item)
self.image_label_queue.put(out)
def batch(self):
"""get batch
Returns:
images: 4-D ndarray [batch_size, height, width, 3]
labels: 3-D ndarray [batch_size, max_objects, 5]
objects_num: 1-D ndarray [batch_size]
"""
images = []
labels = []
objects_num = []
for i in range(self.batch_size):
image, label, object_num = self.image_label_queue.get()
images.append(image)
labels.append(label)
objects_num.append(object_num)
images = np.asarray(images, dtype=np.float32)
# 归一化 ???为啥乘以2嘞???
images = images/255 * 2 - 1
labels = np.asarray(labels, dtype=np.float32)
objects_num = np.asarray(objects_num, dtype=np.int32)
return images, labels, objects_num
nillboy/yolo代码解读4:/yolo/dataset/text_dataset.py
猜你喜欢
转载自blog.csdn.net/weixin_38900691/article/details/79587448
今日推荐
周排行