Une implémentation Pytorch de Capsule Graph Neural Network (ICLR 2019).

Les incorporations de nœuds de haute qualité apprises par les réseaux de neurones graphiques (GNNS) ont été appliqués à un large éventail d'applications basées sur des nœuds et certaines d'entre elles ont atteint des performances de pointe (SOTA). Cependant, lors de l'application des incorporations de nœuds apprises de GNNS pour générer des incorporations graphiques, la représentation du nœud scalaire peut ne pas suffire pour préserver efficacement les propriétés de nœud / graphique, ce qui entraîne des intérêts graphiques sous-optimaux. Inspiré par le réseau neuronal de capsule (CapsNet), nous proposons le réseau neuronal de graphe Capsule (CapsGNN), qui adopte le concept de capsules pour résoudre la faiblesse des algorithmes d'incorporation de graphiques GNN existants basés sur GNN. En extrayant des caractéristiques de nœud sous forme de capsules, le mécanisme de routage peut être utilisé pour capturer des informations importantes au niveau du graphique. En conséquence, notre modèle génère plusieurs intérêts pour chaque graphique pour capturer les propriétés du graphique à partir de différents aspects. Le module d'attention incorporé dans CAPSGNN est utilisé pour lutter contre les graphiques de différentes tailles qui permet également au modèle de se concentrer sur les parties critiques des graphiques. Nos évaluations approfondies avec 10 ensembles de données structurés graphiques démontrent que CAPSGNN a un mécanisme puissant qui fonctionne pour capturer les propriétés macroscopiques de l'ensemble du graphique par les données. Il surpasse les autres techniques SOTA sur plusieurs tâches de classification des graphiques, en raison du nouvel instrument.
Ce référentiel fournit une implémentation pytorch de CapSGNN comme décrit dans l'article:
Capsule Graph Network Network. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Papier]
L'implémentation de réseau neuronal de la capsule principale adaptée est disponible [ici].
La base de code est implémentée dans Python 3.5.2. Les versions de package utilisées pour le développement sont juste en dessous.
networkx 2.4
tqdm 4.28.1
numpy 1.15.4
pandas 0.23.4
texttable 1.5.0
scipy 1.1.0
argparse 1.1.0
torch 1.1.0
torch-scatter 1.4.0
torch-sparse 0.4.3
torch-cluster 1.4.5
torch-geometric 1.3.2
torchvision 0.3.0
Le code prend des graphiques pour la formation d'un dossier d'entrée où chaque graphique est stocké en JSON. Les graphiques utilisés pour les tests sont également stockés comme fichiers JSON. Chaque ID de nœud et étiquette de nœud doit être indexé à partir de 0. Les clés des dictionnaires sont stockées afin de rendre la sérialisation JSON possible.
Chaque fichier JSON a la structure de valeur clé suivante:
{ "edges" : [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 4 ] ] ,
"labels" : { "0" : "A" , "1" : "B" , "2" : "C" , "3" : "A" , "4" : "B" } ,
"target" : 1 }La touche ** ADGES ** a une valeur de liste de bord qui décrit la structure de connectivité. La clé ** Étiquettes ** a des étiquettes pour chaque nœud qui sont stockées en tant que dictionnaire - dans ce dictionnaire imbriqué, les étiquettes sont des valeurs, les identificateurs de nœud sont des clés. La clé ** Target ** a une valeur entière qui est l'adhésion à la classe.
Les prédictions sont enregistrées dans le répertoire `` Output / '. Chaque intégration a un en-tête et une colonne avec les identificateurs de graphiques. Enfin, les prédictions sont triées par la colonne d'identifiant.
Formation Un modèle Capsgnn est géré par le script `src / main.py` qui fournit les arguments de ligne de commande suivants.
--training-graphs STR Training graphs folder. Default is `input/train/`.
--testing-graphs STR Testing graphs folder. Default is `input/test/`.
--prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
--epochs INT Number of epochs. Default is 100.
--batch-size INT Number fo graphs per batch. Default is 32.
--gcn-filters INT Number of filters in GCNs. Default is 20.
--gcn-layers INT Number of GCNs chained together. Default is 2.
--inner-attention-dimension INT Number of neurons in attention. Default is 20.
--capsule-dimensions INT Number of capsule neurons. Default is 8.
--number-of-capsules INT Number of capsules in layer. Default is 8.
--weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
--lambd FLOAT Regularization parameter. Default is 0.5.
--theta FLOAT Reconstruction loss weight. Default is 0.1.
--learning-rate FLOAT Adam learning rate. Default is 0.01.
Les commandes suivantes apprennent un modèle et enregistrent les prédictions. Formation d'un modèle sur l'ensemble de données par défaut:
$ python src/main.py
Formation d'un modèle Capsgnnn pour 100 époques.
$ python src/main.py --epochs 100Modification de la taille du lot.
$ python src/main.py --batch-size 128Licence