nillboy/yolo代码解读4:/yolo/dataset/text_dataset.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import random
import cv2
import numpy as np
from Queue import Queue 
from threading import Thread

from yolo.dataset.dataset import DataSet 

class TextDataSet(DataSet):
  """TextDataSet
  process text input file dataset 
  text file format:
    image_path xmin1 ymin1 xmax1 ymax1 class1 xmin2 ymin2 xmax2 ymax2 class2
  """

  def __init__(self, common_params, dataset_params):
    """
    Args:
      common_params: A dict
      dataset_params: A dict
    """
    # process params
    self.data_path = str(dataset_params['path'])# data/pascal_voc.txt 
    self.width = int(common_params['image_size'])# 448
    self.height = int(common_params['image_size'])# 448    
    self.batch_size = int(common_params['batch_size'])# 16    
    self.thread_num = int(dataset_params['thread_num'])# 5    
    self.max_objects = int(common_params['max_objects_per_image'])# 20

    #record and image_label queue
    self.record_queue = Queue(maxsize=10000)
    self.image_label_queue = Queue(maxsize=512)
    # 存储[image_path xmin ymin xmax ymax class]的列表
    self.record_list = []  

    # filling the record_list
    input_file = open(self.data_path, 'r')

    for line in input_file:
      # 去除后面的'\n'
      line = line.strip()
      ss = line.split(' ')
      # string--float
      ss[1:] = [float(num) for num in ss[1:]]
      self.record_list.append(ss)

    self.record_point = 0
    # 图片的数量
    self.record_number = len(self.record_list)

    self.num_batch_per_epoch = int(self.record_number / self.batch_size)
    # t_record_producer这个线程是从record_list(总图片数)装载数据到record_queue(10000),装载一遍后shuffle
    t_record_producer = Thread(target=self.record_producer)
    t_record_producer.daemon = True 
    t_record_producer.start()
    # t_record_producer这个线程是从record_queue(10000)(总图片数)将数据从将[image_path xmin ymin xmax ymax class]转换成[image_data labels object_num],装载到image_label_queue(512)
    for i in range(self.thread_num):
      t = Thread(target=self.record_customer)
      t.daemon = True
      t.start() 

  def record_producer(self):
    """record_queue's processor
    """
    # 队列生成器 record_queue存储10000个[image_path xmin1 ymin1 xmax1 ymax1 class1 xmin2 ymin2 xmax2 ymax2 class2]这样的数据
    while True:
      if self.record_point % self.record_number == 0:
        random.shuffle(self.record_list)
        self.record_point = 0
      self.record_queue.put(self.record_list[self.record_point])
      self.record_point += 1

  def record_process(self, record):
    """record process 
    Args: record 
    Returns:
      image: 3-D ndarray
      labels: 2-D list [self.max_objects, 5] (xcenter, ycenter, w, h, class_num)
      object_num:  total object number  int 
    """
    image = cv2.imread(record[0])
    # 转换颜色空间,图片矩阵BGR,B通道与R通道交换???
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h = image.shape[0]
    w = image.shape[1]

    width_rate = self.width * 1.0 / w 
    height_rate = self.height * 1.0 / h 
    # 照片resize 448*448
    image = cv2.resize(image, (self.height, self.width))
    # 生成长度为20的列表,列表的元素是[0,0,0,0,0]
    labels = [[0, 0, 0, 0, 0]] * self.max_objects
    i = 1
    object_num = 0
    while i < len(record):
      xmin = record[i]
      ymin = record[i + 1]
      xmax = record[i + 2]
      ymax = record[i + 3]
      class_num = record[i + 4]
      # boxes中心坐标
      xcenter = (xmin + xmax) * 1.0 / 2 * width_rate
      ycenter = (ymin + ymax) * 1.0 / 2 * height_rate
      # boxes宽度,高度
      box_w = (xmax - xmin) * width_rate
      box_h = (ymax - ymin) * height_rate

      labels[object_num] = [xcenter, ycenter, box_w, box_h, class_num]
      object_num += 1
      i += 5
      if object_num >= self.max_objects:
        break
    # 返回一个图像的数据,labels(20*5),目标数目
    return [image, labels, object_num]

  def record_customer(self):
    """record queue's customer 
    """
    # 将[image_path xmin ymin xmax ymax class]转换成[image_data labels object_num]
    while True:
      item = self.record_queue.get()
      out = self.record_process(item)
      self.image_label_queue.put(out)

  def batch(self):
    """get batch
    Returns:
      images: 4-D ndarray [batch_size, height, width, 3]
      labels: 3-D ndarray [batch_size, max_objects, 5]
      objects_num: 1-D ndarray [batch_size]
    """
    images = []
    labels = []
    objects_num = []
    for i in range(self.batch_size):
      image, label, object_num = self.image_label_queue.get()
      images.append(image)
      labels.append(label)
      objects_num.append(object_num)
    images = np.asarray(images, dtype=np.float32)
    # 归一化 ???为啥乘以2嘞???
    images = images/255 * 2 - 1
    labels = np.asarray(labels, dtype=np.float32)
    objects_num = np.asarray(objects_num, dtype=np.int32)
    return images, labels, objects_num

猜你喜欢

转载自blog.csdn.net/weixin_38900691/article/details/79587448
今日推荐