Eine Pytorch -Implementierung des neuronalen Netzwerks des Kapselgrafiks (ICLR 2019).

Der hochwertige Knoten-Einbettungsdings, den die Grafik Neural Networks (GNNs) gelernt haben, wurden auf eine breite Palette von notenbasierten Anwendungen angewendet, und einige von ihnen haben die Leistung des neuesten Stand der Technik (SOTA) erzielt. Bei der Anwendung von Knoteneinbetten, die aus GNNs zur Erzeugung von Graphenbettendings gelernt wurden, reicht die skalare Knotendarstellung möglicherweise nicht aus, um die Noten-/Grapheneigenschaften effizient zu erhalten, was zu suboptimalen Graphen-Einbettungen führt. Inspiriert vom Capsule Neural Network (CAPSNet) schlagen wir das Capsule Graph Neural Network (CapSGNN) vor, das das Konzept von Kapseln annimmt, um die Schwäche in vorhandenen GNN-basierten Graphen-Einbettungsalgorithmen anzugehen. Durch das Extrahieren von Knotenmerkmalen in Form von Kapseln kann der Routing -Mechanismus verwendet werden, um wichtige Informationen auf Diagrammebene zu erfassen. Infolgedessen generiert unser Modell für jedes Diagramm mehrere Einbettungen, um die Grapheneigenschaften aus verschiedenen Aspekten zu erfassen. Das in CapSGNN enthaltene Aufmerksamkeitsmodul wird verwendet, um Grafiken mit verschiedenen Größen anzugehen, mit denen sich das Modell auch auf kritische Teile der Grafiken konzentrieren kann. Unsere umfangreichen Bewertungen mit 10 graph-strukturierten Datensätzen zeigen, dass CapSGNN einen leistungsstarken Mechanismus aufweist, der die makroskopischen Eigenschaften des gesamten Diagramms durch datengesteuerte erfasst. Es übertrifft andere SOTA -Techniken bei mehreren Diagrammklassifizierungsaufgaben aufgrund des neuen Instruments.
Dieses Repository bietet eine Pytorch -Implementierung von CapSgnn, wie im Papier beschrieben:
Kapselgrafik Neuronales Netzwerk. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Papier]
Die angepasste Kernkapsel -Netzwerk -Implementierung ist [hier] verfügbar.
Die Codebasis ist in Python 3.5.2 implementiert. Paketversionen, die für die Entwicklung verwendet werden, finden Sie direkt unten.
networkx 2.4
tqdm 4.28.1
numpy 1.15.4
pandas 0.23.4
texttable 1.5.0
scipy 1.1.0
argparse 1.1.0
torch 1.1.0
torch-scatter 1.4.0
torch-sparse 0.4.3
torch-cluster 1.4.5
torch-geometric 1.3.2
torchvision 0.3.0
Der Code nimmt Diagramme für das Training aus einem Eingabedordner an, in dem jedes Diagramm als JSON gespeichert ist. Diagramme, die zum Testen verwendet werden, werden ebenfalls als JSON -Dateien gespeichert. Jede Knoten -ID und Knotenbezeichnung muss aus 0 in indiziertem Wörterbüchern indexiert werden, um die JSON -Serialisierung zu ermöglichen.
Jede JSON-Datei hat die folgende Schlüsselwertstruktur:
{ "edges" : [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 4 ] ] ,
"labels" : { "0" : "A" , "1" : "B" , "2" : "C" , "3" : "A" , "4" : "B" } ,
"target" : 1 }Die Taste ** Kanten ** hat einen Kantenlistenwert, der die Konnektivitätsstruktur entfaltet. Die Taste ** Labels ** enthält Etiketten für jeden Knoten, die als Wörterbuch gespeichert werden - innerhalb dieser verschachtelten Wörterbuchetiketten sind Werte, Knotenkennungen sind Schlüssel. Der ** Ziel ** Taste hat einen Ganzzahlwert, der die Klassenmitgliedschaft ist.
Die Vorhersagen werden im Verzeichnis "Ausgang/" gespeichert. Jede Einbettung hat einen Header und eine Spalte mit den Grafikkennung. Schließlich werden die Vorhersagen nach der Identifikator -Spalte sortiert.
Training A CapSGNN -Modell wird vom Skript "src/main.py" geleitet, das die folgenden Befehlszeilenargumente liefert.
--training-graphs STR Training graphs folder. Default is `input/train/`.
--testing-graphs STR Testing graphs folder. Default is `input/test/`.
--prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
--epochs INT Number of epochs. Default is 100.
--batch-size INT Number fo graphs per batch. Default is 32.
--gcn-filters INT Number of filters in GCNs. Default is 20.
--gcn-layers INT Number of GCNs chained together. Default is 2.
--inner-attention-dimension INT Number of neurons in attention. Default is 20.
--capsule-dimensions INT Number of capsule neurons. Default is 8.
--number-of-capsules INT Number of capsules in layer. Default is 8.
--weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
--lambd FLOAT Regularization parameter. Default is 0.5.
--theta FLOAT Reconstruction loss weight. Default is 0.1.
--learning-rate FLOAT Adam learning rate. Default is 0.01.
Die folgenden Befehle lernen ein Modell und speichern die Vorhersagen. Training eines Modells auf dem Standard -Datensatz:
$ python src/main.py
Training a CapSgnnn -Modell für 100 Epochen.
$ python src/main.py --epochs 100Ändern der Chargengröße.
$ python src/main.py --batch-size 128Lizenz