これは部分的には、分散型データからのディープネットワークのコミュニケーション効率の高い学習の論文の再現です
MNISTとCIFAR10(IIDと非IIDの両方)に関する実験のみが生成されます。
注:並列コンピューティングの実装がなければ、スクリプトは遅くなります。
Python> = 3.6
pytorch> = 0.4
MLPモデルとCNNモデルは、次のように生成されます。
python main_nn.py
MLPおよびCNNを使用した連合学習は、次のように作成されます。
python main_fed.py
Options.pyの引数を参照してください。
例えば:
python main_fed.py - dataset mnist -iid - num_channels 1 - model cnn -epochs 50 -gpu 0
--all_clientsすべてのクライアントモデルを平均化するためのclients
NB:CIFAR-10の場合、 num_channels 3でなければなりません。
結果を表1と表2に示し、パラメーターc = 0.1、b = 10、e = 5を示します。
表1。0.01の学習率で10エポックトレーニングの結果
| モデル | acc。 IIDの | acc。非IIDの |
|---|---|---|
| FEDAVG-MLP | 94.57% | 70.44% |
| FedAvg-CNN | 96.59% | 77.72% |
表2。0.01の学習率で50エポックトレーニングの結果
| モデル | acc。 IIDの | acc。非IIDの |
|---|---|---|
| FEDAVG-MLP | 97.21% | 93.03% |
| FedAvg-CNN | 98.60% | 93.81% |
謝辞はYoukaichaoに与えます。
McMahan、Brendan、Eider Moore、Daniel Ramage、Seth Hampson、Blaise Aguera Y Arcas。分散データからのディープネットワークの通信効率の高い学習。人工知能と統計(Aistats)、2017年。
shaoxiong ji。 (2018年3月30日)。連合学習のPytorch実装。ゼノド。 http://doi.org/10.5281/zenodo.4321561