Реализация нейронной сети капсул Графа ( ICLR 2019).

Высококачественные встраиваемые узлы, изученные из нейронных сетей графов (GNN), были применены к широкому спектру приложений на основе узлов, и некоторые из них достигли современной (SOTA) производительности. Однако при применении вторжений узлов, полученных от GNN для генерации встроений графика, масло-представления узла может быть недостаточно для эффективного сохранения свойств узла/графика, что приводит к суб-оптимальным встрокам графиков. Вдохновленная капсульной нейронной сетью (CAPSNET), мы предлагаем нейронную сеть капсул Графа (CAPSGNN), которая принимает концепцию капсул для устранения слабости в существующих алгоритмах Entgdings на основе GNN. Извлекая функции узла в виде капсул, механизм маршрутизации может быть использован для захвата важной информации на уровне графика. В результате наша модель генерирует несколько встраиваний для каждого графика для захвата свойств графика из разных аспектов. Модуль внимания, включенный в Capsgnn, используется для борьбы с графиками с различными размерами, что также позволяет модели сосредоточиться на критических частях графиков. Наши обширные оценки с 10 наборами данных по структуре графика показывают, что Capsgnn имеет мощный механизм, который работает для захвата макроскопических свойств всего графика с помощью данных, управляемых данными. Он превосходит другие методы SOTA по нескольким задачам классификации графиков в силу нового инструмента.
Этот репозиторий обеспечивает реализацию Pytorch Capsgnn, как описано в статье:
Капсула График Нейронная сеть. Чжан Синьи, Лихуи Чен. ICLR, 2019. [Paper]
Адаптированная адаптированная внедренная капсула нейронная сеть доступна [здесь].
Кодовая база реализована в Python 3.5.2. Версии, используемые для разработки, чуть ниже.
networkx 2.4
tqdm 4.28.1
numpy 1.15.4
pandas 0.23.4
texttable 1.5.0
scipy 1.1.0
argparse 1.1.0
torch 1.1.0
torch-scatter 1.4.0
torch-sparse 0.4.3
torch-cluster 1.4.5
torch-geometric 1.3.2
torchvision 0.3.0
Код принимает графики для обучения из входной папки, где каждый график хранится как JSON. Графики, используемые для тестирования, также хранятся в виде файлов JSON. Каждый идентификатор узла и метка узла должен быть проиндексирован из 0. Ключи словарей хранятся в строках, чтобы сделать возможной сериализацию JSON.
Каждый файл JSON имеет следующую структуру ключевых значений:
{ "edges" : [ [ 0 , 1 ] , [ 1 , 2 ] , [ 2 , 3 ] , [ 3 , 4 ] ] ,
"labels" : { "0" : "A" , "1" : "B" , "2" : "C" , "3" : "A" , "4" : "B" } ,
"target" : 1 }Ключ ** края ** имеет значение списка краев, которое снижает структуру подключения. Ключ ** Метки ** имеет метки для каждого узла, которые хранятся в качестве словаря - в рамках этих вложенных словарных этикетков - значения, идентификаторы узлов являются ключами. Ключ ** Target ** имеет целочисленное значение, которое является членством в классе.
Прогнозы сохраняются в каталоге «выходной/` `». Каждое встраивание имеет заголовок и столбец с идентификаторами графика. Наконец, прогнозы отсортируются по столбцу идентификатора.
Обучение модели Capsgnn обрабатывается сценарием `src/main.py`, который предоставляет следующие аргументы командной строки.
--training-graphs STR Training graphs folder. Default is `input/train/`.
--testing-graphs STR Testing graphs folder. Default is `input/test/`.
--prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
--epochs INT Number of epochs. Default is 100.
--batch-size INT Number fo graphs per batch. Default is 32.
--gcn-filters INT Number of filters in GCNs. Default is 20.
--gcn-layers INT Number of GCNs chained together. Default is 2.
--inner-attention-dimension INT Number of neurons in attention. Default is 20.
--capsule-dimensions INT Number of capsule neurons. Default is 8.
--number-of-capsules INT Number of capsules in layer. Default is 8.
--weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
--lambd FLOAT Regularization parameter. Default is 0.5.
--theta FLOAT Reconstruction loss weight. Default is 0.1.
--learning-rate FLOAT Adam learning rate. Default is 0.01.
Следующие команды изучают модель и сохраняют прогнозы. Обучение модели в наборе данных по умолчанию:
$ python src/main.py
Обучение модели Capsgnnn для 100 эпох.
$ python src/main.py --epochs 100Изменение размера партии.
$ python src/main.py --batch-size 128Лицензия