标题
ADDA
本次仿真是根据论文Adversarial Discriminative Domain Adaptation和在此基础上提出的ATLA进行的。由于ATLA最后采取分类器的分类结果输出,一般维度只有num_classes,感觉对于discriminator来说太少了,很难达到对抗迁移的效果。所以本次仿真在ADDA的基础上取前面CNN网络作为特征提取器,其输出为特征直接输入到判别器上。实验过程中发现由于源数据和目标数据的长度可能不一样,暂时想到先对目标数据进行插值,实现长度相等,经过相同结构的CNN网络投影到同一个特征空间。但是ADDA中提出非对称网络结构会有更好的效果
代码
代码参考了https://github.com/jvanvugt/pytorch-domain-adaptation/,该作者使用了对称的神经网络(相同)作为特征提取,其数据大小也是一样的,不需要考虑投影空间和输入信号尺寸的关系。
网络部分
import torch
import torch.nn as nn
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.feature_extractor = nn.Sequential(
nn.ZeroPad2d((2, 2, 0, 0)),
nn.Conv2d(in_channels=1, out_channels=50, kernel_size=(1, 8), stride=1, padding='valid'),
nn.ReLU(),
nn.BatchNorm2d(50),
nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(2, 8), stride=1, padding='valid'),
nn.ReLU(),
nn.BatchNorm2d(50),
nn.Flatten(1),
)
self.classifier = nn.Sequential(
nn.Linear(5900, 256),
nn.Dropout(),
nn.Linear(256, 11)
)
def forward(self, x):
# batch_size = x.shape[0]
if len(x.shape) == 3:
x = x.unsqueeze(1)
features = self.feature_extractor(x)
logits = self.classifier(features)
return logits
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.bn = nn.BatchNorm1d(500)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dense1 = nn.Linear(in_features=5900, out_features=1024)
# self.dense2 = nn.Linear(in_features=1024, out_features=256)
self.dense3 = nn.Linear(in_features=1024, out_features=1)
def forward(self, x):
out = self.dense1(x)
out = self.relu(out)
out = self.dense2(out)
out = self.relu(out)
out = self.dense3(out)
out = self.sigmoid(out)
return out
训练源模型
from torch.utils.data import DataLoader
from load_data import load_data
from train import *
from net import *
import torch.utils.data as Data
import torch.nn as nn
import torch
import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
file_path = '../data/RML2016.10a/RML2016.10a_dict.pkl'
ratio_s = 0.7
batch_size = 256
itrs = 100
learning_rate = 0.001
x_s, x_test, y_s, y_test, snrs, classes = load_data(file_path, device, ratio_s)
train_dataset = Data.TensorDataset(x_s, y_s)
test_dataset = Data.TensorDataset(x_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)
model = model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50)
loss_fun = nn.CrossEntropyLoss()
best_acc = 0
########## training ##########
for itr in tqdm.trange(1, itrs+1):
s_loss, s_acc = train(model, train_dataloader, itr, optimizer, loss_fun)
test_acc = test(model, test_dataloader)
if test_acc > best_acc:
best_acc = test_acc
print("best acc: {}".format(best_acc))
torch.save(model, "source_model.pth")
lr_scheduler.step(itr)
ADDA代码
import datetime
from torch.utils.data import DataLoader
from load_data import load_data
from train import *
from net import *
import torch.utils.data as Data
import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import trange
time = datetime.datetime.now()
month = time.month
day = time.day
torch.autograd.set_detect_anomaly(True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
file_path = '../data/RML2016.10a/RML2016.10a_dict.pkl'
ratio_s = 0.7
ratio_t = 0.7 # target dataset is insufficient
batch_size = 256
itrs = 100
learning_rate = 0.0001
discriminator_lr = 0.00005
alpha = 0.5
k_disc = 1
k_clf = 1
###################### dataset ######################
x_s, _, y_s, _, snrs, classes = load_data(file_path, device, ratio_s)
x_t, x_test, y_t, y_test, _, _ = load_data(file_path, device, ratio_t)
x_t, x_test = F.interpolate(x_t, (2, 64)), F.interpolate(x_test, (2, 64))
x_t, x_test = F.interpolate(x_t, (2, 128)), F.interpolate(x_test, (2, 128))
source_dataset = Data.TensorDataset(x_s, y_s)
target_dataset = Data.TensorDataset(x_t, y_t)
mix_dataset = Data.TensorDataset(x_s, x_t)
source_dataloader = DataLoader(source_dataset, batch_size, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size, shuffle=True)
mix_dataloader = DataLoader(mix_dataset, batch_size, shuffle=True)
test_dataset = Data.TensorDataset(x_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)
###################### model ######################
source_model = torch.load('source_model_8002.pth').to(device)
clf = source_model
source_model = source_model.feature_extractor
model = model().to(device)
target_model = model.feature_extractor
discriminator = discriminator().to(device)
target_optimizer = torch.optim.Adam(target_model.parameters(), learning_rate)
# discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=learning_rate)
target_scheduler = torch.optim.lr_scheduler.StepLR(target_optimizer, step_size=200)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), discriminator_lr)
loss_fund = nn.BCELoss()
loss_func = nn.CrossEntropyLoss()
source_model.eval()
########## supervised transfer learning ##########
best_acc = 0
for itr in trange(1, itrs+1):
discriminator_accuracy = 0
total_loss = 0
generator_loss = 0
for i, (sx, tx) in enumerate(mix_dataloader):
mini_batch = sx.shape[0]
valid = torch.ones(mini_batch, 1).to(device)
fake = torch.zeros(mini_batch, 1).to(device)
#################### Discriminator ####################
for _ in range(k_disc):
target_model.eval()
discriminator.train()
s_output = discriminator(source_model(sx).view(mini_batch, -1))
t_output = discriminator(target_model(tx).view(mini_batch, -1))
discriminator_accuracy += (s_output > 0.5).sum() + (t_output < 0.5).sum()
# Discriminator
discriminator_optimizer.zero_grad()
s_loss = loss_fund(s_output, valid)
t_loss = loss_fund(t_output, fake)
d_loss = s_loss + t_loss
d_loss.backward(retain_graph=True)
discriminator_optimizer.step()
total_loss += d_loss.item()
#################### Target Model ####################
for _ in range(k_clf):
target_model.train()
discriminator.eval()
target_optimizer.zero_grad()
t_output = target_model(tx)
t_loss = loss_fund(discriminator(t_output), valid)
t_loss.backward()
target_optimizer.step()
generator_loss += t_loss.item()
target_scheduler.step()
discriminator_accuracy = discriminator_accuracy / (2 * x_s.shape[0])
print(discriminator_accuracy,
round(total_loss, 5),
round(generator_loss, 5))
clf.feature_extractor = target_model
epoch_acc = test(clf, test_dataloader)
if epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(clf, 'target_model.pth')
结果
最终是可以达到60%左右的识别率的,但是判别器的判别率一直在99%左右,实在是太强了。但是判别误差和生成误差一直在降低
经过调低学习率(0.001->0.00005和0.0001)识别率最好可以达到0.6794,和监督学习的0.8002很接近。判别器识别率在96%左右徘徊。
而用wgan的设定判别器直接在50%左右,因此判别器应该不能太好也不能太差