Coder Social home page Coder Social logo

moon's Introduction

Model-Contrastive Federated Learning

This is the code for paper Model-Contrastive Federated Learning.

Abstract: Federated learning enables multiple parties to collaboratively train a machine learning model without communicating their local data. A key challenge in federated learning is to handle the heterogeneity of local data distribution across parties. Although many studies have been proposed to address this challenge, we find that they fail to achieve high performance in image datasets with deep learning models. In this paper, we propose MOON: model-contrastive federated learning. MOON is a simple and effective federated learning framework. The key idea of MOON is to utilize the similarity between model representations to correct the local training of individual parties, i.e., conducting contrastive learning in model-level. Our extensive experiments show that MOON significantly outperforms the other state-of-the-art federated learning algorithms on various image classification tasks.

Dependencies

  • PyTorch >= 1.0.0
  • torchvision >= 0.2.1
  • scikit-learn >= 0.23.1

Parameters

Parameter Description
model The model architecture. Options: simple-cnn, resnet50 .
alg The training algorithm. Options: moon, fedavg, fedprox, local_training
dataset Dataset to use. Options: cifar10. cifar100, tinyimagenet
lr Learning rate.
batch-size Batch size.
epochs Number of local epochs.
n_parties Number of parties.
sample_fraction the fraction of parties to be sampled in each round.
comm_round Number of communication rounds.
partition The partition approach. Options: noniid, iid.
beta The concentration parameter of the Dirichlet distribution for non-IID partition.
mu The parameter for MOON and FedProx.
temperature The temperature parameter for MOON.
out_dim The output dimension of the projection head.
datadir The path of the dataset.
logdir The path to store the logs.
device Specify the device to run the program.
seed The initial seed.

Usage

Here is an example to run MOON on CIFAR-10 with a simple CNN:

python main.py --dataset=cifar10 \
    --model=simple-cnn \
    --alg=moon \
    --lr=0.01 \
    --mu=5 \
    --epochs=10 \
    --comm_round=100 \
    --n_parties=10 \
    --partition=noniid \
    --beta=0.5 \
    --logdir='./logs/' \
    --datadir='./data/' \

Tiny-ImageNet

You can download Tiny-ImageNet here. Then, you can follow the instructions to reformat the validation folder.

Hyperparameters

If you use the same setting as our papers, you can simply adopt the hyperparameters reported in our paper. If you try a setting different from our paper, please tune the hyperparameters of MOON. You may tune mu from {0.001, 0.01, 0.1, 1, 5, 10}. If you have sufficient computing resources, you may also tune temperature from {0.1, 0.5, 1.0} and the output dimension of projection head from {64, 128, 256}.

Citation

Please cite our paper if you find this code useful for your research.

@inproceedings{li2021model,
      title={Model-Contrastive Federated Learning}, 
      author={Qinbin Li and Bingsheng He and Dawn Song},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2021},
}

moon's People

Contributors

qinbinli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

moon's Issues

Question about dirichlet non-iid

Hi, thanks for your excellent work. I am confused about the construction of non-iid data using dirichlet distribution, which is located at line 126 in utils.py. What's the function of line 126? Could you please give me some intuition? I have no idea about what "len(idx_j) < N / n_parties" means. Thanks a lot!

Some questions about the metrics.

Hi, thanks very much for the codes. Recently, I have re-generated the released codes under the default setting. However, I cannot find the metrics presented in your paper. It seems that the released codes can only compute the metrics of the global network or each local network. Therefore, could you please tell us which metric is used in your paper?

Same label for positive and negative cases

Hi,

Thank you for your interesting work. I have a question about the label used for the contrastive loss. Why positive and negative cases have the same label as 0 (lines 309, 297, 303)? It must be 0 for positive and 1 for negative pairs?

Another question about the result of Scaffold published in your paper. This code does not have Scaffold. So, did you run it from https://github.com/Xtra-Computing/NIID-Bench?

Thank you in advance for your feedback.

Convergence theory analysis

Hello, this article I think is very novel, and the application of comparative learning ideas is simply brilliant. But when I wanted to integrate it into my own work, I encountered the problem of convergence theory analysis. Is there a convergence analysis for MOON?

Quesiton for code

Hi, I have read your paper and code and found it to be an interesting work!

I have a question. In your code,

loss2 = mu * criterion(logits, labels) # (main.py line311).

I know it is uesd to caculate the con_loss (Eq.3 in your paper), but why it is implemented by cross-entropy loss with labels (zeros tensor)?

About w_i^t in paper.

I have learned a lot and this is an excellent piece of work!

I have some questions in paper and code:

  1. In paper, I donot really understand how to get w_i^t(in Algorithm1 procedure 10 it shows we get w_i^t by w_t,how does this work?)

  2. in main.py line499,
    local_train_net(nets_this_round, args, net_dataidx_map, train_dl=train_dl, test_dl=test_dl, global_model = global_model, prev_model_pool=old_nets_pool, round=round, device=device),
    I think w_i^t is in nets_this_round,and this is initialized from nets(). Is this(w_i^t) donot change with training?

L2 norm code issue

Hi, I am very interested in your work, is there no code with loss function L2 norm?

Questions about the reported test accuracy.

Hi! Thanks for your inspiring work.
I have a few questions about the reported accuracy in your paper, because I am not familiar with the field of federated learning.

  1. In the section 4.2 Accuracy Comparision, do all client participate in the training(Participation ratio=1.0)?
  2. When under partial participation setting, each round we need to evaluate the current model, so which dataset will the evaluation be conducted on? Will the evaluation be conducted on the selected clients or all clients? If we adopt the first strategy(closer to the practical scenario), how do we get the reported top accuracy? Is the reported top accuracy is the average acc among all clients?

Questions about settings of negative samples

Hi, I have read your paper and I am very interested in your work! I think it's a very good article!
I have some questions about the settings of negative samples. I wonder why z and zpre are as far apart as possible.
Does this speed up training or improve accuracy?
Do you think the increase in accuracy has anything to do with this? Or is it due to contrastive loss?
look forward to your kind reply.

Questions about SCAFFOLD code

Hi, I am very interested in your work. I found some results of SCAFFOLD in your paper, but there is no codes about it. Could i know how you reproduce it with your codes. If convenient, can you release your codes about SCAFFOLD.

getting debug message

when i am applying on differenet data set getting below bug messages in log file
02-15 15:45 INFO cuda:0
02-15 15:45 INFO ####################################################################################################
02-15 15:45 INFO Partitioning data
02-15 15:45 INFO Data statistics: {0: {0: 1383}, 1: {0: 6, 1: 441, 2: 21}, 2: {0: 457, 1: 428}, 3: {0: 1, 1: 24, 2: 566}, 4: {0: 17, 1: 39, 2: 345}}
02-15 15:45 INFO Initializing nets
02-15 15:45 INFO in comm round:0
02-15 15:45 INFO Training network 0. n_training: 1383
02-15 15:45 INFO Training network 0
02-15 15:45 INFO n_training: 21
02-15 15:45 INFO n_test: 37
02-15 15:45 DEBUG STREAM b'IHDR' 16 13
02-15 15:45 DEBUG STREAM b'tIME' 41 7
02-15 15:45 DEBUG b'tIME' 41 7 (unknown)
02-15 15:45 DEBUG STREAM b'IDAT' 60 8192
02-15 15:45 DEBUG STREAM b'IHDR' 16 13
02-15 15:45 DEBUG STREAM b'tIME' 41 7
02-15 15:45 DEBUG b'tIME' 41 7 (unknown)
02-15 15:45 DEBUG STREAM b'IDAT' 60 8192
02-15 15:45 DEBUG STREAM b'IHDR' 16 13
02-15 15:45 DEBUG STREAM b'tIME' 41 7
02-15 15:45 DEBUG b'tIME' 41 7 (unknown)
02-15 15:45 DEBUG STREAM b'IDAT' 60 8192
02-15 15:45 DEBUG STREAM b'IHDR' 16 13
02-15 15:45 DEBUG STREAM b'tIME' 41 7
02-15 15:45 DEBUG b'tIME' 41 7 (unknown)
02-15 15:45 DEBUG STREAM b'IDAT' 60 8192
02-15 15:45 DEBUG STREAM b'IHDR' 16 13
02-15 15:45 DEBUG STREAM b'tIME' 41 7
02-15 15:45 DEBUG b'tIME' 41 7 (unknown)
02-15 15:45 DEBUG STREAM b'IDAT' 60 8192
02-15 15:45 DEBUG STREAM b'IHDR' 16 13
02-15 15:45 DEBUG STREAM b'tIME' 41 7

Hi,

Thank you for your wonderful work.
I a little confused about the code:

` for previous_net in previous_nets:
previous_net.cuda()
_, pro3, _ = previous_net(x)
nega = cos(pro1, pro3)
logits = torch.cat((logits, nega.reshape(-1,1)), dim=1)

            previous_net.to('cpu')`

First, In paper, the negetive representation seem from the local model of last round, instead of all previous rounds. I want to know if the code here makes sense. Second, "logits = torch.cat((logits, nega.reshape(-1,1)), dim=1)", what is the purpose of splicing? Looking forward to your reply, Thank you!!!

The code seems inconsistent with the algoritm in paper

Thank authors for providing code of MOON, It is very useful to me. But one thing I am confusing in these lines

MOON/main.py

Lines 309 to 311 in 6c7a4ed

labels = torch.zeros(x.size(0)).cuda().long()
loss2 = mu * criterion(logits, labels)

Could you help me to understand why "labels" as zeros vector is necessary in line 309?
And I think the loss2 always return 0 every iteration, isn't it?
I hope can receive your responds

About the contrastive loss

This is an excellent piece of work!
However, when I read the code,Why does multiplying logits with an all-zero matrix cause the contrast loss to always be 0?
_, pro1, out = net(x)
_, pro2, _ = global_net(x)

        posi = cos(pro1, pro2)
        for previous_net in previous_nets:
            previous_net.cuda()
            _, pro3, _ = previous_net(x)
            nega = cos(pro1, pro3)
            logits = torch.cat((logits, nega.reshape(-1,1)), dim=1)
            previous_net.to('cpu')
        logits /= temperature
        labels = torch.zeros(x.size(0)).cuda().long()
        loss2 = mu * criterion(logits, labels)###here

Question about FedAvg code

Thanks for sharing your code, it is an amazing work!
But I have a question on the implementation of FedAvg:

MOON/main.py

Line 587 in dbf6344

global_w[key] += net_para[key] * fed_avg_freqs[net_id]

why here aggregation without a "mean" operation ?

About the model Resnet50

What's the difference between the "models.resnet50(pretrained=False)" and "ResNet50_cifar10()" in model.py?

Time for Training on CIFAR-100 and Tiny-ImageNet

Hello, thanks for the good work. I'm trying to reproduce the results shown in the paper. Training on CIFAR-100 and Tiny-ImageNet seems to be very slow. I'm using Titan Xp. For CIFAR-100, it took 4 hours to train for 4 global epochs; the test accuracy right now is 7.1%. For Tiny-ImageNet, it took 3 hours to train 1 global epoch; the test accuracy right now is 0.5%. I followed the preprocessing steps you suggested. The command lines are exactly what you had on the GitHub project page. I also followed the hyperparameters outlined in the paper.

Here are the command lines I used:

python main.py --dataset=cifar100 --alg=moon --lr=0.01 --mu=1 --epochs=10 --comm_round=100 --n_parties=10 --partition=noniid --beta=0.5 --logdir='./logs/' --datadir='./data/'

python main.py --dataset=tinyimagenet --alg=moon --lr=0.01 --mu=1 --epochs=10 --comm_round=20 --n_parties=10 --partition=noniid --beta=0.5 --logdir='./logs/' --datadir='./data/'

Is it normal to take such a long time to train on CIFAR-100 and Tiny-ImageNet? Can I ask how long it took you to finish training on CIFAR-100 and Tiny-ImageNet?

Processing of Datasets

Hey, I had a question about the dataset processing, whether the final version of MOON's transformation in the dataloader only applied 'transforms.ToTensor()' and nothing like 'transforms.Normalize()'.
I really need your help so that I can follow my work better.

For tinyimagenet, the test acc is very low, 0.009

I run the model as the same as the author for the tinyimagenet, python main.py --dataset=tinyimagenet --model=resnet50 --alg=moon --lr=0.01 --mu=1 --epochs=10 --comm_round=20 --n_parties=10 --partition=noiid --beta=0.5

Questions about T-SNE

Hello,

Thanks for sharing the interesting work, I just have one question about the T-SNE part. Could i know more detailed information about how you generate such amazing T-SNE results and i want to reproduce them. And actually i have tried TSNE of sklearn and open-tsne but they did not work.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.