『RNN』SEQ2SEQ

一.编解码网络

  • SEQ2SEQ:前半部分与后半部分都为LSTM。输出和标签做损失。(I–>C–>O,O和label做损失)
  • AE(自编码模型):前半部分与后半部分都为CNN或转置卷积。输出和自己做损失。(I–>C–>O,O和I做损失)
  • VAE:使用少。被GAN代替。难,效果又不好。
  • GAN:优秀。

二.SEQ2SEQ

在这里插入图片描述
在这里插入图片描述

三.验证码的识别

1.形状变换

  • 首先,将NCHW的图片按照序列裁剪。如下:原始图片大小w,h:240,60;原始图片格式:NCHW;将NCHW–>NSV
  • NSV中,N为1,S为240。因为W=60不变,所以V为60x3(3个通道。RNN中三个通道同时往前走)。最终变为(N,240,180)。

在这里插入图片描述

  • 然后,接全连接。使用全连接将图片变为特征。做全连接之前,要将结构NSV–>NV结构。(N,240,180)–>(Nx240,180)。**(Nx240,180)x(180,128)**的权重得到(Nx240,128)的全连接输出结果。
  • 接着,将两个维度变成三个维度。(Nx240,128)–>(N,240,128)
  • 接着,再接LSTM。将**(N,240,128)传入RNN**中,得到结果C:(N,-1,128)。其中,-1是一个,不用管。即得到编码结果C:(N,128)。
  • 接着,将特征(N,128)传入LSTM中进行解码。上图中,如果0051做onehot编码时,需要(4,10)的维度(4个字,每个字有10个概率,加上批次,为(N,4,10),这就需要将(N,128)变为和(N,4,10)一样的结果才能做损失。(N,128)–>(N,1,128)通过广播–>(N,4,128)通过全连接((Nx4,128)x(128,10)–>(Nx4,10)–>(N,4,128))–>(N,4,128)。这时,即可做损失,最后计算得到结果。

2.代码

生成–采样–(网络–训练)

(1)生成图片

  • 图片大小60x240,通道3。
  • 图片上4个数字,0-9之间随机获取。
  • 4个数字颜色随机。数字颜色人为前景色。
  • 图片背景,每个像素颜色也随机。
  • 步骤:
  • 首先,生成一张白色的图片,再使用循环,将获取的每个像素按照从上往下,从左往右一次填充。
  • 接着,随机生成4个数字,随机给定颜色,并按照一定距离间隔写入图片上。
  • 最后,加上模糊效果。
  • 随机3个东西:前景色、背景色、数字颜色。
from PIL import Image,ImageDraw,ImageFont,ImageFilter
import random
import os

#随机数字
def ranNum():
    a = str(random.randint(0,9))
    a = chr(random.randint(48,57))
    b=chr(random.randint(65,90))#大写字母
    c=chr(random.randint(97,122))#小写字母
    d=ord(a)#将字母转化成对应的ascii码

    # print(d)
    return a

#随机颜色1(背景色稍微亮)或反过来
def ranColor1():
    return (random.randint(65,255),
            random.randint(65, 255),
            random.randint(65, 255))
#随机颜色2(前景色稍微暗)或反过来
def ranColor2():
    return (random.randint(32,127),
            random.randint(32, 127),
            random.randint(32, 127))
#240*60
w= 240
h = 60

font = ImageFont.truetype("arial.ttf",40)
for i in range(1000):

    image = Image.new("RGB", (w, h), (255, 255, 255))
    draw = ImageDraw.Draw(image)

    for x in range(w):
        for y in range(h):
            draw.point((x,y),fill=ranColor1())

    filename = ""
    for j in range(4):
        ch = ranNum()
        filename+=ch
        draw.text((60*j+10,10),(ch),font=font,fill=ranColor2())

    # 模糊:
    image = image.filter(ImageFilter.BLUR)
    # image.show()
    if not os.path.exists("./code"):
        os.makedirs("./code")
    image_path = r"./code"
    image.save("{0}/{1}.jpg".format(image_path,filename))
    print(i)

(2)采样

import os
import torch
import numpy as np
from PIL import Image
import torch.utils.data as data
from torchvision import transforms

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
])


class Sampling(data.Dataset):
    def __init__(self,root):
        self.transform = data_transforms
        self.imgs = []
        self.labels = []
        for filenames in os.listdir(root):
            x = os.path.join(root,filenames)
            y = filenames.split('.')[0]#第0个是4个字符,第1个是jpg
            # print(y)
            self.imgs.append(x)
            self.labels.append(y)


    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, index):
        img_path = self.imgs[index]
        img = Image.open(img_path)
        img = self.transform(img)
        label = self.labels[index]
        label = self.one_hot(label)
        return img,label
    def one_hot(self,x):
        z = np.zeros(shape=[4,10])
        for i in range(4):
            index = int(x[i])#传进来四个值(如:5286)的第i个
            z[i][index] = 1#如:0数组中第一行中第5个变为1。(如:5286)
        return z

if __name__ == '__main__':
    samping = Sampling("./code")
    dataloader = data.DataLoader(samping,10,shuffle=True,drop_last=True)
    for i,(img,label) in enumerate(dataloader):
        print(i)
        print(img.shape)
        # print(label)
        print(label.shape)

(3)网络和训练

  • 编码

  • 形状与位置变化:(N,3,60,240)–>(N,V,S)即(N,180,240)–>(N,240,180)

x = x.reshape(-1,180,240).permute(0,2,1)
  • 形状变换:(N,240,180)–>(Nx240,180)
x = x.reshape(-1,180)
  • 线性变换:(Nx240,180)x**(180,128)**–>(Nx240,128)
self.fc1 = nn.Sequential(
	nn.Linear(180,128),
	nn.BatchNorm1d(num_features=128),
	nn.ReLU()
	)
fc1 = self.fc1(x)
  • 形状变换:(Nx240,128)–>(N,S,V)
fc1 = fc1.reshape(-1, 240, 128)
  • 传入LSTM
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)
lstm,(h_n,h_c) = self.lstm(fc1,None)#None:初始化状态没有。
  • 解码

  • 变形:(N,128)–>(N,1,128)

x = x.reshape(-1,1,128)
  • 广播:(N,1,128)–>(N,4,128)
x = x.expand(-1,4,128)
  • 传入LSTM:(N,S,V)
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)
lstm,(h_n,h_c) = self.lstm(x,None)
  • 变形:(N,S,V)–>(N,128)
y1 = lstm.reshape(-1,128)
  • 全连接:(N,128)x(128,10)–>(N,10)
self.out = nn.Linear(128,10)
out = self.out(y1)
  • 变形:(N,10)–>(N,4,10)。10个概率值。
output = out.reshape(-1,4,10)
  • 网络代码:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(180,128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)

    def forward(self, x):
        x = x.reshape(-1,180,240).permute(0,2,1)
        x = x.reshape(-1,180)
        fc1 = self.fc1(x)
        fc1 = fc1.reshape(-1, 240, 128)
        lstm,(h_n,h_c) = self.lstm(fc1,None)
        out = lstm[:,-1,:]

        return out


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)
        self.out = nn.Linear(128,10)

    def forward(self,x):
        x = x.reshape(-1,1,128)
        x = x.expand(-1,4,128)
        lstm,(h_n,h_c) = self.lstm(x,None)
        y1 = lstm.reshape(-1,128)
        out = self.out(y1)
        output = out.reshape(-1,4,10)
        return output


class MainNet (nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)

        return decoder
  • 训练代码:
if __name__ == '__main__':
    BATCH = 32
    EPOCH = 100
    save_path = r'params/seq2seq.pth'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = MainNet().to(device)
    if os.path.exists(os.path.join(save_path)):
        net.load_state_dict(torch.load(save_path))
    opt = torch.optim.Adam(net.parameters())
    loss_func = nn.MSELoss()

    if os.path.exists(save_path):
        net.load_state_dict(torch.load(save_path))
    else:
        print("No Params!")

    train_data = Sampling_train_num.Sampling(root="./code")
    train_loader = data.DataLoader(dataset=train_data,
    batch_size=BATCH, shuffle=True, drop_last=True,num_workers=4)

    losses = []
    for epoch in range(EPOCH):

        for i, (x, y) in enumerate(train_loader):
            batch_x = x.to(device)
            batch_y = y.float().to(device)

            output = net(batch_x)
            loss = loss_func(output,batch_y)

            opt.zero_grad()
            loss.backward()
            opt.step()


            if i % 5 == 0:
                losses.append(loss.float())
                label_y = torch.argmax(y,2).detach().numpy()#得到的Y为(N,4,10),取最大值索引,即,取第二轴的索引。得到格式为(N,4)
                out_y = torch.argmax(output,2).cpu().detach().numpy()

                accuracy = np.sum(out_y == label_y,dtype=np.float32)/(BATCH * 4)#每个字符的准确率
                print("epoch:{},i:{},loss:{:.4f},acc:{:.2f}%".format(epoch,i,loss.item(),accuracy * 100))
                print("label_y:",label_y[0])
                print("out_y:",out_y[0])


                plt.clf()
                plt.plot(losses)
                plt.pause(0.01)


        torch.save(net.state_dict(), save_path)

四.几个好的博客

https://www.jianshu.com/p/b2b95f945a98

https://zhuanlan.zhihu.com/p/40920384

发布了29 篇原创文章 · 获赞 45 · 访问量 5036

猜你喜欢

转载自blog.csdn.net/sinat_39783664/article/details/104582558