Una implementación de Pytorch de la red neuronal de Graph Capsule (ICLR 2019).

Las integridades de nodos de alta calidad aprendidas de las redes neuronales Graph (GNN) se han aplicado a una amplia gama de aplicaciones basadas en nodos y algunas de ellas han alcanzado el rendimiento de vanguardia (SOTA). Sin embargo, al aplicar incrustaciones de nodos aprendidos de GNN para generar incrustaciones de gráficos, la representación del nodo escalar puede no ser suficiente para preservar las propiedades de nodo/gráfico de manera eficiente, lo que resulta en incrustaciones de gráficos subóptimas. Inspirado en la red neuronal de la cápsula (CAPSNET), proponemos la red neuronal de la gráfica de la cápsula (CAPSGNN), que adopta el concepto de cápsulas para abordar la debilidad en los algoritmos existentes de incrustaciones de gráficos basados en GNN. Al extraer características de nodo en forma de cápsulas, el mecanismo de enrutamiento se puede utilizar para capturar información importante a nivel de gráfico. Como resultado, nuestro modelo genera múltiples incrustaciones para cada gráfico para capturar las propiedades de los gráficos de diferentes aspectos. El módulo de atención incorporado en Capsgnn se utiliza para abordar gráficos con varios tamaños que también permite que el modelo se centre en partes críticas de los gráficos. Nuestras extensas evaluaciones con 10 conjuntos de datos estructurados en gráficos demuestran que CapSGNN tiene un poderoso mecanismo que opera para capturar propiedades macroscópicas de todo el gráfico mediante datos basados en datos. Supera a otras técnicas SOTA en varias tareas de clasificación de gráficos, en virtud del nuevo instrumento.
Este repositorio proporciona una implementación de Pytorch de Capsgnn como se describe en el documento:
Capsule Graph Network Neural. Zhang Xinyi, Lihui Chen. ICLR, 2019. [Documento]
La implementación de la red neuronal cápsula central adaptada está disponible [aquí].
La base de código se implementa en Python 3.5.2. Las versiones de paquetes utilizadas para el desarrollo están justo debajo.
networkx 2.4
tqdm 4.28.1
numpy 1.15.4
pandas 0.23.4
texttable 1.5.0
scipy 1.1.0
argparse 1.1.0
torch 1.1.0
torch-scatter 1.4.0
torch-sparse 0.4.3
torch-cluster 1.4.5
torch-geometric 1.3.2
torchvision 0.3.0
El código toma gráficos para entrenar desde una carpeta de entrada donde cada gráfico se almacena como un JSON. Los gráficos utilizados para las pruebas también se almacenan como archivos JSON. Cada ID de nodo y etiqueta de nodo debe indexarse a partir de 0. Las claves de los diccionarios se almacenan en cadenas para hacer posible la serialización de JSON.
Cada archivo JSON tiene la siguiente estructura de valor clave:
{ "edges" : [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 4 ] ] ,
"labels" : { "0" : "A" , "1" : "B" , "2" : "C" , "3" : "A" , "4" : "B" } ,
"target" : 1 }La tecla ** Bordes ** tiene un valor de lista de borde que descarte la estructura de conectividad. La clave ** Etiquetas ** tiene etiquetas para cada nodo que se almacenan como un diccionario, dentro de este diccionario anidado son valores, los identificadores de nodos son claves. La clave ** Target ** tiene un valor entero que es la membresía de la clase.
Las predicciones se guardan en el directorio `salida/`. Cada incrustación tiene un encabezado y una columna con los identificadores de gráficos. Finalmente, las predicciones son ordenadas por la columna Identificador.
La capacitación de un modelo Capsgnn se maneja mediante el script `src/main.py` que proporciona los siguientes argumentos de línea de comandos.
--training-graphs STR Training graphs folder. Default is `input/train/`.
--testing-graphs STR Testing graphs folder. Default is `input/test/`.
--prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
--epochs INT Number of epochs. Default is 100.
--batch-size INT Number fo graphs per batch. Default is 32.
--gcn-filters INT Number of filters in GCNs. Default is 20.
--gcn-layers INT Number of GCNs chained together. Default is 2.
--inner-attention-dimension INT Number of neurons in attention. Default is 20.
--capsule-dimensions INT Number of capsule neurons. Default is 8.
--number-of-capsules INT Number of capsules in layer. Default is 8.
--weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
--lambd FLOAT Regularization parameter. Default is 0.5.
--theta FLOAT Reconstruction loss weight. Default is 0.1.
--learning-rate FLOAT Adam learning rate. Default is 0.01.
Los siguientes comandos aprenden un modelo y guardan las predicciones. Capacitar a un modelo en el conjunto de datos predeterminado:
$ python src/main.py
Entrenamiento de un modelo Capsgnnn para 100 épocas.
$ python src/main.py --epochs 100Cambiar el tamaño del lote.
$ python src/main.py --batch-size 128Licencia