【Introduction de l'article】- STFL : un cadre d'apprentissage fédéré spatio-temporel pour les réseaux de neurones de graphes

Informations papier

STFL : un cadre d'apprentissage fédéré spatio-temporel pour les réseaux de neurones de graphes
insérez la description de l'image ici

Adresse d'origine : STFL : A Spatial-Temporal Federated Learning Framework for Graph Neural Networks : https://arxiv.org/abs/2111.06750
Code source : https://github.com/JW9MsjwjnpdRLFw/TSFL

Résumé

Nous présentons un cadre d'apprentissage fédéré spatio-temporel pour les réseaux de neurones de graphes, à savoir STFL. Le cadre explore la corrélation sous-jacente des données spatio-temporelles d'entrée et la transforme à la fois en caractéristiques de nœud et en matrice de contiguïté. Le cadre d'apprentissage fédéré dans le cadre garantit la confidentialité des données tout en réalisant une bonne généralisation du modèle. Les résultats des expériences sur l'ensemble de données sur les phases de sommeil, ISRUC_S3, illustrent l'efficacité de STFL sur les tâches de prédiction de graphes.

Nous proposons STFL, un cadre d'apprentissage spatio-temporel fédéré pour les réseaux de neurones de graphes. Le cadre exploite les corrélations sous-jacentes des données spatio-temporelles d'entrée et les transforme en caractéristiques de nœud et en matrice de contiguïté. Le cadre d'apprentissage fédéré dans le cadre permet une bonne généralisation du modèle tout en garantissant la confidentialité des données. Les résultats expérimentaux sur l'ensemble de données sur les phases de sommeil ISRUC S3 illustrent l'efficacité de STFL sur les tâches de prédiction de graphes.

Contributions

  1. Nous implémentons d'abord un générateur de graphes pour le traitement de données spatio-temporelles, y compris l'extraction de caractéristiques et l'exploration de corrélation de nœuds ;
  2. En intégrant le générateur de graphes dans le STFL proposé, un cadre d'apprentissage fédéré de bout en bout de GNN spatio-temporels sur des tâches de classification au niveau du graphe est conçu ;
  3. Des expériences approfondies ont été menées sur l'ensemble de données sur le sommeil réel ISRUC S3 ;
  4. Publiez le code source de STFL sur Github1.

Méthodologie

Cadre STFL :
insérez la description de l'image ici

Génération de graphes

Traitez la série spatio-temporelle comme une entrée brute. Définition Une séquence multivariée insérez la description de l'image iciest définie comme l'ensemble de séries chronologiques avec un total de T horodatages, dont chacun a une fréquence de signal si ∈ RDde dimension . Puisqu'il n'y a pas de concept de nœud dans les données spatio-temporelles, nous exploitons les canaux spatiaux et les traitons comme des nœuds, ce qui signifie que s'il y a N canaux, il y aura N nœuds dans la structure de données du graphe transformé.

En supposant que chaque canal a un ensemble de séries temporelles S, la série spatio-temporelle avec des canaux complets est notée insérez la description de l'image ici.
Ensuite, les données spatio-temporelles d'origine sont converties en une représentation matricielle d'entités à l'aide d'un réseau d'extraction d'entités basé sur CNN, et la sortie du réseau d'extraction d'entités est , insérez la description de l'image icioù d représente la dimension de l'entité. insérez la description de l'image iciUn instantané de est représenté par insérez la description de l'image ici.
Après avoir obtenu la matrice de caractéristiques raffinée insérez la description de l'image ici, la corrélation entre les canaux (nœuds) doit être révélée. À ce stade, il est naturel de traiter XT ∈ RN×dasla matrice des caractéristiques des nœuds et de récupérer les corrélations potentielles entre eux. Ensuite, nous définissons la fonction de corrélation de nœud, qui prend une matrice de caractéristiques de nœud en entrée et produit une matrice de contiguïté AT∈RN×N::
insérez la description de l'image ici
où Corr( ) calcule la corrélation ou la dépendance de chaque canal (nœud) sur la base de XT . Il existe plusieurs options pour la fonction de corrélation nodale, telles que la fonction de corrélation de Pearson ou la fonction de valeur à verrouillage de phase.

Réseau neuronal graphique

Le long de la dimension temporelle, nous obtenons {G1, ..., GT} comme ensemble de données de graphique complet, indiquant les données de graphique générées à chaque horodatage, et nous utilisons {y1, ..., yT} pour correspondre aux étiquettes de graphique. Nous formulons ici la tâche de prédiction de graphe, où la sortie du générateur de graphe devrait être correctement prédite. Pour la simplicité de la notation, nous utilisons VT pour représenter l'ensemble de nœuds dans chaque GT, et le nombre de nœuds V est fondamentalement le même que le numéro de ligne dans la matrice de caractéristiques des nœuds XT. Pour chaque v ∈ V, les caractéristiques des nœuds correspondants sont écrites sous la forme xv∈ Rd.

On utilise ne[v] pour désigner le voisinage du nœud v, dont les valeurs associées peuvent être récupérées à partir de la matrice d'adjacence A. Ensuite, nous formulons les étapes de passage et de lecture du message de GNN. entraîneurSoit les plongements de nœuds dans la couche l. Le message passant du nœud v de la couche l à la couche l+1 peut être formalisé comme :
insérez la description de l'image icioù, insérez la description de l'image icireprésente la matrice de transformation apprenable de la couche l+1, et σ représente la fonction d'activation. Les GNN mettent à jour les nœuds hl 1vintégrés .

Pour obtenir une représentation de l'ensemble du graphe après la couche de transmission de messages de niveau L, GNN effectue une opération de lecture pour dériver la représentation graphique finale des incorporations de nœuds, qui peut être formulée comme suit : Readout( ) est une opération invariante de permutation,
insérez la description de l'image ici
qui peut être simplement La fonction moyenne peut également être une fonction de regroupement au niveau graphique plus complexe, telle que MLP.
Dans le cadre entièrement supervisé, nous utilisons un réseau de neurones peu profond pour apprendre une correspondance entre les plongements de graphes et l'espace d'étiquettes Y. σ( ) est une transformation non linéaire qui peut être généralisée comme suit : insérez la description de l'image ici
De plus, nous utilisons une fonction d'entropie croisée binaire basée sur un graphe pour calculer la perte L dans le cadre supervisé. La formule de la fonction de perte est :
insérez la description de l'image ici

apprentissage fédéré

STFL forme les GNN de différents clients dans un cadre d'apprentissage fédéré. STFL se compose d'un serveur central S et de n clients C. Chaque client déploie un GNN qui apprend le client à partir des données graphiques locales et télécharge les poids du GNN sur un serveur central. Le serveur central reçoit les poids de tous les clients, met à jour les poids WS du modèle GNN global et redistribue les poids mis à jour à chaque client. Dans ce travail, nous choisissons FedAvg comme fonction d'agrégation, qui fait la moyenne des poids de chaque client pour générer les poids du GNN global sur le serveur.
insérez la description de l'image ici

Expérience

base de données

Dans nos expériences, ISRUC S3 (Khalighi et al., 2016) est utilisé comme ensemble de données de référence. ISRUC S3 collecte les enregistrements polysomnographiques (PSG) de 10 canaux de 10 sujets sains (c'est-à-dire les participants à l'expérience du sommeil). Ces enregistrements PSG ont été étiquetés pour cinq stades de sommeil distincts, y compris le réveil, N1, N2, N3 et REM, selon les critères AASM (Jia et al., 2020). Comme décrit dans la section précédente, nous utilisons un réseau d'extraction de caractéristiques basé sur CNN (Jia et al., 2021) pour générer les caractéristiques initiales des nœuds. Pour générer la matrice de contiguïté, quatre fonctions d'association de nœuds différentes sont implémentées et discutées séparément. Pour évaluer l'efficacité de STFL, nous suivons le cadre de données non iid (Zhang et al., 2020) et attribuons différentes phases de sommeil aux clients pour vérifier l'efficacité de notre cadre proposé.

Fonctions de corrélation de nœud

  • DB est la fonction de distance euclidienne utilisée pour mesurer la distance spatiale entre les paires d'électrodes.
  • K-NN (Jiang et al. 2013) génère une matrice d'adjacence qui sélectionne uniquement les k voisins les plus proches de chaque nœud pour représenter les dépendances de nœud d'un graphe.
  • PCC (Pearson et Lee 1903) est connue sous le nom de fonction de corrélation de Pearson et est utilisée pour mesurer la similarité entre chaque paire de nœuds.
  • PLV (Aydore, Pantazis et Leahy 2013) est une fonction de corrélation de nœuds variant dans le temps qui mesure le signal de chaque paire de nœuds.

Analyse comparative des performances

  1. Pour évaluer l'efficacité des quatre fonctions liées aux nœuds, nous comparons l'effet de chaque fonction liée aux nœuds sur GCN dans le cadre fédéré, puisque GCN a la structure la plus simple parmi les trois modèles GNN. Comme le montre la figure 2, PCC et PLV fonctionnent bien dans le cadre conjoint, avec des taux de convergence plus rapides, en particulier au cours des deux premières époques. De plus, par rapport aux autres fonctions liées aux nœuds, comme le montre le tableau 2, les scores F1 de PLV des 3 modèles fédérés sont les plus élevés, suivis de PCC, et DB est le pire. Cela peut être dû à la couche de regroupement dans le modèle CNN (réseau d'extraction de fonctionnalités), qui examine une petite fenêtre temporelle de la séquence d'entrée, à partir de laquelle la corrélation correcte pour chaque paire de nœuds peut être extraite à l'aide de PLV.
    insérez la description de l'image iciinsérez la description de l'image ici

  2. Pour évaluer l'efficacité de STFL, nous avons testé ses performances sous différents angles. Dans nos expériences, nous évaluons d'abord le modèle de graphe fédéré sur ISRUC S3 avec PLV, puisque PLV se forme mieux pour chacune des quatre fonctions liées aux nœuds discutées dans RQ1. Comme le montre le tableau 3, sous STFL, les trois modèles GNN produisent des résultats raisonnables. Surtout dans le cadre commun, GAT obtient le score et la précision F1 les plus élevés sur PLV, et GraphSage arrive en deuxième position.
    insérez la description de l'image iciinsérez la description de l'image iciDe plus, nous examinons les résultats des modèles centralisés pour ces trois réseaux de graphes, et les résultats sont également présentés dans le tableau 3. Dans cette partie, les hyperparamètres sont maintenus constants avec les expériences conjointes. Pour le fractionnement des données, les données de test sont les mêmes que les données des expériences d'apprentissage fédéré. Les données de formation sont échantillonnées au hasard à partir des données agrégées de tous les clients, et la taille des données de formation est la même que celle d'un client. Pour tous les GNN dans le cadre centralisé, GraphSage obtient le score et la précision F1 les plus élevés, suivi du GCN. De plus, tous les modèles entraînés dans le cadre conjoint obtiennent de meilleurs résultats (score F1 et précision) par rapport au cadre centralisé. Cela indique que les modèles formés sous STFL génèrent avec succès des distributions de données dans des contextes non IID. Une autre constatation est que le meilleur modèle GNN dans un cadre centralisé n'est pas nécessairement le meilleur dans un cadre fédéré.

  3. Pour déterminer la meilleure correspondance entre les GNN et STFL, trois GNN ont été testés sur ISRUC S3 et PLV dans le cadre conjoint, car il a été observé que PLV obtenait les meilleurs résultats parmi toutes les fonctions liées aux nœuds, dont les détails sont analysés dans RQ1 . Comme le montre la figure 3, GCN converge le plus rapidement mais est plus instable que les deux autres. Nous constatons également que GraphSage converge le plus lentement à la première époque, mais atteint une réduction constante des pertes pendant la phase de test. Il a également constaté que les trois modèles ont finalement convergé vers la même perte, fluctuant autour de 0,15. De plus, nous évaluons le score F1 de chaque classe à l'aide de PLV. Le tableau 4 montre que pour REM, GraphSage obtient les meilleurs résultats, tandis que GCN obtient les meilleurs résultats dans les quatre autres catégories. insérez la description de l'image ici
    Fait intéressant, la perte d'entraînement des trois modèles fluctue dans une large gamme, en particulier au cours des trois dernières époques. C'est probablement parce que le cadre commun distribue le modèle global à chaque client dans chaque lot de formation. Dans les dernières étapes de la formation, chaque client ne peut pas bien ajuster ses propres données dans le modèle global généralisé, en particulier pour les modèles qui sont sujets au surajustement.

Je suppose que tu aimes

Origine blog.csdn.net/weixin_43598687/article/details/131141861
conseillé
Classement