FedAvg代码详解
代码位置:https://github.com/shaoxiongji/federated-learning
README
首先查看作者给的README文档:
这一部分给出了论文的地址,如果有兴趣的话可以读一下论文,讲的就是FedAvg的思想以及对独立同分布数据和非独立同分布数据的一些研究。
下来是requirements,就是代码的依赖库,需要安装
如果有不会安装pytorch的同学可以看我之前的博客,从安装CUDA和CUDNN开始的。
运行:
给出了运行的命令,如果使用MLP和CNN模型单独训练就执行main_nn.py文件,使用联邦学习的训练就执行main_fed.py文件。并且给出了命令行参数:–dataset设置训练集;–iid设置数据是否为独立同分布;以及–num_channels设置数据的通道数量,如果是MNIST数据集就是1,CIFAR-10是3;–epochs设置训练的轮数;–gpu设置是否使用GPU;–all_client设置平局所有客户端模型。
作者给出的结果:
项目目录
- data文件夹下存放数据集,有MNIST和CIFAR-10
- models文件夹下存放和模型相关的文件
- save存放训练的结果
- utils存放一些工具
分析代码
所有_init_.py文件都不需要分析
Models目录下代码
Fed.py
import copy
import torch
from torch import nn
def FedAvg(w):
w_avg = copy.deepcopy(w[0]) # 创建一个深拷贝,用于平均参数的累加
for k in w_avg.keys(): # 遍历参数字典的键
for i in range(1, len(w)): # 遍历参与平均的参数列表
w_avg[k] += w[i][k] # 将每个参数按键累加到平均参数中
w_avg[k] = torch.div(w_avg[k], len(w)) # 将累加的参数值除以参与平均的参数个数,得到平均值
return w_avg # 返回平均参数
Fed文件主要做的事情就是对权重求平均
这段代码可能难以理解的就是这个深拷贝,我这里给出一个测试代码,大家可以自行理解,这里就不展开了,有问题可以写在评论区或者私信我。
import copy
a = 1 # =赋值 不可变元素如字符串、数值等
b = a
print('a和b的值:', a, b)
print('a和b的id:', id(a), id(b))
a = 2
print('修改a后a和b的值:', a, b)
print('修改a后a和b的id:', id(a), id(b))
# a和b的值: 1 1
# a和b的id: 2279821764912 2279821764912
# 修改a后a和b的值: 2 1
# 修改a后a和b的id: 2279821764944 2279821764912
c = [1,2,3] # =赋值 可变元素如列表、字典等
d = c
print('c和d的值:', c, d)
print('c和d的id:', id(c), id(d))
c.append(4)
print('修改c后c和d的值:', c, d)
print('修改c后c和d的id:', id(c), id(d))
# c和d的值: [1, 2, 3] [1, 2, 3]
# c和d的id: 2279827005184 2279827005184
# 修改c后c和d的值: [1, 2, 3, 4] [1, 2, 3, 4]
# 修改c后c和d的id: 2279827005184 2279827005184
orignal_list = [1, 2, 3, [4]] # 定义一个有嵌套层次的列表
copy_list = copy.copy(orignal_list) # 浅拷贝
deepcopy_list = copy.deepcopy(orignal_list) # 深拷贝
print(copy_list, deepcopy_list)
print(id(orignal_list), id(copy_list), id(deepcopy_list))
orignal_list.append(5)
orignal_list[-2].append(6) # 修改嵌套内的可变对象
print(orignal_list, copy_list, deepcopy_list)
# [1, 2, 3, [4]] [1, 2, 3, [4]]
# 1985362690688 1985365369664 1985365370112
# [1, 2, 3, [4, 6], 5] [1, 2, 3, [4, 6]] [1, 2, 3, [4]] # copy和deepcopy的差别就是:copy只复制了第一层 而更深的层次还是一种引用;deepcopy是递归复制 所有层次完全复制到新的内存空间
Nets.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import torch
from torch import nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out):
super(MLP, self).__init__()
self.layer_input = nn.Linear(dim_in, dim_hidden) # 输入层
self.relu = nn.ReLU() # ReLU激活函数
self.dropout = nn.Dropout() # Dropout层
self.layer_hidden = nn.Linear(dim_hidden, dim_out) # 隐藏层
def forward(self, x):
x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1]) # 将输入展平
x = self.layer_input(x) # 输入层
x = self.dropout(x) # Dropout层,用于防止过拟合
x = self.relu(x) # ReLU激活函数
x = self.layer_hidden(x) # 隐藏层
return x
class CNNMnist(nn.Module):
def __init__(self, args):
super(CNNMnist, self).__init__()
self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) # 第一个卷积层,输入通道数为args.num_channels,输出通道数为10
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # 第二个卷积层,输入通道数为10,输出通道数为20
self.conv2_drop = nn.Dropout2d() # 二维Dropout层
self.fc1 = nn.Linear(320, 50) # 全连接层1,输入大小为320,输出大小为50
self.fc2 = nn.Linear(50, args.num_classes) # 全连接层2,输入大小为50,输出大小为args.num_classes
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2)) # 第一个卷积层后接ReLU激活函数和最大池化层
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) # 第二个卷积层后接ReLU激活函数、Dropout层和最大池化层
x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) # 将输入展平
x = F.relu(self.fc1(x)) # 全连接层1后接ReLU激活函数
x = F.dropout(x, training=self.training) # Dropout层,用于防止过拟合
x = self.fc2(x) # 全连接层2,输出最终结果
return x
class CNNCifar(nn.Module):
def __init__(self, args):
super(CNNCifar, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 第一个卷积层,输入通道数为3,输出通道数为6
self.pool = nn.MaxPool2d(2, 2) # 最大池化层
self.conv2 = nn.Conv2d(6, 16, 5) # 第二个卷积层,输入通道数为6,输出通道数为16
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层1,输入大小为16*5*5,输出大小为120
self.fc2 = nn.Linear(120, 84) # 全连接层2,输入大小为120,输出大小为84
self.fc3 = nn.Linear(84, args.num_classes) # 全连接层3,输入大小为84,输出大小为args.num_classes
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 第一个卷积层后接ReLU激活函数和最大池化层
x = self.pool(F.relu(self.conv2(x))) # 第二个卷积层后接ReLU激活函数和最大池化层
x = x.view(-1, 16 * 5 * 5) # 将输入展平
x = F.relu(self.fc1(x)) # 全连接层1后接ReLU激活函数
x = F.relu(self.fc2(x)) # 全连接层2后接ReLU激活函数
x = self.fc3(x) # 全连接层3,输出最终结果
return x
Nets.py是定义模型的文件代码,我们在这里实现我们需要使用的模型就可以
test.py
def test_img(net_g, datatest, args):
net_g.eval()
# 将模型设置为评估模式 eval函数的作用就是不启用dropout和BN 否则测试的时候不会是训练好的权重
test_loss = 0
correct = 0
data_loader = DataLoader(datatest, batch_size=args.bs)
# 创建测试集的数据加载器
for idx, (data, target) in enumerate(data_loader):
if args.gpu != -1:
data, target = data.cuda(), target.cuda()
# 将数据和标签移到GPU上(如果可用)
log_probs = net_g(data) # 模型会自动调用forward函数前向传播 对于分类任务 返回值就是每一类的概率 也就是权重w
test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
# 计算批次损失的总和
y_pred = log_probs.data.max(1, keepdim=True)[1]
# 在第一个维度中寻找最大值并不改变维度
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
# 计算正确预测的数量
test_loss /= len(data_loader.dataset)
# 计算平均损失
accuracy = 100.00 * correct / len(data_loader.dataset)
# 计算准确率
if args.verbose:
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(data_loader.dataset), accuracy))
# 打印测试结果(平均损失和准确率)
return accuracy, test_loss
# 返回准确率和损失
test.py中只有一个test_img函数,输入模型和数据集,返回准确率和测试的损失
Update.py
class DatasetSplit(Dataset):
# Dataset的子类 可以创建一个数据子集对象 只包含原始数据集中特定索引的样本 在分割数据集用于训练和测试时使用
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image, label
这个类在实例化的时候,接收参数dataset和idxs,数据集和要分割的索引,将数据集按照索引进行分割
class LocalUpdate(object): # 本地模型更新 根据本地数据集进行训练并返回
def __init__(self, args, dataset=None, idxs=None):
self.args = args
self.loss_func = nn.CrossEntropyLoss() # 交叉熵损失
self.selected_clients = [] # 选择的客户端节点
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
# 加载数据集的子集 通过idxs分割 本地的batch_size由命令行参数给出
def train(self, net): # 本地训练
net.train() # 设置为训练模式 会启用dropout函数和BN函数
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) # 梯度下降算法使用随机梯度下降
epoch_loss = []
for iter in range(self.args.local_ep): # 根据local_ep确定本地训练轮数
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
loss.backward()
# 反向传播 用于计算损失函数对模型参数的梯度 从而更新模型参数以最小化损失函数
optimizer.step() # 根据
if self.args.verbose and batch_idx % 10 == 0:
print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
iter, batch_idx * len(images), len(self.ldr_train.dataset),
100. * batch_idx / len(self.ldr_train), loss.item()))
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
LocalUpdate类主要定义了联邦学习中本地更新的过程,每个客户端使用自己的本地模型进行训练,并将更新后的模型参数传回到中央服务器进行聚合,代码中的形式就是return,net.state_dict()就是模型参数,返回值的第二项就是所有轮次的平均损失。
Utils目录下代码
options.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import argparse
def args_parser():
parser = argparse.ArgumentParser()
# federated arguments
parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
parser.add_argument('--bs', type=int, default=128, help="test batch size")
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")
# model arguments
parser.add_argument('--model', type=str, default='mlp', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
help='comma-separated kernel size to use for convolution')
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
parser.add_argument('--max_pool', type=str, default='True',
help="Whether use max pooling rather than strided convolutions")
# other arguments
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
parser.add_argument('--verbose', action='store_true', help='verbose print')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
args = parser.parse_args()
return args
这个代码就没什么好说的,命令行的参数含义
sampling.py
def mnist_iid(dataset, num_users):
"""
Sample I.I.D. client data from MNIST dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset)/num_users) # 根据数据集长度及客户端数算出需要分成几份
dict_users, all_idxs = {
}, [i for i in range(len(dataset))]
# dict_users是一个空字典 all_idxs是一个从0到len(dataset)-1的连续整数的列表
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
# 为每一个客户端随机选择num_items个索引 并且设置不重复选择 然后将这些索引作为一个集合set存在字典中 使用i作为键 这样每个客户端在在 dict_users 字典中都有一个对应的图像索引集合,包含了随机选择的 num_items 个不重复的图像索引。通过这种方式,确保了每个客户端的数据集合起来能够覆盖整个数据集,并且每个客户端的数据量相等。
all_idxs = list(set(all_idxs) - dict_users[i]) # 每给一个用户分配数据后 将分配过的数据删去
return dict_users # 返回一个用户序号-数据集的字典
从MNIST数据集中抽取独立同分布IID的客户端数据,将数据集划分为多个子数据集给每个客户端使用(但是根据这个代码,我发现其实这个函数并没有考虑数据的同分布问题,只是单纯的数量相同)
import numpy as np
def mnist_noniid(dataset, num_users):
"""
从MNIST数据集中采样非独立同分布(non-I.I.D.)的客户端数据
:param dataset: MNIST数据集
:param num_users: 客户端数量
:return: 包含客户端数据索引的字典
"""
num_shards, num_imgs = 200, 300 # 数据集划分的分片数和每个分片的图像数量
idx_shard = [i for i in range(num_shards)] # 分片索引列表
dict_users = {
i: np.array([], dtype='int64') for i in range(num_users)} # 存储客户端数据索引的字典
idxs = np.arange(num_shards * num_imgs) # 所有图像的索引
labels = dataset.train_labels.numpy() # 所有图像的标签
# 根据标签对索引进行排序
idxs_labels = np.vstack((idxs, labels)) # 将两个数据idxs和labels按垂直方向堆叠
idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] # 按照图像标签进行排序
idxs = idxs_labels[0, :] # 取出第一行 即排序后的所有图像索引
# 划分并分配数据
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 2, replace=False)) # 随机选择两个分片
idx_shard = list(set(idx_shard) - rand_set) # 从分片索引列表中移除已选择的分片
for rand in rand_set:
dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:(rand + 1) * num_imgs]), axis=0)
# 将所选分片的图像索引添加到相应客户端的数据索引列表中
return dict_users
作用就是采取非独立同分布的客户端数据,不像iid那样均分数据,因为取数据的时候先将数据按照类别排好序之后再进行分配,将数据排好序后,根据随机的分片索引在随机的位置上取连续的数据,这样取到的数据很可能是同一个标签的。
def cifar_iid(dataset, num_users):
"""
Sample I.I.D. client data from CIFAR10 dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset)/num_users)
dict_users, all_idxs = {
}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
和mnist_iid函数完全一致,我也不知道为什么要放两个一模一样的函数。
核心代码
main_fed.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img
if __name__ == '__main__':
# parse args
args = args_parser() # 解析命令行参数
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
# load dataset and split users
if args.dataset == 'mnist': # 加载MNIST数据集
trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users # 划分数据集给客户端
if args.iid:
dict_users = mnist_iid(dataset_train, args.num_users)
else:
dict_users = mnist_noniid(dataset_train, args.num_users)
elif args.dataset == 'cifar': # 加载CIFAR-10数据集
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 因为CIFAR数据集是3通道
dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
if args.iid:
dict_users = cifar_iid(dataset_train, args.num_users)
else:
exit('Error: only consider IID setting in CIFAR10')
# 代码没有实现CIFAR的非独立同分布
else:
exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape # 输出数据集的图像格式
# build model # 调用模型
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob) # 输出网络的前向函数的返回值,也就是一个张量
net_glob.train() # 设置为训练模式
# copy weights
w_glob = net_glob.state_dict() # 取出模型参数即权重
# 训练
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None
val_acc_list, net_list = [], []
if args.all_clients: # 如果命令行参数指定了全部客户端
print("Aggregation over all clients")
w_locals = [w_glob for i in range(args.num_users)] # 聚合所有的本地模型参数
for iter in range(args.epochs):
loss_locals = []
if not args.all_clients:
w_locals = []
m = max(int(args.frac * args.num_users), 1)
# 计算每轮训练中参与训练的用户的数量
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
# 随机选择m个不重复的用户索引
for idx in idxs_users:
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
# 实例化LocalUpdate类
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
# 调用LocalUpdate类的函数train 使用深拷贝取出刚才的模型 这里就因为每个客户端都需要一个单独的模型来训练 所以必须使用深拷贝
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)
else:
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss)) # 计算本地损失
# update global weights
w_glob = FedAvg(w_locals) # 使用FedAvg算法聚合模型参数
# copy weight to net_glob
net_glob.load_state_dict(w_glob) # 加载全局模型
# print loss
loss_avg = sum(loss_locals) / len(loss_locals) # 输出平均损失
print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
loss_train.append(loss_avg)
# plot loss curve
plt.figure() #使用plt可视化损失变化
plt.plot(range(len(loss_train)), loss_train)
plt.ylabel('train_loss')
plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
# testing
net_glob.eval() # 测试模型 设置为eval评估模式
acc_train, loss_train = test_img(net_glob, dataset_train, args) # 调用test_img测试函数
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))
main_nn.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar
def test(net_g, data_loader):
# 计算测试时的平均损失和准确率 实际上和test.py中的test_img代码基本相同
# testing
net_g.eval()
test_loss = 0
correct = 0
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader):
data, target = data.to(args.device), target.to(args.device)
log_probs = net_g(data)
test_loss += F.cross_entropy(log_probs, target).item()
y_pred = log_probs.data.max(1, keepdim=True)[1]
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
test_loss /= len(data_loader.dataset)
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))
return correct, test_loss
if __name__ == '__main__':
# parse args
args = args_parser() # 解析命令行参数
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
torch.manual_seed(args.seed)
# load dataset and split users # 这一部分加载数据和main_fed一致
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
img_size = dataset_train[0][0].shape
elif args.dataset == 'cifar':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
img_size = dataset_train[0][0].shape
else:
exit('Error: unrecognized dataset')
# build model
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)
# training # 训练过程 直接用整个数据集进行训练
optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
list_loss = []
net_glob.train()
for epoch in range(args.epochs):
batch_loss = []
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(args.device), target.to(args.device)
optimizer.zero_grad()
output = net_glob(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
batch_loss.append(loss.item())
loss_avg = sum(batch_loss)/len(batch_loss)
print('\nTrain loss:', loss_avg)
list_loss.append(loss_avg)
# plot loss
plt.figure()
plt.plot(range(len(list_loss)), list_loss)
plt.xlabel('epochs')
plt.ylabel('train loss')
plt.savefig('./save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
# testing # 调用上面的test函数进行测试
if args.dataset == 'mnist':
dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
elif args.dataset == 'cifar':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
else:
exit('Error: unrecognized dataset')
print('test on', len(dataset_test), 'samples')
test_acc, test_loss = test(net_glob, test_loader)
我们可以看到main_nn就是正常的一个训练过程,main_fed是联邦学习的训练测试过程。代码部分讲解就到这里,我们给出一个自己画的main_fed.py代码流程图。