Annuaire d'articles
Informations papier
STFL : un cadre d'apprentissage fédéré spatio-temporel pour les réseaux de neurones de graphes
Adresse d'origine : STFL : A Spatial-Temporal Federated Learning Framework for Graph Neural Networks : https://arxiv.org/abs/2111.06750
Code source : https://github.com/JW9MsjwjnpdRLFw/TSFL
Résumé
Nous présentons un cadre d'apprentissage fédéré spatio-temporel pour les réseaux de neurones de graphes, à savoir STFL. Le cadre explore la corrélation sous-jacente des données spatio-temporelles d'entrée et la transforme à la fois en caractéristiques de nœud et en matrice de contiguïté. Le cadre d'apprentissage fédéré dans le cadre garantit la confidentialité des données tout en réalisant une bonne généralisation du modèle. Les résultats des expériences sur l'ensemble de données sur les phases de sommeil, ISRUC_S3, illustrent l'efficacité de STFL sur les tâches de prédiction de graphes.
Nous proposons STFL, un cadre d'apprentissage spatio-temporel fédéré pour les réseaux de neurones de graphes. Le cadre exploite les corrélations sous-jacentes des données spatio-temporelles d'entrée et les transforme en caractéristiques de nœud et en matrice de contiguïté. Le cadre d'apprentissage fédéré dans le cadre permet une bonne généralisation du modèle tout en garantissant la confidentialité des données. Les résultats expérimentaux sur l'ensemble de données sur les phases de sommeil ISRUC S3 illustrent l'efficacité de STFL sur les tâches de prédiction de graphes.
Contributions
- Nous implémentons d'abord un générateur de graphes pour le traitement de données spatio-temporelles, y compris l'extraction de caractéristiques et l'exploration de corrélation de nœuds ;
- En intégrant le générateur de graphes dans le STFL proposé, un cadre d'apprentissage fédéré de bout en bout de GNN spatio-temporels sur des tâches de classification au niveau du graphe est conçu ;
- Des expériences approfondies ont été menées sur l'ensemble de données sur le sommeil réel ISRUC S3 ;
- Publiez le code source de STFL sur Github1.
Méthodologie
Cadre STFL :
Génération de graphes
Traitez la série spatio-temporelle comme une entrée brute. Définition Une séquence multivariée est définie comme l'ensemble de séries chronologiques avec un total de T horodatages, dont chacun a une fréquence de signal de dimension . Puisqu'il n'y a pas de concept de nœud dans les données spatio-temporelles, nous exploitons les canaux spatiaux et les traitons comme des nœuds, ce qui signifie que s'il y a N canaux, il y aura N nœuds dans la structure de données du graphe transformé.
En supposant que chaque canal a un ensemble de séries temporelles S, la série spatio-temporelle avec des canaux complets est notée .
Ensuite, les données spatio-temporelles d'origine sont converties en une représentation matricielle d'entités à l'aide d'un réseau d'extraction d'entités basé sur CNN, et la sortie du réseau d'extraction d'entités est , où d représente la dimension de l'entité. Un instantané de est représenté par .
Après avoir obtenu la matrice de caractéristiques raffinée , la corrélation entre les canaux (nœuds) doit être révélée. À ce stade, il est naturel de traiter la matrice des caractéristiques des nœuds et de récupérer les corrélations potentielles entre eux. Ensuite, nous définissons la fonction de corrélation de nœud, qui prend une matrice de caractéristiques de nœud en entrée et produit une matrice de contiguïté :
où Corr( ) calcule la corrélation ou la dépendance de chaque canal (nœud) sur la base de XT . Il existe plusieurs options pour la fonction de corrélation nodale, telles que la fonction de corrélation de Pearson ou la fonction de valeur à verrouillage de phase.
Réseau neuronal graphique
Le long de la dimension temporelle, nous obtenons {G1, ..., GT} comme ensemble de données de graphique complet, indiquant les données de graphique générées à chaque horodatage, et nous utilisons {y1, ..., yT} pour correspondre aux étiquettes de graphique. Nous formulons ici la tâche de prédiction de graphe, où la sortie du générateur de graphe devrait être correctement prédite. Pour la simplicité de la notation, nous utilisons VT pour représenter l'ensemble de nœuds dans chaque GT, et le nombre de nœuds V est fondamentalement le même que le numéro de ligne dans la matrice de caractéristiques des nœuds XT. Pour chaque v ∈ V, les caractéristiques des nœuds correspondants sont écrites sous la forme .
On utilise ne[v] pour désigner le voisinage du nœud v, dont les valeurs associées peuvent être récupérées à partir de la matrice d'adjacence A. Ensuite, nous formulons les étapes de passage et de lecture du message de GNN. Soit les plongements de nœuds dans la couche l. Le message passant du nœud v de la couche l à la couche l+1 peut être formalisé comme :
où, représente la matrice de transformation apprenable de la couche l+1, et σ représente la fonction d'activation. Les GNN mettent à jour les nœuds intégrés .
Pour obtenir une représentation de l'ensemble du graphe après la couche de transmission de messages de niveau L, GNN effectue une opération de lecture pour dériver la représentation graphique finale des incorporations de nœuds, qui peut être formulée comme suit : Readout( ) est une opération invariante de permutation,
qui peut être simplement La fonction moyenne peut également être une fonction de regroupement au niveau graphique plus complexe, telle que MLP.
Dans le cadre entièrement supervisé, nous utilisons un réseau de neurones peu profond pour apprendre une correspondance entre les plongements de graphes et l'espace d'étiquettes Y. σ( ) est une transformation non linéaire qui peut être généralisée comme suit :
De plus, nous utilisons une fonction d'entropie croisée binaire basée sur un graphe pour calculer la perte L dans le cadre supervisé. La formule de la fonction de perte est :
apprentissage fédéré
STFL forme les GNN de différents clients dans un cadre d'apprentissage fédéré. STFL se compose d'un serveur central S et de n clients C. Chaque client déploie un GNN qui apprend le client à partir des données graphiques locales et télécharge les poids du GNN sur un serveur central. Le serveur central reçoit les poids de tous les clients, met à jour les poids WS du modèle GNN global et redistribue les poids mis à jour à chaque client. Dans ce travail, nous choisissons FedAvg comme fonction d'agrégation, qui fait la moyenne des poids de chaque client pour générer les poids du GNN global sur le serveur.
Expérience
base de données
Dans nos expériences, ISRUC S3 (Khalighi et al., 2016) est utilisé comme ensemble de données de référence. ISRUC S3 collecte les enregistrements polysomnographiques (PSG) de 10 canaux de 10 sujets sains (c'est-à-dire les participants à l'expérience du sommeil). Ces enregistrements PSG ont été étiquetés pour cinq stades de sommeil distincts, y compris le réveil, N1, N2, N3 et REM, selon les critères AASM (Jia et al., 2020). Comme décrit dans la section précédente, nous utilisons un réseau d'extraction de caractéristiques basé sur CNN (Jia et al., 2021) pour générer les caractéristiques initiales des nœuds. Pour générer la matrice de contiguïté, quatre fonctions d'association de nœuds différentes sont implémentées et discutées séparément. Pour évaluer l'efficacité de STFL, nous suivons le cadre de données non iid (Zhang et al., 2020) et attribuons différentes phases de sommeil aux clients pour vérifier l'efficacité de notre cadre proposé.
Fonctions de corrélation de nœud
- DB est la fonction de distance euclidienne utilisée pour mesurer la distance spatiale entre les paires d'électrodes.
- K-NN (Jiang et al. 2013) génère une matrice d'adjacence qui sélectionne uniquement les k voisins les plus proches de chaque nœud pour représenter les dépendances de nœud d'un graphe.
- PCC (Pearson et Lee 1903) est connue sous le nom de fonction de corrélation de Pearson et est utilisée pour mesurer la similarité entre chaque paire de nœuds.
- PLV (Aydore, Pantazis et Leahy 2013) est une fonction de corrélation de nœuds variant dans le temps qui mesure le signal de chaque paire de nœuds.
Analyse comparative des performances
-
Pour évaluer l'efficacité des quatre fonctions liées aux nœuds, nous comparons l'effet de chaque fonction liée aux nœuds sur GCN dans le cadre fédéré, puisque GCN a la structure la plus simple parmi les trois modèles GNN. Comme le montre la figure 2, PCC et PLV fonctionnent bien dans le cadre conjoint, avec des taux de convergence plus rapides, en particulier au cours des deux premières époques. De plus, par rapport aux autres fonctions liées aux nœuds, comme le montre le tableau 2, les scores F1 de PLV des 3 modèles fédérés sont les plus élevés, suivis de PCC, et DB est le pire. Cela peut être dû à la couche de regroupement dans le modèle CNN (réseau d'extraction de fonctionnalités), qui examine une petite fenêtre temporelle de la séquence d'entrée, à partir de laquelle la corrélation correcte pour chaque paire de nœuds peut être extraite à l'aide de PLV.
-
Pour évaluer l'efficacité de STFL, nous avons testé ses performances sous différents angles. Dans nos expériences, nous évaluons d'abord le modèle de graphe fédéré sur ISRUC S3 avec PLV, puisque PLV se forme mieux pour chacune des quatre fonctions liées aux nœuds discutées dans RQ1. Comme le montre le tableau 3, sous STFL, les trois modèles GNN produisent des résultats raisonnables. Surtout dans le cadre commun, GAT obtient le score et la précision F1 les plus élevés sur PLV, et GraphSage arrive en deuxième position.
De plus, nous examinons les résultats des modèles centralisés pour ces trois réseaux de graphes, et les résultats sont également présentés dans le tableau 3. Dans cette partie, les hyperparamètres sont maintenus constants avec les expériences conjointes. Pour le fractionnement des données, les données de test sont les mêmes que les données des expériences d'apprentissage fédéré. Les données de formation sont échantillonnées au hasard à partir des données agrégées de tous les clients, et la taille des données de formation est la même que celle d'un client. Pour tous les GNN dans le cadre centralisé, GraphSage obtient le score et la précision F1 les plus élevés, suivi du GCN. De plus, tous les modèles entraînés dans le cadre conjoint obtiennent de meilleurs résultats (score F1 et précision) par rapport au cadre centralisé. Cela indique que les modèles formés sous STFL génèrent avec succès des distributions de données dans des contextes non IID. Une autre constatation est que le meilleur modèle GNN dans un cadre centralisé n'est pas nécessairement le meilleur dans un cadre fédéré. -
Pour déterminer la meilleure correspondance entre les GNN et STFL, trois GNN ont été testés sur ISRUC S3 et PLV dans le cadre conjoint, car il a été observé que PLV obtenait les meilleurs résultats parmi toutes les fonctions liées aux nœuds, dont les détails sont analysés dans RQ1 . Comme le montre la figure 3, GCN converge le plus rapidement mais est plus instable que les deux autres. Nous constatons également que GraphSage converge le plus lentement à la première époque, mais atteint une réduction constante des pertes pendant la phase de test. Il a également constaté que les trois modèles ont finalement convergé vers la même perte, fluctuant autour de 0,15. De plus, nous évaluons le score F1 de chaque classe à l'aide de PLV. Le tableau 4 montre que pour REM, GraphSage obtient les meilleurs résultats, tandis que GCN obtient les meilleurs résultats dans les quatre autres catégories.
Fait intéressant, la perte d'entraînement des trois modèles fluctue dans une large gamme, en particulier au cours des trois dernières époques. C'est probablement parce que le cadre commun distribue le modèle global à chaque client dans chaque lot de formation. Dans les dernières étapes de la formation, chaque client ne peut pas bien ajuster ses propres données dans le modèle global généralisé, en particulier pour les modèles qui sont sujets au surajustement.