tensorflow keras数据集的读取 fit_generator的使用,以及模型编译保存

一、数据集的样式以及读取函数

数据集以x,y的形式分别保存检测图像和标签,其中X存放png和jpg格式的图像
读取的时候用model.fit_generator函数载入数据集,关键点则在于生成器的构造

二、步骤

1.制作一个数据生成器代码
2.使用yield返回值
3.接受值并给予model.fit_generator函数

三、代码(类)

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping

import cv2
import PIL
import json, os
import sys
from PIL import Image
import labelme
import labelme.utils as utils
import glob

以下是整个模型类的一部分 ,方便理解就基本都贴出来了

class Net(): 
	  def __init__(self):#存储列表
      self.input_width=input_width
      self.input_height=input_height
      self.num_classes=num_classes
      self.train_images=train_images
      self.train_instances=train_instances
      self.val_images=val_images
      self.epochs=epochs
      self.lr=lr
      self.lr_decay=lr_decay
      self.batch_size=batch_size
      self.save_path=save_path
    def build_model():#定义模型的生成样式
    pass
    
    def train(self):#############训练的方法
    G_train = self.dataGenerator(mode='training')
    G_eval  = self.dataGenerator(mode='validation')
    model =self.build_mode()#构建模型的方法 具体可以看keras官方文档
    model.summary()
    model.compile(#模型的编译
      optimizers=keras.optimizers.Adam(self.lr,self.lr_decay),
      loss = 'categorical_crossentropy',
      metrics=['categorical_accuracy','recall','AUC']
    )
    #使用model.fit_generator载入函数,必须有个数据生成器不断读取函数
   model.fit_generator(G_train,5,validation_data=G_eval,validation_steps=5,epochs=self.epochs)
#保存模型
    model.save(self.save_path)



    #数据生成器函数
  def dataGenerator(self,mode):
    if mode =='training':#训练集
   #读取文件
      images = glob.glob(self.train_images+'*.jpg')#读取列表
      images.sort()#排序
      instances= glob.glob(self.train_instances +'*.png')
      instances.sort()
      zipped = inertools.cycle(zip(images,instances))#用zip包装,cycle循环
      while True :
        x_train=[]#必须定义个空集,使张量量的维度增加一维
        y_train=[]
        for _ in range(self.batch_size):
          img,seg = next(zipped)
          img = cv2. resize(cv2.imread(img,1),(self.input_width,self.input_height))
          seg = keras.utils.to_categorical(cv2.imread(seg,0),num_classes=self.num_classes)
          x_train.append(img)
          y_train.append(seg)
        yield np.array(x_train),np.array(y_train)#使用yield返回值
    if mode == 'validation':#测试集同上,读取的地方不一样
      images = glob.glob(self.train_images + '*.jpg')
      images.sort()
      instances = glob.glob(self.train_instances + '*.png')
      instances.sort()
      zipped = inertools.cycle(zip(images, instances))
      while True:
        x_eval = []
        y_eval = []
        for _ in range(self.batch_size):
          img, seg = next(zipped)
          img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height))
          seg = keras.utils.to_categorical(cv2.imread(seg, 0), num_classes=self.num_classes)
          x_eval.append(img)
          y_eval.append(seg)
        yield np.array(x_eval), np.array(y_eval)
       
发布了11 篇原创文章 · 获赞 3 · 访问量 822

猜你喜欢

转载自blog.csdn.net/qq_44930937/article/details/104565155