Implementasi Pytorch dari Capsule Graph Neural Network (ICLR 2019).

Embeddings simpul berkualitas tinggi yang dipelajari dari grafik Neural Networks (GNNs) telah diterapkan pada berbagai aplikasi berbasis node dan beberapa di antaranya telah mencapai kinerja canggih (SOTA). Namun, ketika menerapkan embeddings simpul yang dipelajari dari GNNs untuk menghasilkan embeddings grafik, representasi node skalar mungkin tidak cukup untuk melestarikan sifat simpul/grafik secara efisien, menghasilkan embeddings grafik sub-optimal. Terinspirasi oleh Capsule Neural Network (CAPSNET), kami mengusulkan jaringan saraf grafik kapsul (CAPSGNN), yang mengadopsi konsep kapsul untuk mengatasi kelemahan dalam algoritma embeddings grafik berbasis GNN yang ada. Dengan mengekstraksi fitur simpul dalam bentuk kapsul, mekanisme perutean dapat digunakan untuk menangkap informasi penting di tingkat grafik. Akibatnya, model kami menghasilkan beberapa embeddings untuk setiap grafik untuk menangkap properti grafik dari berbagai aspek. Modul perhatian yang tergabung dalam capsgnn digunakan untuk menangani grafik dengan berbagai ukuran yang juga memungkinkan model untuk fokus pada bagian -bagian penting dari grafik. Evaluasi ekstensif kami dengan 10 set data terstruktur grafik menunjukkan bahwa CapSgnn memiliki mekanisme yang kuat yang beroperasi untuk menangkap sifat makroskopik dari seluruh grafik dengan data yang digerakkan. Ini mengungguli teknik SOTA lainnya pada beberapa tugas klasifikasi grafik, berdasarkan instrumen baru.
Repositori ini menyediakan implementasi CapSgnn Pytorch seperti yang dijelaskan dalam makalah:
Jaringan saraf grafik kapsul. Zhang Xinyi, Lihui Chen. Iclr, 2019. [Kertas]
Implementasi jaringan saraf kapsul inti yang diadaptasi tersedia [di sini].
Basis kode diimplementasikan dalam Python 3.5.2. Versi paket yang digunakan untuk pengembangan tepat di bawah ini.
networkx 2.4
tqdm 4.28.1
numpy 1.15.4
pandas 0.23.4
texttable 1.5.0
scipy 1.1.0
argparse 1.1.0
torch 1.1.0
torch-scatter 1.4.0
torch-sparse 0.4.3
torch-cluster 1.4.5
torch-geometric 1.3.2
torchvision 0.3.0
Kode mengambil grafik untuk pelatihan dari folder input di mana setiap grafik disimpan sebagai JSON. Grafik yang digunakan untuk pengujian juga disimpan sebagai file JSON. Setiap Node ID dan label node harus diindeks dari 0. Kunci kamus disimpan string untuk memungkinkan serialisasi JSON menjadi mungkin.
Setiap file JSON memiliki struktur nilai kunci berikut:
{ "edges" : [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 4 ] ] ,
"labels" : { "0" : "A" , "1" : "B" , "2" : "C" , "3" : "A" , "4" : "B" } ,
"target" : 1 }Tombol ** Tepi ** memiliki nilai daftar tepi yang menggambarkan struktur konektivitas. Label ** ** Kunci memiliki label untuk setiap node yang disimpan sebagai kamus - di dalam label kamus bersarang ini adalah nilai, pengidentifikasi simpul adalah kunci. Kunci target ** ** memiliki nilai integer yang merupakan keanggotaan kelas.
Prediksi disimpan di direktori `output/`. Setiap embedding memiliki header dan kolom dengan pengidentifikasi grafik. Akhirnya, prediksi diurutkan berdasarkan kolom pengidentifikasi.
Melatih model capsgnn ditangani oleh skrip `src/main.py` yang menyediakan argumen baris perintah berikut.
--training-graphs STR Training graphs folder. Default is `input/train/`.
--testing-graphs STR Testing graphs folder. Default is `input/test/`.
--prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
--epochs INT Number of epochs. Default is 100.
--batch-size INT Number fo graphs per batch. Default is 32.
--gcn-filters INT Number of filters in GCNs. Default is 20.
--gcn-layers INT Number of GCNs chained together. Default is 2.
--inner-attention-dimension INT Number of neurons in attention. Default is 20.
--capsule-dimensions INT Number of capsule neurons. Default is 8.
--number-of-capsules INT Number of capsules in layer. Default is 8.
--weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
--lambd FLOAT Regularization parameter. Default is 0.5.
--theta FLOAT Reconstruction loss weight. Default is 0.1.
--learning-rate FLOAT Adam learning rate. Default is 0.01.
Perintah berikut mempelajari model dan menyimpan prediksi. Melatih model pada dataset default:
$ python src/main.py
Melatih model capsgnnn untuk 100 zaman.
$ python src/main.py --epochs 100Mengubah ukuran batch.
$ python src/main.py --batch-size 128Lisensi