Coder Social home page Coder Social logo

lucaswychan / federated-edge-ai-for-6g Goto Github PK

View Code? Open in Web Editor NEW
1.0 2.0 0.0 16.83 MB

Combination of federated learning algorithm and 6G technology.

License: MIT License

Python 100.00%
6g federated-learning oops-in-python reconfigurable-intelligent-surfaces pytorch wireless-communication convex-optimization

federated-edge-ai-for-6g's Introduction

PyTorch Federated Learning RIS

Federated Edge AI for 6G

Abstract

Federated Learning (FL) is a decentralized approach to machine learning that addresses the crucial issue of data privacy. However, decentralization presents challenges such as data and system heterogeneity, as well as communication channel difficulties. By tapping into the potential of state-of-the-art 6G technology and leveraging advancements in FL computation algorithms, it is possible to effectively overcome the computational and communication complexities inherent in the FL system. Techniques such as FedDyn (Federated Dynamic Regularization) and RIS-FL (Reconfigurable Intelligence Surface-assisted Federated Learning) have been chosen to address the computation and communication problems resulting from the decentralization of FL, with named RIS-AirFedDyn. Through these methods, it is possible to enhance the efficiency and performance of FL systems while maintaining data privacy and security.

Credit

Federated Learning Based on Dynamic Regularization
Paper
Codes

Reconfigurable Intelligent Surface Enabled Federated Learning: A Unified Communication-Learning Design Approach
Paper
Codes

Dependencies

  • Python >= 3.6
  • numpy==1.21.6 (For Dirichlet Case)
  • torch
  • torchvision
  • cvxpy
  • matplotlib

Or you can install all the packages via

pip install -r requirements.txt

Instructions

There are four algorithms available to play with, which are FedDyn, FedAvg, FedProx, and SCAFFOLD
The default algorithm is FedDyn, but you can feel free to change the algorithm by adding --algorithm_name={FedAvg, FedProx, FedDyn, SCAFFOLD}

For more details on the parameters, please visit Parameters

python3 fl_main.py

Add new algorithm

This whole code structure enjoys the advantage OOP brings, so adding new algorithms on top of the existing codes is a piece of cake.

  1. Create a new Python file under the algorithm directory

    E.g. fedsplit.py

  2. Construct a new class inheriting Algorithm in the corresponding Python file.

    E.g.
class FedSplit(Algorithm):
    def __init__(self, lr, lr_decay_per_round, batch_size, epoch, weight_decay, model_func, n_param, max_norm, noiseless, dataset_name, save_period, print_per, new_parameter):
        super().init("FedSplit", lr, lr_decay_per_round, batch_size, epoch, weight_decay, model_func, n_param, max_norm, noiseless, dataset_name, save_period, print_per)

        self.new_parameter = new_parameter



3. Override the method local_train and aggregate.

E.g.

class FedSplit(Algorithm):
    .
    .
    .
    def local_train(self, client: Client, inputs: dict):
        # client local training

    def __train_model(self, ...):
        # helper function for local_train (Optional)

    def aggregate(self, server: Server, inputs: dict):
        # aggregate the global model



4. Add the corresponding class in AlgorithmFactory, which is in algorithm/algorithm_factory.py

E.g.

class AlgorithmFactory:
    def __init__(self, args):
        ...

    def create_algorithm(self, algorithm_name) -> Algorithm:
        ...
        elif algorithm_name == "FedAvg":
            ...

        elif algorithm_name == "FedSplit":
            new_parameter = 0.3

            algorithm = FedSplit(
                self.args.lr,
                self.args.lr_decay_per_round,
                self.args.batch_size,
                self.args.epoch,
                self.args.weight_decay,
                self.args.model_func,
                self.args.n_param,
                self.args.max_norm,
                self.args.noiseless,
                self.args.data_obj.dataset,
                self.args.save_period,
                self.args.print_per,
                new_parameter
            )



5. Create the required parameters for this algorithm in fl_main.py

E.g.

# these lines should be in fl_main.py
required_parameter = np.ones((args.n_clients, n_param))
.
.
.
inputs["required_parameter"] = required_parameter

Note that if inputs["required_parameter"] is updated in client.local_train or server.aggregate, it should be explicitly updated in fl_main.py
i.e.

# these lines should be in fl_main.py
if args.algorithm_name == "FedSplit":
    required_parameter = inputs["required_parameter"]



6. Have fun to play with your algorithm !

Parameters

There are various parameters required by the algorithms. For more details you can visit args.py

Parameter Name Meaning Default Value Type Choice/Range
algorithm_name algorithm for training FedDyn str {FedDyn, FedAvg, FedProx, SCAFFOLD}
n_clients number of clients 30 int [1, inf)
comm_rounds number of communication rounds 50 int [1, inf)
lr learning rate 0.03 float (0, inf)
act_prob probability of randomly choosing active clients 0.9 float [0,1]
lr_decay_per_round learning rate decay per round 0.99 float [0, inf)
batch_size number of data per batch 50 int [1, total data size]
epoch local epoch for client training 5 int [1, inf)
weight_decay weight decay 0.01 float (0, inf)
max_norm max norm for gradient clipping 10.0 float (0, inf)
model_name model for training. The name is also the corresponding dataset name cifar10 str {linear, mnist, emnist, cifar10, cifar100, resnet18, shakespeare}
rule the rule of data partitioning iid str {iid, dirichlet}
rand_seed random seed 1 int [0, inf)
save_period period to save the models 1 int [1, comm_rounds]
print_per period to print the training result 5 int [1, epoch]
n_RIS_ele number of RIS elements 40 int [0, inf)
n_receive_ant number of receive antennas 5 int [0, inf)
alpha_direct path loss component 3.76 float [0, inf)
SNR noise variance/0.1W in dB 90.0 float [0, inf)
location_range location range between clients and RIS 30 int [0, inf)
Jmax number of maximum Gibbs Outer loops 50 int [1, inf)
tau the SCA regularization term 0.03 float [0, inf)
nit I_max, number of maximum SCA loops 100 int [1, inf)
threshold epsilon, SCA early stopping criteria 0.01 float [0, inf)
transmit_power transmit power of clients 0.003 float [0, inf)
noiseless whether the channel is noiseless False bool {True, False}
rison whether the RIS is presented 1 int {0, 1}

License

This project is licensed under the MIT License - see the LICENSE file for details.

federated-edge-ai-for-6g's People

Contributors

lucaswychan avatar

Stargazers

 avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.