利用ADDA实现无监督迁移学习

标题

ADDA

本次仿真是根据论文Adversarial Discriminative Domain Adaptation和在此基础上提出的ATLA进行的。由于ATLA最后采取分类器的分类结果输出,一般维度只有num_classes,感觉对于discriminator来说太少了,很难达到对抗迁移的效果。所以本次仿真在ADDA的基础上取前面CNN网络作为特征提取器,其输出为特征直接输入到判别器上。实验过程中发现由于源数据和目标数据的长度可能不一样,暂时想到先对目标数据进行插值,实现长度相等,经过相同结构的CNN网络投影到同一个特征空间。但是ADDA中提出非对称网络结构会有更好的效果
在这里插入图片描述

代码

代码参考了https://github.com/jvanvugt/pytorch-domain-adaptation/,该作者使用了对称的神经网络(相同)作为特征提取,其数据大小也是一样的,不需要考虑投影空间和输入信号尺寸的关系。

网络部分

import torch
import torch.nn as nn


class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.ZeroPad2d((2, 2, 0, 0)),
            nn.Conv2d(in_channels=1, out_channels=50, kernel_size=(1, 8), stride=1, padding='valid'),
            nn.ReLU(),
            nn.BatchNorm2d(50),
            nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(2, 8), stride=1, padding='valid'),
            nn.ReLU(),
            nn.BatchNorm2d(50),
            nn.Flatten(1),
        )
        self.classifier = nn.Sequential(
            nn.Linear(5900, 256),
            nn.Dropout(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
        # batch_size = x.shape[0]
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        features = self.feature_extractor(x)
        logits = self.classifier(features)

        return logits


class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.bn = nn.BatchNorm1d(500)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dense1 = nn.Linear(in_features=5900, out_features=1024)
        # self.dense2 = nn.Linear(in_features=1024, out_features=256)
        self.dense3 = nn.Linear(in_features=1024, out_features=1)


    def forward(self, x):
        out = self.dense1(x)
        out = self.relu(out)
        out = self.dense2(out)
        out = self.relu(out)
        out = self.dense3(out)
        out = self.sigmoid(out)
        return out

训练源模型

from torch.utils.data import DataLoader
from load_data import load_data
from train import *
from net import *
import torch.utils.data as Data
import torch.nn as nn
import torch
import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
file_path = '../data/RML2016.10a/RML2016.10a_dict.pkl'

ratio_s = 0.7
batch_size = 256
itrs = 100
learning_rate = 0.001

x_s, x_test, y_s, y_test, snrs, classes = load_data(file_path, device, ratio_s)
train_dataset = Data.TensorDataset(x_s, y_s)
test_dataset = Data.TensorDataset(x_test, y_test)

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)


model = model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50)
loss_fun = nn.CrossEntropyLoss()

best_acc = 0

########## training ##########
for itr in tqdm.trange(1, itrs+1):
    s_loss, s_acc = train(model, train_dataloader, itr, optimizer, loss_fun)
    test_acc = test(model, test_dataloader)

    if test_acc > best_acc:
        best_acc = test_acc
        print("best acc: {}".format(best_acc))
        torch.save(model, "source_model.pth")

    lr_scheduler.step(itr)

ADDA代码

import datetime
from torch.utils.data import DataLoader
from load_data import load_data
from train import *
from net import *
import torch.utils.data as Data
import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import trange

time = datetime.datetime.now()
month = time.month
day = time.day
torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
file_path = '../data/RML2016.10a/RML2016.10a_dict.pkl'

ratio_s = 0.7
ratio_t = 0.7   # target dataset is insufficient
batch_size = 256
itrs = 100
learning_rate = 0.0001
discriminator_lr = 0.00005
alpha = 0.5
k_disc = 1
k_clf = 1

###################### dataset ######################
x_s, _, y_s, _, snrs, classes = load_data(file_path, device, ratio_s)
x_t, x_test, y_t, y_test, _, _ = load_data(file_path, device, ratio_t)
x_t, x_test = F.interpolate(x_t, (2, 64)), F.interpolate(x_test, (2, 64))
x_t, x_test = F.interpolate(x_t, (2, 128)), F.interpolate(x_test, (2, 128))
source_dataset = Data.TensorDataset(x_s, y_s)
target_dataset = Data.TensorDataset(x_t, y_t)
mix_dataset = Data.TensorDataset(x_s, x_t)

source_dataloader = DataLoader(source_dataset, batch_size, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size, shuffle=True)
mix_dataloader = DataLoader(mix_dataset, batch_size, shuffle=True)

test_dataset = Data.TensorDataset(x_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)

###################### model ######################
source_model = torch.load('source_model_8002.pth').to(device)
clf = source_model
source_model = source_model.feature_extractor
model = model().to(device)
target_model = model.feature_extractor
discriminator = discriminator().to(device)

target_optimizer = torch.optim.Adam(target_model.parameters(), learning_rate)
# discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=learning_rate)
target_scheduler = torch.optim.lr_scheduler.StepLR(target_optimizer, step_size=200)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), discriminator_lr)

loss_fund = nn.BCELoss()
loss_func = nn.CrossEntropyLoss()

source_model.eval()
########## supervised transfer learning ##########
best_acc = 0
for itr in trange(1, itrs+1):

    discriminator_accuracy = 0
    total_loss = 0
    generator_loss = 0

    for i, (sx, tx) in enumerate(mix_dataloader):

        mini_batch = sx.shape[0]

        valid = torch.ones(mini_batch, 1).to(device)
        fake = torch.zeros(mini_batch, 1).to(device)

        #################### Discriminator ####################
        for _ in range(k_disc):
            target_model.eval()
            discriminator.train()

            s_output = discriminator(source_model(sx).view(mini_batch, -1))
            t_output = discriminator(target_model(tx).view(mini_batch, -1))

            discriminator_accuracy += (s_output > 0.5).sum() + (t_output < 0.5).sum()

            # Discriminator
            discriminator_optimizer.zero_grad()
            s_loss = loss_fund(s_output, valid)
            t_loss = loss_fund(t_output, fake)
            d_loss = s_loss + t_loss
            d_loss.backward(retain_graph=True)
            discriminator_optimizer.step()

            total_loss += d_loss.item()

        #################### Target Model ####################
        for _ in range(k_clf):
            target_model.train()
            discriminator.eval()

            target_optimizer.zero_grad()
            t_output = target_model(tx)

            t_loss = loss_fund(discriminator(t_output), valid)
            t_loss.backward()
            target_optimizer.step()

            generator_loss += t_loss.item()

    target_scheduler.step()

    discriminator_accuracy = discriminator_accuracy / (2 * x_s.shape[0])
    print(discriminator_accuracy,
          round(total_loss, 5),
          round(generator_loss, 5))

    clf.feature_extractor = target_model
    epoch_acc = test(clf, test_dataloader)
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(clf, 'target_model.pth')




结果

最终是可以达到60%左右的识别率的,但是判别器的判别率一直在99%左右,实在是太强了。但是判别误差和生成误差一直在降低
经过调低学习率(0.001->0.00005和0.0001)识别率最好可以达到0.6794,和监督学习的0.8002很接近。判别器识别率在96%左右徘徊。
而用wgan的设定判别器直接在50%左右,因此判别器应该不能太好也不能太差

猜你喜欢

转载自blog.csdn.net/weixin_45121008/article/details/129802002