تطبيق Pytorch للشبكة العصبية الرسم البياني للكبسولة (ICLR 2019).

تم تطبيق تضمينات العقدة عالية الجودة المستفادة من الشبكات العصبية الرسم البياني (GNNS) على مجموعة واسعة من التطبيقات القائمة على العقدة وحقق بعضها أداء أحدث (SOTA). ومع ذلك ، عند تطبيق تضمينات العقدة المستفادة من GNNs لإنشاء تضمينات للرسم البياني ، قد لا يكفي تمثيل العقدة العددية للحفاظ على خصائص العقدة/الرسم البياني بكفاءة ، مما يؤدي إلى تضمينات الرسم البياني دون المستوى الأمثل. مستوحاة من الشبكة العصبية للكبسولة (CAPSNET) ، نقترح الشبكة العصبية الرسم البياني للكبسولة (CAPSGNN) ، التي تعتمد مفهوم الكبسولات لمعالجة الخوارزميات الموجودة في خوارزميات الرسم البياني القائم على GNN. عن طريق استخراج ميزات العقدة في شكل كبسولات ، يمكن استخدام آلية التوجيه لالتقاط معلومات مهمة على مستوى الرسم البياني. نتيجة لذلك ، يولد نموذجنا تضمينات متعددة لكل رسم بياني لالتقاط خصائص الرسم البياني من جوانب مختلفة. يتم استخدام وحدة الانتباه المدمجة في Capsgnn لمعالجة الرسوم البيانية بأحجام مختلفة والتي تمكن أيضًا النموذج من التركيز على الأجزاء الحرجة من الرسوم البيانية. تُظهر تقييماتنا الشاملة مع 10 مجموعات بيانات منظمة الرسم البياني أن Capsgnn لديها آلية قوية تعمل لالتقاط خصائص العيان من الرسم البياني بأكمله بواسطة البيانات. يتفوق على تقنيات SOTA الأخرى على العديد من مهام تصنيف الرسم البياني ، بحكم الأداة الجديدة.
يوفر هذا المستودع تطبيق Pytorch من Capsgnn كما هو موضح في الورقة:
كبسولة الرسم البياني الشبكة العصبية. Zhang Xinyi ، Lihui Chen. ICLR ، 2019. [ورقة]
يتوفر تنفيذ الشبكة العصبية Capsule Capsule Capsule [هنا].
يتم تنفيذ قاعدة الشفرة في Python 3.5.2. إصدارات الحزمة المستخدمة للتطوير هي أدناه.
networkx 2.4
tqdm 4.28.1
numpy 1.15.4
pandas 0.23.4
texttable 1.5.0
scipy 1.1.0
argparse 1.1.0
torch 1.1.0
torch-scatter 1.4.0
torch-sparse 0.4.3
torch-cluster 1.4.5
torch-geometric 1.3.2
torchvision 0.3.0
يأخذ الرمز الرسوم البيانية للتدريب من مجلد إدخال حيث يتم تخزين كل رسم بياني كـ JSON. يتم تخزين الرسوم البيانية المستخدمة للاختبار أيضًا كملفات JSON. يجب فهرسة كل معرف العقدة وعلامة العقدة من 0. يتم تخزين مفاتيح القواميس سلاسل من أجل جعل JSON Serialization ممكنة.
يحتوي كل ملف JSON على بنية قيمة المفاتيح التالية:
{ "edges" : [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 4 ] ] ,
"labels" : { "0" : "A" , "1" : "B" , "2" : "C" , "3" : "A" , "4" : "B" } ,
"target" : 1 }يحتوي مفتاح ** الحواف ** على قيمة قائمة الحافة التي تنحدر بنية الاتصال. يحتوي مفتاح ** Labels ** على كل عقدة يتم تخزينها كقاموس - ضمن ملصقات القاموس المتداخلة هذه هي قيم ، معرفات العقدة هي مفاتيح. يحتوي مفتاح ** Target ** على قيمة عدد صحيح وهي عضوية الفصل.
يتم حفظ التنبؤات في الدليل `الإخراج/`. كل تضمين له رأس وعمود مع معرفات الرسم البياني. أخيرًا ، يتم فرز التنبؤات بواسطة عمود المعرف.
يتم التعامل مع نموذج Capsgnn بواسطة البرنامج النصي src/main.py` الذي يوفر وسيطات سطر الأوامر التالية.
--training-graphs STR Training graphs folder. Default is `input/train/`.
--testing-graphs STR Testing graphs folder. Default is `input/test/`.
--prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
--epochs INT Number of epochs. Default is 100.
--batch-size INT Number fo graphs per batch. Default is 32.
--gcn-filters INT Number of filters in GCNs. Default is 20.
--gcn-layers INT Number of GCNs chained together. Default is 2.
--inner-attention-dimension INT Number of neurons in attention. Default is 20.
--capsule-dimensions INT Number of capsule neurons. Default is 8.
--number-of-capsules INT Number of capsules in layer. Default is 8.
--weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
--lambd FLOAT Regularization parameter. Default is 0.5.
--theta FLOAT Reconstruction loss weight. Default is 0.1.
--learning-rate FLOAT Adam learning rate. Default is 0.01.
تتعلم الأوامر التالية نموذجًا وحفظ التنبؤات. تدريب نموذج على مجموعة البيانات الافتراضية:
$ python src/main.py
تدريب نموذج capsgnnn ل 100 عصر.
$ python src/main.py --epochs 100تغيير حجم الدُفعة.
$ python src/main.py --batch-size 128رخصة