Coder Social home page Coder Social logo

cassini-titan / async-fl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from nuaa-smartsensing/async-fl

0.0 0.0 0.0 11.43 MB

a federated learning framework for supporting asynchronous fl, semi-asynchronous fl, synchronous fl and personalized fl. 一个支持异步联邦学习、半异步联邦学习、同步联邦学习和个性化联邦学习的联邦学习框架。

License: MIT License

Python 99.90% Dockerfile 0.10%

async-fl's Introduction

Async-FL

GitHub code size license python torch

This document is also available in: 中文 | English

keywords: federated-learning, asynchronous, synchronous, semi-asynchronous, personalized

Table of Contents

Original Intention

The initial intention of this project is to build an asynchronous federated learning framework and conduct experiments on it during my undergraduate thesis.

However, when I tried to search for related open-source projects on GitHub, I found that the field of asynchronous federated learning is quite closed-source, with almost no open-source projects available. Additionally, mainstream frameworks also lack compatibility with asynchronous FL and only support synchronous FL. Thus, this project was born.

Git Branch Description

The master branch is the main branch with the latest code, but some of the commits are dirty commits and not guaranteed to run properly. It is recommended to use tagged versions for better stability.

The checkout branch retains the functionality of adding clients to the system during the training process, which has been removed in the main branch. The checkout branch is not actively maintained and only supports synchronous and asynchronous FL.

Requirements

python3.8 + pytorch + macos

It has been validated on Linux.

It supports single GPU and Multi-GPU.

Getting Started

Experiments

You can run python main.py (the main file in the fl directory) directly. The program will automatically read the config.json file in the root directory and store the results in the specified path under results, along with the configuration file.

You can also specify the configuration file by python main.py ../../config.json. Please note that the path of config.json is relative to the main.py.

The config folder in the root directory provides some algorithm configuration files proposed in papers. The following algorithm implementations are currently available:

FedAvg
FedAsync
FedProx
FedAT
FedLC
FedDL
M-Step AsyncFL

Docker

Now you can directly pull and run a Docker image, the command is as follows:

docker pull desperadoccy/async-fl
docker run -it async-fl config/FedAvg-config.json

Similarly, it supports passing a config file path as a parameter. You can also build the Docker image yourself.

cd docker
docker build -t async-fl .
docker run -it async-fl config/FedAvg-config.json 

Features

  • Asynchronous Federated Learning
  • Support model and dataset replacement
  • Support scheduling algorithm replacement
  • Support aggregation algorithm replacement
  • Support loss function replacement
  • Support client replacement
  • Synchronous federated learning
  • Semi-asynchronous federated learning
  • Provide test loss information
  • Custom label heterogeneity
  • Custom data heterogeneity
  • Support Dirichlet distribution
  • wandb visualization
  • Support for leaf-related datasets
  • Support for multiple GPUs
  • Docker deployment

Project Directory

Project Directory

.
├── config                                    Common algorithm configuration files
│   ├── FedAT-config.json
│   ├── FedAsync-config.json
│   ├── FedAvg-config.json
│   ├── FedDL-config.json
│   ├── FedLC-config.json
│   ├── FedProx-config.json
│   ├── MSTEPAsync-config.json
│   ├── config.json
│   └── model_config
│       ├── CIFAR10-config.json
│       ├── ResNet18-config.json
│       └── ResNet50-config.json
├── config.json
├── config_semi.json
├── config_semi_test.json
├── config_sync.json
├── config_sync_test.json
├── config_test.json
├── doc
│   ├── params.docx
│   ├── pic
│   │   ├── fedsemi.png
│   │   ├── framework.png
│   │   └── header.png
│   ├── readme-zh.md
│   └── 参数.docx
├── docker
│   └── Dockerfile
├── license
├── readme.md
├── requirements.txt
└── src
    ├── checker                                checker implementation
    │   ├── AllChecker.py
    │   ├── CheckerCaller.py
    │   ├── SyncChecker.py
    │   └── __init__.py
    ├── client                                 client implementation
    │   ├── ActiveClient.py
    │   ├── Client.py
    │   ├── DLClient.py
    │   ├── NormalClient.py
    │   ├── ProxClient.py
    │   ├── SemiClient.py
    │   ├── TestClient.py
    │   └── __init__.py
    ├── clientmanager                           client manager implementation
    │   ├── BaseClientManager.py
    │   ├── NormalClientManager.py
    │   └── __init__.py
    ├── compressor                              compressor algorithm class
    │   ├── QSGD.py
    │   └── __init__.py
    ├── data
    ├── dataset
    │   ├── CIFAR10.py
    │   ├── FashionMNIST.py
    │   ├── MNIST.py
    │   └── __init__.py
    ├── exception
    │   ├── ClientSumError.py
    │   └── __init__.py
    ├── fl                                       wandb running directory
    │   ├── __init__.py
    │   ├── main.py
    │   └── wandb
    ├── group                                    group algorithm class
    │   ├── AbstractGroup.py
    │   ├── DelayGroup.py
    │   ├── GroupCaller.py
    │   ├── OneGroup.py
    │   └── __init__.py
    ├── groupmanager                             group manager implementation
    │   ├── BaseGroupManager.py
    │   ├── NormalGroupManager.py
    │   └── __init__.py
    ├── loss                                     loss algorithm class
    │   ├── FedLC.py
    │   ├── LossFactory.py
    │   └── __init__.py
    ├── model
    │   ├── CNN.py
    │   └── __init__.py
    ├── numgenerator                             num generator algorithm class
    │   ├── AbstractNumGenerator.py
    │   ├── NumGeneratorFactory.py
    │   ├── StaticNumGenerator.py
    │   └── __init__.py
    ├── queuemanager                             queuemanager implementation
    │   ├── AbstractQueueManager.py
    │   ├── BaseQueueManger.py
    │   ├── QueueListManager.py
    │   ├── SingleQueueManager.py
    │   └── __init__.py
    ├── receiver                                 receiver implementation
    │   ├── MultiQueueReceiver.py
    │   ├── NoneReceiver.py
    │   ├── NormalReceiver.py
    │   ├── ReceiverCaller.py
    │   └── __init__.py
    ├── results
    ├── schedule                                 scheduling algorithm class
    │   ├── AbstractSchedule.py
    │   ├── FullSchedule.py
    │   ├── NoSchedule.py
    │   ├── RandomSchedule.py
    │   ├── RoundRobin.py
    │   ├── ScheduleCaller.py
    │   └── __init__.py
    ├── scheduler                                scheduler implementation
    │   ├── AsyncScheduler.py
    │   ├── BaseScheduler.py
    │   ├── SemiAsyncScheduler.py
    │   ├── SyncScheduler.py
    │   └── __init__.py
    ├── server                                   server implementation
    │   ├── AsyncServer.py
    │   ├── BaseServer.py
    │   ├── SemiAsyncServer.py
    │   ├── SyncServer.py
    │   └── __init__.py
    ├── test                                     for test
    │   ├── __init__.py
    │   ├── test.ipynb
    │   └── test.py
    ├── update                                   update algorithm class
    │   ├── AbstractUpdate.py
    │   ├── AsyncAvg.py
    │   ├── FedAT.py
    │   ├── FedAsync.py
    │   ├── FedAvg.py
    │   ├── FedDL.py
    │   ├── StepAsyncAvg.py
    │   ├── UpdateCaller.py
    │   └── __init__.py
    ├── updater                                 updater implementation
    │   ├── AsyncUpdater.py
    │   ├── BaseUpdater.py
    │   ├── SemiAsyncUpdater.py
    │   ├── SyncUpdater.py
    │   └── __init__.py
    └── utils
        ├── ConfigManager.py
        ├── GlobalVarGetter.py
        ├── IID.py
        ├── JsonTool.py
        ├── ModelTraining.py
        ├── ModuleFindTool.py
        ├── Plot.py
        ├── ProcessTool.py
        ├── Queue.py
        ├── Random.py
        ├── Time.py
        ├── Tools.py
        └── __init__.py

The "Time" file under the "utils" package is an implementation of a multi-threaded time acquisition class, and the "Queue" file is an implementation of related functionalities for the "queue" module, as some functionalities of the "queue" module are not yet implemented on macOS.

Framework

error

error

Code Explanations

Receiver Class

The receiver in synchronous and semi-asynchronous federated learning is used to check whether the updates received during the current global iteration meet the conditions set, such as whether all designated clients have uploaded their updates. If the conditions are met, the updater process will be triggered to perform global aggregation.

Checker Class

In synchronous and semi-asynchronous federated learning, after a client completes its training, it will upload its weights to the uploader class, which will determine whether the update meets the upload criteria based on its own logic, and decide whether to accept or discard the update.

Configuration

Configuration

async mdoe example

sync mdoe example

semi-async mdoe example

Parameter explanation

Parameter explanation

parameters

type

explanations

wandb

enabled

bool

whether to enable wandb

project

string

project name

name

string

the name of this run

global

use_file_system

bool

whether to enable the file system as the torch multi-thread sharing strategy

multi_gpu

bool

whether to enable multi-GPU, detailed explanation

experiment

string

the name of this run

stale

explanation

dataset

path

string

the path of the dataset

params

dict

required parameters

iid

explanation

client_num

int

client num

server

path

string

the path of server

epochs

int

global epoch

model

path

string

the path of the model

params

dict

required parameters

scheduler

path

string

the path of the scheduler

schedule

path

string

the path of the schedule

params

dict

required parameters

other_params

*

other parameters

updater

path

string

the path of the updater

update

path

string

the path of the update

params

dict

required parameters

loss

explanation

num_generator

explanation

group

path

string

the path of the updater

params

dict

required parameters

client_manager

path

string

the path of the client manager

group_manager

path

string

the path of the group manager

group_method

path

string

the path of the group method

params

dict

required parameters

queue_manager

path

string

the path of the queue manager

receiver

path

string

the path of the receiver

params

dict

required parameters

checker

path

string

the path of the checker

params

dict

required parameters

client

path

string

the path of the client

epochs

int

local epoch

batch_size

int

batch

model

path

string

the path of the model

params

dict

required parameters

loss

explanation

mu

float

proximal term’s coefficient

optimizer

path

string

the path of the optimizer

params

dict

required parameters

other_params

*

other parameters

Adding New Algorithm

To allow clients/servers to call your own algorithms or implementation classes (note: all algorithm implementations must be in class form), the following steps are required:

  • Add your own implementation to the corresponding location (dataset, model, schedule, update, client, loss)
  • Import the class in the __init__.py file of the corresponding package, for example from model import CNN
  • Declare in the configuration file, model_path corresponds to the path where the new algorithm is located.
  • checker, group, receiver, schedule, and update modules need to be supplemented with invocation methods in the Caller class.
  • loss and numgenerator modules need to be supplemented with invocation methods in the factory class.

In addition, parameters that the algorithm needs to use can be declared in the params configuration item.

Now the model, optim, and loss modules support the introduction of built-in implementation classes such as torch, for example:

"model": {
      "path": "torchvision.models.resnet18",
      "params": {
        "pretrained": true,
        "num_classes": 10 
      }
}

Adding Loss Function

The loss function is now generated and created by the LossFactory class. You can choose to use built-in algorithms from Torch or implement your own.

The loss configuration supports three settings. The first option is using a string format commonly used in the configuration file:

"loss": "torch.nn.functional.cross_entropy"

In this case, the program will directly generate a loss function using the functional approach.

The second option is to generate an object-based loss:

"loss": {
    "path": "loss.myloss.MyLoss",
    "params": {}
}

Here, you specify the path to your custom loss class and provide any necessary parameters in the params field.

The third option is to generate a loss based on the type:

"loss": {
        "type": "func",
        "path": "loss.myloss.MyLoss",
        "params": {}
    }

With this option, you also provide the type field as "func", and the rest of the process is similar to the object-based approach.

Staleness Settings

stale has three settings, one of which is mentioned in the above configuration file.

"stale": {
      "step": 5,
      "shuffle": true,
      "list": [10, 10, 10, 5, 5, 5, 5]
    }

The program will generate a string of random integers based on the provided step and list. For example, in the code above, the program will generate 10 zeros, 10 (0, 5), and 10 [5, 10), and shuffle them if shuffle is set to true. Finally, the random string is assigned to each client, and the client sleeps according to the corresponding number of seconds after each round of training. When storing the JSON file to the experimental results, this setting will be automatically converted to the third setting.

The second option is to set it to false, in which case the program will set the delay for each client to 0.

"stale": false

The third option is a list of random integers, and the program will directly assign the delay settings from the list to the clients.

"stale": [1, 2, 3, 1, 4]

Data Distribution Settings

iid

When iid is set to true (in fact, it is also the default when set to false), the data will be distributed to each client in an identical and independent way (iid).

"iid": true

dirichlet non-iid

When customize in iid is set to false or not set, the data will be distributed to each client in a Dirichlet distribution.

Beta is the parameter of the Dirichlet distribution.

"iid": {
    "customize": false,
    "beta": 0.5
}

or

"iid": {
    "beta": 0.5
}

customize non-iid

Customized non-iid settings are divided into two parts, one is for label non-iid setting and the other is for data quantity non-iid setting. Currently, only random generation is provided for data quantity, and personalized settings will be introduced in future versions.

When enabling the customized setting, you need to set customize to true and set label and data separately.

"iid": {
    "customize": true
}

label distribution

Label setting is similar to staleness settings and supports three modes. The first one is mentioned in the configuration file.

"label": {
    "step": 1,
    "list": [10, 10, 30]
}

The above configuration will generate 10 clients with 1 label data, 10 clients with 2 label data, and 30 clients with 3 label data.

If step is set to 2, the program will generate 10 clients with 1 label data, 10 clients with 3 label data, and 30 clients with 5 label data.

The second option is a two-dimensional array of random numbers, and the program will assign the array directly to the clients.

"label": {
    "0": [1, 2, 3, 8],
    "1": [2, 4],
    "2": [4, 7],
    "3": [0, 2, 3, 6, 9],
    "4": [5]
}

The third option is a one-dimensional array, which represents the number of labels each client has, and the length of the array should be the same as the number of clients.

"label": {
  "list": [4, 5, 10, 1, 2, 3, 4]
}

The above configuration sets the number of label data for each client: client 0 has 4 label data, client 1 has 5 label data, and so on.

Currently, there are two randomization methods for generating label non-iid data, one is pure randomization, which may lead to all clients missing one label, resulting in a decrease in accuracy (although the probability is extremely low). The other method uses shuffle algorithm to ensure that each label is selected, but it also leads to the inability to generate data with uneven label distributions. The shuffle algorithm is controlled by the shuffle parameter, as shown below:

"label": {
  "shuffle": true,
  "list": [4, 5, 10, 1, 2, 3, 4]
}

data distribution

The data setting is relatively simple, currently there are two methods, one of which is empty.

"data": {}

That is, no non-iid setting is performed on the data quantity.

The second method is mentioned in the configuration file.

"data": {
    "max": 500,
    "min": 400
}

That is, the data quantity for each client will be randomly distributed between 400 and 500, and will be evenly distributed among the labels by default.

The data quantity distribution is still relatively simple at this point, and will be gradually improved in the future.

Adding New Client Class

Currently, client replacement needs to inherit from AsyncClient or SyncClient, and the new parameters are passed into the class through the client configuration item.

Multi-GPU

The multi-GPU feature of this project is not about multi-GPU parallel computing. Each client is still trained on a single GPU, but macroscopically, the clients run on multiple GPUs. That is, the training tasks of each client will be evenly distributed to the GPUs visible to the program. The GPU bound to each client is specified at initialization and is not specified on each round of training. Therefore, it is still possible to have a serious imbalance in GPU load.

This feature is controlled by the multi_gpu switch in the global settings.

Existing Bugs

Currently, there is a core issue in the framework that the communication between clients and servers is implemented using the multiprocessing queues. However, when a CUDA tensor is received by the queue and retrieved by other threads, it can cause a memory leak and may cause the program to crash.

This bug is caused by PyTorch and the multiprocessing queue, and the current solution is to upload non-CUDA tensors to the queue and convert them to CUDA tensors during aggregation. Therefore, when adding aggregation algorithms, the following code will be needed:

updated_parameters = {}
for key, var in client_weights.items():
    updated_parameters[key] = var.clone()
    if torch.cuda.is_available():
        updated_parameters[key] = updated_parameters[key].cuda()

Contributors

desperadoccy
Desperadoccy
jzj007
Jzj007

Contact us

QQ: 527707607

email: [email protected]

Welcome to provide suggestions for the project~

if you'd like contribute to this project, please contact us.

async-fl's People

Contributors

desperadoccy avatar jzj007 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.