Author: Li Ruifeng
Paper title
Rethinking Federated Learning with Domain Shift: A Prototype View
Paper source
CVPR 2023
Paper link
code link
https://github.com/yuhangchen0/FPL_MS
昇思MindSpore作为一个开源的AI框架,为产学研和开发人员带来端边云全场景协同、极简开发、极致性能,超大规模AI预训练、极简开发、安全可信的体验,2020.3.28开源来已超过5百万的下载量,昇思MindSpore已支持数百+AI顶会论文,走入Top100+高校教学,通过HMS在5000+App上商用,拥有数量众多的开发者,在AI计算中心,金融、智能制造、金融、云、无线、数通、能源、消费者1+8+N、智能汽车等端边云车全场景逐步广泛应用,是Gitee指数最高的开源软件。欢迎大家参与开源贡献、套件、模型众智、行业创新与应用、算法创新、学术合作、AI书籍合作等,贡献您在云侧、端侧、边侧以及安全领域的应用案例。
在科技界、学术界和工业界对昇思MindSpore的广泛支持下,基于昇思MindSpore的AI论文2023年在所有AI框架中占比7%,连续两年进入全球第二,感谢CAAI和各位高校老师支持,我们一起继续努力做好AI科研创新。昇思MindSpore社区支持顶级会议论文研究,持续构建原创AI成果。我会不定期挑选一些优秀的论文来推送和解读,希望更多的产学研专家跟昇思MindSpore合作,一起推动原创AI研究,昇思MindSpore社区会持续支撑好AI创新和AI应用,本文是昇思MindSpore AI顶会论文系列第18篇,我选择了来自武汉大学计算机学院的叶茫老师团队的一篇论文解读,感谢各位专家教授同学的投稿。
昇思MindSpore旨在实现易开发、高效执行、全场景覆盖三大目标。通过使用体验,昇思MindSpore这一深度学习框架的发展速度飞快,它的各类API的设计都在朝着更合理、更完整、更强大的方向不断优化。此外,昇思不断涌现的各类开发工具也在辅助这一生态圈营造更加便捷强大的开发手段,例如MindSpore Insight,它可以将模型架构以图的形式呈现出来,也可以动态监控模型运行时各个指标和参数的变化,使开发过程更加方便。
01
Research Background
In the digital world, data privacy and security have become core issues of increasing concern. It is against this background that federated learning emerged as a distributed machine learning method that protects data privacy. Its core idea is to allow multiple devices or servers to jointly train a model without sharing original data. This approach can handle machine learning tasks on multiple mobile devices, especially when data privacy and security requirements are high.
There is an important problem to be solved in federated learning: data heterogeneity. It usually refers to the fact that the data held by each node (such as a device, server, or organization) involved in learning may vary greatly. These differences may involve aspects such as the distribution, quality, quantity, and type of features of the data. The issue of data heterogeneity is particularly important in federated learning because it may directly affect the learning effect and generalization ability of the model.
This paper points out that for data heterogeneity, existing solutions mainly focus on all private data from the same domain. When distributed data originate from different domains, private models are prone to exhibit degraded performance in other domains (with domain offsets), and global signals cannot capture rich and fair domain information. Therefore, the authors expect that the optimized global model can stably provide generalization performance on multiple domains during the federated learning process.
In this paper, the authors propose "Federated Prototype Learning" (FPL) for federated learning under domain shift. The core idea is to build clustered prototypes and unbiased prototypes that provide rich domain knowledge and fair convergence targets. On the one hand, the sample embeddings are moved away from cluster prototypes from different categories and closer to cluster prototypes of the same semantics. On the other hand, consistency regularization is introduced to align local instances with corresponding unbiased prototypes.
论文基于昇思MindSpore进行框架开发和实验,Digits和Office Caltech任务等实验结果证明了所提出的解决方案的有效性和关键模块的高效性。
02
team introduction
Huang Wenke, the first author of the paper, is currently studying for a master's and doctoral degree at Wuhan University (2021-present), and his mentors are Professor Du Bo and Professor Ye Mang. Graduated from Wuhan University with a bachelor's degree. His main research directions include federated learning, graph learning, financial technology, etc. He has currently published 4 papers as the first author at top international conferences such as CVPR, IJCAI, and ACM MM. During his postgraduate period, he won titles such as Guotai Junan Scholarship and Outstanding Graduate Student. Served as a research intern at Alibaba Group, Microsoft Research Asia, etc.
Ye Mang , the corresponding author of the paper, is a professor and doctoral supervisor at the School of Computer Science at Wuhan University, a national-level high-level young talent, and a youth candidate recommended by the China Association for Science and Technology. He served as a research scientist at the Emirates Origin Artificial Intelligence Research Institute and a visiting scholar at Columbia University in the United States. His main research directions include computer vision, multimedia retrieval, federated learning, etc. He has published more than 80 papers in international journals and conferences, 10 ESI highly cited papers, and has been cited by Google Scholar more than 5,600 times. Served as field chair of academic conferences such as CVPR24 and ACM MM23. Hosts scientific research projects such as the Hubei Provincial Key R&D Plan and the National Natural Science Foundation of China. Won Google Excellent Scholarship, champion of the drone target re-identification track at ICCV2021, the top international computer vision conference, "top 2% of the world's top scientists" in the 2021-2022 Stanford rankings, and 2022 Baidu AI Chinese Young Scholar. .
The research team MARS is directed by Professor Ye Mang and focuses on surveillance video pedestrian/behavior analysis, unsupervised/semi-supervised learning, cross-modal understanding and reasoning, and federated learning.
03
Introduction to the paper
3.1 Introduction
Based on the aforementioned research background, this paper proposes Federated Prototype Learning to solve the problem of federated multi-domain generalization: private data comes from different fields, and different clients have widely different feature distributions. Since the local model will By overfitting the local distribution, the private model fails to perform well in other domains. For example, a local model A trained on grayscale images MNIST, after being aggregated by the server, cannot perform normally on another client such as the color image SVHN data set, because this local model A cannot learn SVHN. Domain information, leading to performance degradation.
Since the global signal cannot represent knowledge information in multiple fields and may be biased towards information in the dominant field, the generalization ability is reduced. In order to allow the model to learn rich multi-domain knowledge and use shared signals to provide information in multiple domains to improve generalization capabilities, this paper proposes to use cluster prototypes to represent information in different domains and use contrastive learning to enhance the commonality of the same categories in different domains and enhance the differences between different categories. , called Cluster Prototypes Contrastive Learning; in order to avoid optimizing towards the potential dominant domain and improve the ability in a few domains, this paper uses unbiased prototypes to provide fair and stable information, which is called unbiased prototype consistency regularization (Unbiased Prototypes Consistent Regularization).
3.2 Method
3.2.1 Preparation
federated learning
In a typical federated learning setting, there are participants and their corresponding private data, expressed as:
Among them, represents the local data scale. In a heterogeneous federated learning environment, the conditional feature distribution
will vary across participants, even if
it is consistent, which leads to domain shift. Define the domain offset as:
This means that there is a domain offset in the private data. Specifically, for the same label space, there are unique feature distributions among different participants.
Figure 1 The local client data source domain is different and the difference is large.
Furthermore, all participants reach consensus and share a model with the same architecture. This model can be viewed as two main parts: feature extractor and classifier. The feature extractor, denoted as , encodes the sample x into
a one-dimensional feature vector in the feature space, expressed as:
The classifier maps features to logits output , which in subsequent formulas
represents the categories of classification. The optimization goal is to learn a generalizable global model with good performance in multiple domains through the federated learning process.
feature prototype
In order to implement subsequent prototype-related methods, this article first constructs the definition of prototype:
The prototype representing
the label of the client-th is obtained
by calculating the average of the feature vectors of all samples
with the label of the client-th , which intuitively represents the domain information represented by the label of this client.
If you ignore the method in this article first, the most general method is to directly average the domain information of all clients' tags , let all clients learn this information, and constrain local client updates:
Which represents
the average domain information of all samples from different fields labeled as in the entire federated system. However, such a global view is biased. It cannot correctly represent information in various fields, and may also be biased toward dominant fields and ignore a few fields, as shown in Figure 2a.
Figure 2 Representation of different types of prototypes
3.2.2 Cluster prototype comparative learning
In order to solve the problem of global prototypes, this paper first uses the FINCH method for unsupervised clustering to separate extensive domain knowledge (feature vectors of each sample) unsupervisedly. In this way, samples from different domains have their own If there are differences in feature vectors, different fields will be clustered into different clusters, and then the prototype of this cluster will be calculated within the same cluster, as shown in Figure 2b, to prevent multiple domains from being far away from all useful domain knowledge after averaging each other.
In the above formula, it represents the set of cluster prototypes of labels
that have been clustered .
Based on this, this paper implements cluster prototype comparative learning by adding a new loss term. For a sample belonging to a certain sample
, its feature vector is
. This article uses comparative learning to try to shorten
the distance between the sample and all prototypes that belong to the same semantics as much as possible, or to improve the similarity under the same label from different domains. At the same time, the similarity with all prototypes that do not belong to
(marked as
) is reduced as much as possible. Through this method, rich knowledge in different fields is learned during local update and the generalization ability is improved. The author defines the similarity between the sample feature vector and the prototype as:
Then construct the loss term that implements cluster prototype comparative learning:
Why does this approach work? The author gives the following analysis:
Minimizing this loss function is equivalent to bringing the sample feature vector closely to its assigned positive cluster prototype , and moving the feature vector away from other negative prototypes
. This not only maintains invariance to various domain distortions, but also enhances the diffusion properties of semantics, ensuring that the feature space is both generalizable and discriminative, thereby achieving satisfactory generalization performance in federated learning.
3.2.3 Unbiased prototype consistency regularization
Since cluster prototypes bring diverse domain knowledge to the plasticity under domain transfer, but due to the unsupervised clustering method, cluster prototypes are dynamically generated with each communication, and their scale changes. Therefore, the cluster prototype cannot provide stable convergence direction in different communication eras. This paper proposes a second method to ensure sustained multi-domain fairness by building a fair and stable unbiased prototype and constraining the distance between multiple cluster prototypes and the unbiased prototype.
Specifically, multiple cluster prototypes under the same label that have been clustered are averaged to represent the unbiased convergence target under the label , as shown in Figure 2c.
This article introduces a second loss term and uses the consistency regularization term to bring the feature vector of the sample closer to the corresponding unbiased prototype , providing a relatively fair and stable optimization point to solve the problem of convergence instability:
3.2.4 Overall algorithm
In addition to the above two losses, the cross-entropy loss function used in conventional model training is used as the loss function of the federated prototype learning proposed in this article:
learning process:
Algorithm :
04
Experimental results
4.1 Comparison with experimental results of State-of-the-art
This article was tested under the Digits and Office Caltech data sets. The former is a digital data set with four identical labels and different data sources, and the latter is a real-world data set with four same labels and different data sources. Experiments show that the proposed FPL is better than the current SOTA in both performance in a single field and average performance in multiple fields.
4.2 Ablation experiment
It can be seen that in most cases CPCL and UPCR work together to produce better performance.
Comparison of the experimental results demonstrated by the two methods using ordinary global prototypes and the proposed prototype demonstrates the effectiveness of clustered prototypes and unbiased prototypes.
4.3 昇思MindSpore代码展示
本框架基于昇思MindSpore进行开发。
4.3.1 昇思MindSpore实现集群原型对比学习
def calculate_infonce(self, f_now, label, all_f, all_global_protos_keys):
pos_indices = 0
neg_indices = []
for i, k in enumerate(all_global_protos_keys):
if k == label.item():
pos_indices = i
else:
neg_indices.append(i)
f_pos = Tensor(all_f[pos_indices][0]).reshape(1,512)
f_neg = ops.cat([Tensor(all_f[i]).reshape(-1, 512) for i in neg_indices], axis=0)
#aaa
f_proto = ops.cat((f_pos, f_neg), axis=0)
f_now = f_now.reshape(1,512)
f_now_np = f_now.asnumpy()
f_proto_np = f_proto.asnumpy()
def cosine_similarity_numpy(vec_a, vec_b):
dot_product = np.dot(vec_a, vec_b.T)
norm_a = np.linalg.norm(vec_a, axis=1, keepdims=True)
norm_b = np.linalg.norm(vec_b, axis=1)
return dot_product / (norm_a * norm_b)
l_np = cosine_similarity_numpy(f_now_np, f_proto_np)
l = Tensor(l_np)
#l = ops.cosine_similarity(f_now, f_proto, dim=1)
l = ops.div(l, self.infoNCET)
exp_l = ops.exp(l).reshape(1, -1)
pos_num = f_pos.shape[0]
neg_num = f_neg.shape[0]
pos_mask = Tensor([1] * pos_num + [0] * neg_num).reshape(1, -1)
pos_l = exp_l * pos_mask
sum_pos_l = ops.sum(pos_l, dim=1)
sum_exp_l = ops.sum(exp_l, dim=1)
infonce_loss = -ops.log(sum_pos_l / sum_exp_l)
return Tensor(infonce_loss)
4.3.2 昇思**MindSpore实现无偏原型一致性正则化**
def hierarchical_info_loss(self, f_now, label, mean_f, all_global_protos_keys):
pos_indices = 0
for i, k in enumerate(all_global_protos_keys):
if k == label.item():
pos_indices = i
mean_f_pos = Tensor(mean_f[pos_indices])
f_now = Tensor(f_now)
cu_info_loss = self.loss_mse(f_now, mean_f_pos)
return cu_info_loss
4.3.3 Client local model training
def _train_net(self, index, net, train_loader):
if len(self.global_protos) != 0:
all_global_protos_keys = np.array(list(self.global_protos.keys()))
all_f = []
mean_f = []
for protos_key in all_global_protos_keys:
temp_f = self.global_protos[protos_key]
all_f.append(copy.deepcopy(temp_f))
mean_f.append(copy.deepcopy(np.mean(temp_f, axis=0)))
all_f = [item.copy() for item in all_f]
mean_f = [item.copy() for item in mean_f]
else:
all_f = []
mean_f = []
all_global_protos_keys = []
optimizer = nn.SGD(net.trainable_params(), learning_rate=self.local_lr, momentum=0.9, weight_decay=1e-5)
criterion1 = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
criterion = CustomLoss(criterion1, self.loss2)
self.loss_mse = mindspore.nn.MSELoss()
train_net= nn.TrainOneStepCell(nn.WithLossCell(net,criterion), optimizer=optimizer)
train_net.set_train(True)
iterator = tqdm(range(self.local_epoch))
for iter in iterator:
agg_protos_label = {}
for di in train_loader.create_dict_iterator():
images = di["image"]
labels = di["label"]
# train_net.set_train(False)
f = net.features(images)
#train_net.set_train(True)
if len(self.global_protos) == 0:
loss_InfoNCE = 0
else:
i = 0
loss_InfoNCE = None
for label in labels:
if label in all_global_protos_keys:
f_now = f[i]
cu_info_loss = self.hierarchical_info_loss(f_now, label, mean_f, all_global_protos_keys)
xi_info_loss = self.calculate_infonce(f
05
Summary and Outlook
在本文中,我们探讨了在异构联邦学习中领域转移下的泛化性和稳定性问题。我们的研究引入了一个简单而有效的联邦学习算法,即联邦原型学习(FPL)。我们利用原型(类的典型表示)来解决这两个问题,享受集群原型和无偏原型的互补优势:多样的领域知识和稳定的收敛信号。我们使用昇思MindSpore架构实现了FPL框架并展现其在效率和准确性上的优势。
在使用昇思MindSpore进行FPL框架开发中,我们注意到昇思MindSpore社区非常活跃,有许多华为开发者和使用者针对我们框架搭建中遇到的困难提供巨大帮助。不仅如此,借助昇思MindSpore提供的丰富的文档和教程以及社区中的实际案例和最佳实践,我们避免了许多潜在的陷阱,更快地达到了我们的研究目标。
A programmer born in the 1990s developed a video porting software and made over 7 million in less than a year. The ending was very punishing! Google confirmed layoffs, involving the "35-year-old curse" of Chinese coders in the Flutter, Dart and Python teams . Daily | Microsoft is running against Chrome; a lucky toy for impotent middle-aged people; the mysterious AI capability is too strong and is suspected of GPT-4.5; Tongyi Qianwen open source 8 models Arc Browser for Windows 1.0 in 3 months officially GA Windows 10 market share reaches 70%, Windows 11 GitHub continues to decline. GitHub releases AI native development tool GitHub Copilot Workspace JAVA is the only strong type query that can handle OLTP+OLAP. This is the best ORM. We meet each other too late.