에세이 강의 | MindSpore 기반 다중 도메인 프로토타입 비교 학습에서 일반화된 연합 프로토타입 학습 구현

저자: 리 루이펑

논문 제목

도메인 전환을 통한 연합 학습 재고: 프로토타입 보기

용지 공급

CVPR 2023

종이 링크

https://openaccess.thecvf.com/content/CVPR2023/papers/Huang_Rethinking_Federated_Learning_With_Domain_Shift_A_Prototype_View_CVPR_2023_paper.pdf

코드 링크

https://github.com/yuhangchen0/FPL_MS

오픈 소스 AI 프레임워크인 MindSpore는 산학연-연구 및 개발자에게 전체 시나리오 장치-에지-클라우드 협업, 미니멀리스트 개발, 최고의 성능, 초대형 AI 사전 훈련, 미니멀리스트 개발 및 안전하고 신뢰할 수 있는 솔루션을 제공합니다. 경험, 2020.3.28 오픈 소스는 500만 회 이상의 다운로드를 보유하고 있습니다. MindSpore는 수백 개의 AI 최고 컨퍼런스 논문을 지원하고 상위 100개 이상의 대학에서 강의했으며 HMS를 통해 5000개 이상의 앱에서 상업적으로 사용 가능합니다. AI 컴퓨팅 분야에서는 센터, 금융, 지능형 제조, 금융, 클라우드, 무선, 데이터 통신, 에너지, 소비자 1+8+N, 스마트 자동차 및 기타 엔드 에지 클라우드 자동차 시나리오가 점차 널리 사용되고 있습니다. Gitee 지수가 가장 높은 오픈 소스 소프트웨어입니다. 누구나 오픈 소스 기여, 키트, 모델 크라우드 인텔리전스, 산업 혁신 및 응용, 알고리즘 혁신, 학술 협력, AI 도서 협력 등에 참여하고 클라우드 측, 장치 측, 엣지 측 및 애플리케이션 사례에 기여할 수 있습니다. 보안 분야.

과학기술계, 학계, 산업계의 SunSilicon MindSpore의 광범위한 지원으로 SunSilicon MindSpore를 기반으로 한 AI 논문은 2023년 전체 AI 프레임워크의 7%를 차지하여 2년 연속 세계 2위를 차지했습니다. 모든 대학은 교사들의 지원을 받아 AI 연구와 혁신을 위해 계속 함께 열심히 노력할 것입니다. MindSpore 커뮤니티는 최고의 컨퍼런스 논문 연구를 지원하고 독창적인 AI 결과를 지속적으로 구축합니다. 나는 때때로 우수한 논문을 선택하여 추진하고 해석할 것입니다. 업계, 학계 및 연구 분야의 더 많은 전문가가 Shengsi MindSpore와 협력하여 독창적인 AI 연구를 촉진하기를 바랍니다. MindSpore AI 서밋 논문 시리즈의 18번째 기사에서는 우한 대학교 컴퓨터 과학부의 Ye Mang 선생님 팀의 논문을 해석 하기로 했습니다 . 모든 전문가, 교수 및 동급생에게 감사드립니다. 그들의 기여에 대해.

MindSpore는 손쉬운 개발, 효율적인 실행, 전체 시나리오 적용이라는 세 가지 주요 목표를 달성하는 것을 목표로 합니다. 사용 경험을 통해 딥러닝 프레임워크인 MindSpore는 빠르게 발전하고 있으며 다양한 API의 디자인은 보다 합리적이고 완전하며 강력한 방향으로 지속적으로 최적화되고 있습니다. 또한 Shengsi에서 지속적으로 등장하는 다양한 개발 도구도 이 생태계를 지원하여 모델 아키텍처를 다이어그램 형식으로 제시하고 다양한 측면을 동적으로 모니터링할 수 있는 MindSpore Insight와 같이 보다 편리하고 강력한 개발 방법을 만들 수 있습니다. 런타임 중 모델의 표시기와 매개변수가 변경되면 개발 프로세스가 더욱 편리해집니다.

01

연구배경

디지털 세계에서 데이터 개인 정보 보호 및 보안은 점점 더 우려되는 핵심 문제가 되었습니다. 이러한 배경에서 연합 학습은 데이터 프라이버시를 보호하는 분산 기계 학습 방법으로 등장했습니다. 핵심 아이디어는 여러 장치 또는 서버가 원본 데이터를 공유하지 않고 모델을 공동으로 훈련할 수 있도록 하는 것입니다. 이 접근 방식은 특히 데이터 개인 정보 보호 및 보안 요구 사항이 높은 경우 여러 모바일 장치에서 기계 학습 작업을 처리할 수 있습니다.

연합 학습에는 해결해야 할 중요한 문제가 있습니다. 바로 데이터 이질성입니다. 일반적으로 학습과 관련된 각 노드(예: 장치, 서버 또는 조직)가 보유하는 데이터가 크게 다를 수 있다는 사실을 나타냅니다. 이러한 차이에는 데이터의 분포, 품질, 수량 ​​및 기능 유형과 같은 측면이 포함될 수 있습니다. 데이터 이질성 문제는 연합 학습에서 특히 중요합니다. 모델의 학습 효과와 일반화 능력에 직접적인 영향을 미칠 수 있기 때문입니다.

본 논문에서는 데이터 이질성을 위해 기존 솔루션이 주로 동일한 도메인의 모든 개인 데이터에 중점을 두고 있음을 지적합니다. 분산 데이터가 다른 도메인에서 발생하는 경우 프라이빗 모델은 다른 도메인(도메인 오프셋 포함)에서 성능 저하를 나타내는 경향이 있으며 글로벌 신호는 풍부하고 공정한 도메인 정보를 캡처할 수 없습니다. 따라서 저자는 최적화된 글로벌 모델이 연합학습 과정에서 여러 도메인에 대한 일반화 성능을 안정적으로 제공할 수 있을 것으로 기대한다.

본 논문에서 저자는 도메인 전환 시 연합 학습을 위한 "Federated Prototype Learning"(FPL)을 제안합니다. 핵심 아이디어는 풍부한 도메인 지식과 공정한 융합 목표를 제공하는 클러스터링된 프로토타입과 편견 없는 프로토타입을 구축하는 것입니다. 한편으로, 샘플 임베딩은 다른 카테고리의 클러스터 프로토타입에서 멀어지고 동일한 의미의 클러스터 프로토타입에 더 가깝게 이동됩니다. 반면, 로컬 인스턴스를 해당하는 편견 없는 프로토타입과 정렬하기 위해 일관성 정규화가 도입되었습니다.

본 논문에서는 MindSpore를 기반으로 프레임워크 개발 및 실험을 수행하며 Digits 및 Office Caltech 과제 등의 실험 결과를 통해 제안된 솔루션의 효율성과 핵심 모듈의 효율성을 입증합니다.

02

팀 소개

논문의 제1저자인 황원커(Huang Wenke) 는 현재 우한대학교에서 석사 및 박사과정(2021~현재)을 공부하고 있으며, 그의 멘토는 두보(Du Bo) 교수와 예망(Ye Mang) 교수이다. 우한대학교에서 학사 학위를 취득했으며, 주요 연구 방향은 연합 학습, 그래프 학습, 금융 기술 등입니다. 현재 CVPR, IJCAI, ACM MM 등 주요 국제 학회에서 제1저자로 4편의 논문을 발표했습니다. 대학원 기간 동안 그는 Guotai Junan 장학금 및 우수 대학원생과 같은 타이틀을 획득했습니다. 알리바바그룹, 마이크로소프트 리서치 아시아 등에서 연구원으로 근무

논문의 교신저자인 예망(Ye Mang)은 우한대학교 컴퓨터과학과 교수이자 박사 지도교수이며, 국가 수준의 고급 청년 인재이자 중국 과학기술협회가 추천하는 청년 후보자이다. 그는 Emirates Origin 인공 지능 연구소의 연구 과학자와 미국 컬럼비아 대학교의 방문 학자로 재직했습니다. 그의 주요 연구 방향에는 컴퓨터 비전, 멀티미디어 검색, 연합 학습 등이 포함됩니다. 그는 국제 저널 및 컨퍼런스에 80편 이상의 논문을 발표했으며, ESI에서 많이 인용된 논문 10편을 발표했으며, Google Scholar에서 5,600회 이상 인용되었습니다. CVPR24, ACM MM23 등 학술회의 현장의장 역임. 후베이성 핵심 R&D 계획, 중국 국립자연과학재단 등 과학 연구 프로젝트를 주관합니다. Google 우수 장학금, 최고 국제 컴퓨터 비전 컨퍼런스인 ICCV2021에서 드론 표적 재식별 트랙 챔피언, 2021~2022 스탠포드 순위에서 '세계 최고 과학자 상위 2%', 2022 Baidu AI Chinese Young Scholar를 수상했습니다. .

연구팀 MARS는 예망 교수가 지휘하며 감시 영상 보행자/행동 분석, 비지도/반지도 학습, 교차 모드 이해 및 추론, 연합 학습에 중점을 두고 있습니다.

03

논문 소개

3.1 소개

앞서 언급한 연구 배경을 바탕으로 이 논문에서는 연합 다중 도메인 일반화 문제를 해결하기 위해 연합 프로토타입 학습을 제안합니다. 개인 데이터는 서로 다른 분야에서 오고, 서로 다른 클라이언트는 매우 다른 특징 분포를 갖습니다. , 개인 모델은 다른 도메인에서 잘 수행되지 않습니다. 예를 들어 회색조 이미지 MNIST에 대해 훈련된 로컬 모델 A는 서버에 의해 집계된 후 컬러 이미지 SVHN 데이터 세트와 같은 다른 클라이언트에서 정상적으로 수행할 수 없습니다. 이 로컬 모델 A는 SVHN 도메인 정보를 학습할 수 없기 때문에 성능이 저하됩니다. 하락.

글로벌 신호는 여러 분야의 지식 정보를 표현할 수 없고, 지배적인 분야의 정보 쪽으로 치우칠 수 있으므로 일반화 능력이 떨어진다. 모델이 풍부한 다중 도메인 지식을 학습하고 공유 신호를 사용하여 다중 도메인에 정보를 제공하여 일반화 능력을 향상시킬 수 있도록 하기 위해 본 논문에서는 클러스터 프로토타입을 사용하여 다양한 도메인의 정보를 표현하고 대조 학습을 사용하여 공통성을 향상시키는 것을 제안합니다. 잠재적인 지배적 도메인에 대한 최적화를 방지하고 일부 도메인의 능력을 향상시키기 위해 클러스터 프로토타입(Cluster Prototypes Contrastive Learning)이라고 함. 이를 편견 없는 프로토타입 일관성 정규화(Unbiased Prototypes Consistency Regularization)라고 합니다.

3.2 방법

3.2.1 준비

연합 학습

그림일반적인 연합 학습 설정에는 다음과 같이 표현되는 참가자와 해당 개인 데이터가 있습니다 .

그림

그 중 그림로컬 데이터 규모를 나타냅니다. 이기종 연합 학습 환경에서 조건부 기능 분포는 일관 그림되더라도 참가자마다 다르므로 그림도메인 이동으로 이어집니다. 도메인 오프셋을 다음과 같이 정의합니다.

그림

이는 개인 데이터에 도메인 오프셋이 있음을 의미합니다. 특히, 동일한 레이블 공간의 경우 여러 참여자 간에 고유한 기능 분포가 있습니다.

그림그림 1 로컬 클라이언트 데이터 소스 도메인은 다르며 그 차이가 큽니다.

또한 모든 참가자는 합의에 도달하고 동일한 아키텍처로 모델을 공유합니다. 이 모델은 특징 추출기와 분류기라는 두 가지 주요 부분으로 볼 수 있습니다. 로 표시된 특징 추출기는 그림샘플 x를 그림다음과 같이 표현되는 특징 공간의 1차원 특징 벡터로 인코딩합니다.

그림

분류자는 기능을 로지트 출력에 매핑하며 그림, 이는 후속 공식에서 그림분류 범주를 나타냅니다. 최적화 목표는 연합 학습 프로세스를 통해 여러 도메인에서 좋은 성능을 보이는 일반화 가능한 글로벌 모델을 학습하는 것입니다.

기능 프로토타입

후속 프로토타입 관련 메서드를 구현하기 위해 이 기사에서는 먼저 프로토타입의 정의를 구성합니다.

그림

client-th의 레이블을 그림나타내는 프로토타입은 client-th의 레이블을 갖는 모든 샘플의 특징 벡터의 평균을 계산하여 얻어지며, 이는 이 클라이언트의 레이블이 나타내는 도메인 정보를 직관적으로 나타냅니다.그림그림그림그림그림

먼저 이 문서의 방법을 무시하는 경우 가장 일반적인 방법은 모든 클라이언트 태그의 도메인 정보를 직접 평균화하고 그림, 모든 클라이언트가 이 정보를 학습하도록 하고, 로컬 클라이언트 업데이트를 제한하는 것입니다.

그림

이는 전체 페더레이션 시스템에서와 같이 라벨이 붙은 다양한 필드의 모든 샘플에 대한 평균 도메인 정보를 그림나타냅니다 . 그림그러나 이러한 전역적 관점은 다양한 필드의 정보를 정확하게 표현할 수 없으며 그림 2a에서 볼 수 있듯이 지배적인 필드에 편향되어 일부 필드를 무시할 수도 있습니다.

그림그림 2 다양한 유형의 프로토타입 표현

3.2.2 클러스터 프로토타입 비교학습

글로벌 프로토타입의 문제를 해결하기 위해 본 논문에서는 먼저 비지도 클러스터링을 위한 FINCH 방법을 사용하여 광범위한 도메인 지식(각 샘플의 특징 벡터)을 비지도 없이 분리합니다. 특징 벡터를 사용하면 서로 다른 필드가 서로 다른 클러스터로 클러스터링되고 그림 2b와 같이 이 클러스터의 프로토타입이 동일한 클러스터 내에서 계산되어 여러 도메인이 서로 평균을 낸 후 모든 유용한 도메인 지식에서 멀어지는 것을 방지할 수 있습니다. .

그림

위 수식에서는 클러스터링된 그림라벨 의 클러스터 프로토타입 집합을 나타냅니다.그림그림

이를 기반으로 본 논문에서는 새로운 손실항을 추가하여 클러스터 프로토타입 비교학습을 구현한다. 그림특정 샘플에 속하는 샘플 그림의 특징 벡터는 입니다 . 본 기사 에서는 샘플과 동일한 의미에 속하는 모든 프로토타입 간의 거리를 최대한 그림줄이거나 동일 하의 유사성을 향상시키기 위해 비교 학습을 사용합니다. 그림동시에, 속하지 않는 모든 프로토타입 그림( 으로 표시됨 그림)과의 유사성을 최대한 줄입니다. 이 방법을 통해 로컬 업데이트 중에 다양한 분야의 풍부한 지식을 학습하고 일반화 능력을 향상시킬 수 있습니다. 향상. 저자는 샘플 특징 벡터와 프로토타입 간의 유사성을 다음과 같이 정의합니다.

그림

그런 다음 클러스터 프로토타입 비교 학습을 구현하는 손실 항을 구성합니다.

그림

이 접근 방식이 작동하는 이유는 무엇입니까? 저자는 다음과 같은 분석을 제공합니다.

그림

그림이 손실 함수를 최소화하는 것은 샘플 특징 벡터를 할당된 포지티브 클러스터 프로토타입에 가깝게 가져오고 특징 벡터를 다른 네거티브 프로토타입에서 멀리 이동시키는 것과 같습니다 그림. 이는 다양한 도메인 왜곡에 대한 불변성을 유지할 뿐만 아니라 의미론의 확산 특성을 향상시켜 특징 공간이 일반화 가능하고 식별 가능하도록 보장함으로써 연합 학습에서 만족스러운 일반화 성능을 달성합니다.

3.2.3 편향되지 않은 프로토타입 일관성 정규화

클러스터 프로토타입은 도메인 이전 시 가소성에 다양한 도메인 지식을 가져오지만 비지도 클러스터링 방식으로 인해 각 통신마다 클러스터 프로토타입이 동적으로 생성되고 규모가 변경됩니다. 따라서 클러스터 프로토타입은 서로 다른 통신 시대에 안정적인 융합 방향을 제시할 수 없습니다. 본 논문에서는 공정하고 안정적인 편견 없는 프로토타입을 구축하고 여러 클러스터 프로토타입과 편견 없는 프로토타입 사이의 거리를 제한함으로써 지속적인 다중 영역 공정성을 보장하는 두 번째 방법을 제안합니다.

특히, 클러스터링된 동일한 레이블 아래의 여러 클러스터 프로토타입은 그림그림 2c에 표시된 대로 레이블 아래의 편향되지 않은 수렴 목표를 나타내기 위해 평균화됩니다.

그림

이 기사에서는 두 번째 손실 항을 소개하고 일관성 정규화 항을 사용하여 샘플의 특징 벡터를 해당 편견 없는 프로토타입에 더 가깝게 가져와 수렴 그림불안정성 문제를 해결하기 위한 비교적 공정하고 안정적인 최적화 지점을 제공합니다.

그림

3.2.4 전체 알고리즘

위의 두 가지 손실 외에도 기존 모델 훈련에서 사용되는 교차 엔트로피 손실 함수가 본 논문에서 제안하는 연합 프로토타입 학습의 손실 함수로 사용됩니다.

그림

학습 과정:

그림

알고리즘 :

그림

04

실험 결과

4.1 최신기술의 실험결과와의 비교

이 기사는 Digits 및 Office Caltech 데이터 세트에서 테스트되었습니다. 전자는 4개의 동일한 레이블과 다른 데이터 소스가 있는 디지털 데이터 세트이고, 후자는 4개의 동일한 레이블과 다른 데이터 소스가 있는 실제 데이터 세트입니다. 실험 결과 제안된 FPL은 단일 분야의 성능과 여러 분야의 평균 성능 모두에서 현재 SOTA보다 우수하다는 것을 보여주었다.

그림

 

4.2 절제 실험

그림

대부분의 경우 CPCL과 UPCR이 함께 작동하여 더 나은 성능을 생성하는 것을 볼 수 있습니다.

그림

일반적인 글로벌 프로토타입과 제안된 프로토타입을 사용한 두 가지 방법으로 입증된 실험 결과를 비교하면 클러스터링된 프로토타입과 편견 없는 프로토타입의 효율성이 입증됩니다.

4.3 MindSpore 코드 표시

이 프레임워크는 MindSpore를 기반으로 개발되었습니다.

4.3.1 MindSpore는 클러스터 프로토타입의 비교 학습을 구현합니다.

def calculate_infonce(self, f_now, label, all_f, all_global_protos_keys):
        pos_indices = 0
        neg_indices = []
        for i, k in enumerate(all_global_protos_keys):
            if k == label.item():
                pos_indices = i
            else:
                neg_indices.append(i)

        f_pos = Tensor(all_f[pos_indices][0]).reshape(1,512)
        f_neg = ops.cat([Tensor(all_f[i]).reshape(-1, 512) for i in neg_indices], axis=0)
        #aaa
        f_proto = ops.cat((f_pos, f_neg), axis=0)
        f_now = f_now.reshape(1,512)

        f_now_np = f_now.asnumpy()
        f_proto_np = f_proto.asnumpy()
        def cosine_similarity_numpy(vec_a, vec_b):
            dot_product = np.dot(vec_a, vec_b.T)
            norm_a = np.linalg.norm(vec_a, axis=1, keepdims=True)
            norm_b = np.linalg.norm(vec_b, axis=1)
            return dot_product / (norm_a * norm_b)
        l_np = cosine_similarity_numpy(f_now_np, f_proto_np)
        l = Tensor(l_np)

        #l = ops.cosine_similarity(f_now, f_proto, dim=1)
        l = ops.div(l, self.infoNCET)

        exp_l = ops.exp(l).reshape(1, -1)

        pos_num = f_pos.shape[0]
        neg_num = f_neg.shape[0]
        pos_mask = Tensor([1] * pos_num + [0] * neg_num).reshape(1, -1)

        pos_l = exp_l * pos_mask
        sum_pos_l = ops.sum(pos_l, dim=1)
        sum_exp_l = ops.sum(exp_l, dim=1)
        infonce_loss = -ops.log(sum_pos_l / sum_exp_l)
        return Tensor(infonce_loss)

4.3.2  MindSpore는 편견 없는 프로토타입 일관성 정규화를 실현합니다 **

def hierarchical_info_loss(self, f_now, label, mean_f, all_global_protos_keys):


        pos_indices = 0
        for i, k in enumerate(all_global_protos_keys):
            if k == label.item():
                pos_indices = i



        mean_f_pos = Tensor(mean_f[pos_indices])
        f_now = Tensor(f_now)

        cu_info_loss = self.loss_mse(f_now, mean_f_pos)

        return cu_info_loss

4.3.3 클라이언트 로컬 모델 훈련

 def _train_net(self, index, net, train_loader):

        if len(self.global_protos) != 0:
            all_global_protos_keys = np.array(list(self.global_protos.keys()))
            all_f = []
            mean_f = []
            for protos_key in all_global_protos_keys:
                temp_f = self.global_protos[protos_key]
                all_f.append(copy.deepcopy(temp_f))
                mean_f.append(copy.deepcopy(np.mean(temp_f, axis=0)))
            all_f = [item.copy() for item in all_f]
            mean_f = [item.copy() for item in mean_f]
        else:
            all_f = []
            mean_f = []
            all_global_protos_keys = []        

        optimizer = nn.SGD(net.trainable_params(), learning_rate=self.local_lr, momentum=0.9, weight_decay=1e-5)
        criterion1 = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        criterion = CustomLoss(criterion1, self.loss2)
        self.loss_mse = mindspore.nn.MSELoss()
        train_net= nn.TrainOneStepCell(nn.WithLossCell(net,criterion), optimizer=optimizer)
        train_net.set_train(True)

        iterator = tqdm(range(self.local_epoch))
        for iter in iterator:

            agg_protos_label = {}
            for di in train_loader.create_dict_iterator():
                images = di["image"]
                labels = di["label"]

                #   train_net.set_train(False)
                f = net.features(images)
                #train_net.set_train(True)

                if len(self.global_protos) == 0:
                    loss_InfoNCE = 0 
                else:
                    i = 0
                    loss_InfoNCE = None

                    for label in labels:
                        if label in all_global_protos_keys:

                            f_now = f[i]
                            cu_info_loss = self.hierarchical_info_loss(f_now, label, mean_f, all_global_protos_keys)
                            xi_info_loss = self.calculate_infonce(f

05

요약 및 전망

본 논문에서는 이종 연합 학습에서 도메인 이전에 따른 일반화 및 안정성 문제를 탐구합니다. 우리 연구에서는 FPL(Federated Prototype Learning)이라는 간단하면서도 효과적인 연합 학습 알고리즘을 소개합니다. 우리는 프로토타입(클래스의 표준 표현)을 활용하여 클러스터링된 프로토타입과 편향되지 않은 프로토타입의 보완적인 이점, 즉 다양한 도메인 지식과 안정적인 수렴 신호를 즐기면서 이 두 가지 문제를 해결합니다. 우리는 Sunthink MindSpore 아키텍처를 사용하여 FPL 프레임워크를 구현했으며 효율성과 정확성의 장점을 입증했습니다.

FPL 프레임워크를 개발하기 위해 Shengsi MindSpore를 사용할 때 우리는 Shengsi MindSpore 커뮤니티가 매우 활발하다는 것을 알게 되었고, 프레임워크 구축 시 직면했던 어려움에 대해 많은 Huawei 개발자와 사용자가 큰 도움을 주었습니다. 뿐만 아니라 MindSpore에서 제공하는 풍부한 문서와 튜토리얼은 물론 커뮤니티의 실제 사례와 모범 사례를 통해 우리는 많은 잠재적인 함정을 피하고 연구 목표를 더 빨리 달성했습니다.

1990년대에 태어난 프로그래머가 비디오 포팅 소프트웨어를 개발하여 1년도 안 되어 700만 개 이상의 수익을 올렸습니다. 결말은 매우 처참했습니다! Google은 Flutter, Dart 및 Python 팀의 중국 코더의 "35세 저주"와 관련된 정리해고를 확인했습니다 . | Daily Windows 1.0용 Arc Browser가 3개월 만에 공식적으로 GA Windows 10 시장 점유율이 70%에 도달했으며 Windows 11 GitHub는 AI 기본 개발 도구 GitHub Copilot Workspace JAVA를 계속해서 출시했습니다 . OLTP+OLAP을 처리할 수 있는 유일한 강력한 유형의 쿼리입니다. 우리는 너무 늦게 만났습니다 .
{{o.이름}}
{{이름}}

추천

출처my.oschina.net/u/4736317/blog/11072527