Coder Social home page Coder Social logo

frank-xwang / cld-unsupervisedlearning Goto Github PK

View Code? Open in Web Editor NEW
99.0 99.0 9.0 202 KB

[CVPR 2021] Code release for "Unsupervised Feature Learning by Cross-Level Instance-Group Discrimination."

License: MIT License

Python 98.41% Shell 1.59%

cld-unsupervisedlearning's People

Contributors

frank-xwang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cld-unsupervisedlearning's Issues

BYOL+CLD

I admire the work you have done very much. I am planning to combine SSL with federated learning recently. The BYOL framework is used locally. I hope I can refer to the BYOL+CLD code. If I can see it, thank you very much!

failed to load checkpoint

I was trying to load mocov2-cld checkpoint but I got this error:

Traceback (most recent call last):
  File "inference.py", line 87, in <module>
    model.load_state_dict(checkpoint['state_dict'])
  File "/home/chris/anaconda3/envs/cld/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MoCo:
	Missing key(s) in state_dict: "queue", "queue_ptr", "encoder_q.conv1.weight", "encoder_q.bn1.weight", "encoder_q.bn1.bias", "encoder_q.bn1.running_mean", "encoder_q.bn1.running_var", "encoder_q.layer1.0.conv1.weight", "encoder_q.layer1.0.bn1.weight", "encoder_q.layer1.0.bn1.bias", "encoder_q.layer1.0.bn1.running_mean", "encoder_q.layer1.0.bn1.running_var", "encoder_q.layer1.0.conv2.weight", "encoder_q.layer1.0.bn2.weight", "encoder_q.layer1.0.bn2.bias", "encoder_q.layer1.0.bn2.running_mean", "encoder_q.layer1.0.bn2.running_var", "encoder_q.layer1.0.conv3.weight", "encoder_q.laye...

Here is the code that i used to load the checkpoint:


print("=> creating model '{}'".format(args_train['arch']))
model = builder.MoCo(ResNet.__dict__[args_train['arch']],
    args_test['moco_dim'], args_test['moco_k'], args_test['moco_m'], 
    args_test['moco_t'], args_test['mlp'], two_branch=True, 
    normlinear=args_test['normlinear'])
print(model)

chpt_moco_merge = 'checkpoint/chris-imagenet-full_c200_bz512/mocov2+cld/lr0.2-Lambda0.25-cld_t0.4-clusters200-NormNLP-epochs200/mocov2_cld_mlp_e200.pth.tar'
checkpoint = torch.load(chpt_moco_merge,  map_location='cuda:0')
print(model)
print(checkpoint.keys())
model.load_state_dict(checkpoint['state_dict'])
model.eval()

About the training process

Hi, thanks for your excellent work!

There are some problems about the training process after reading the code:
1.What role does DistributedShufle.forward_shuffle() and DistributedShufle.backward_shuffle() play here? Can't you directly set shuffle=True in the dataloader? Is it because of the use of distributed training mode?
image

2.I found you built three augmented samples (x1, x2, x3) for x, why not use only two augmented samples like MoCo? Isn't the calculation of general contrastive loss in this form "[L(f(x1), f(x2))+L(f(x2), f(x1))]/2" in terms of the positive pair {x1, x2}?

3.The paper mentioned "memory bank v is computed as the average feature of all the augmented versions of x_i seen so far".
image
How to understand the meaning of "average feature" here? I don't seem to be able to find the relevant implementation in the code.

I am looking forward to your answers. Thank you very much!

Question about the number of groups

Hi, thanks for the interesting work! I have a question about the number of groups. When the number of groups is equal to mini-batch size, can CLD still work? As I understand, this will lead to the cluster centroids being the samples themselves and make cross-level discrimination loss become instance-level discrimination loss.

Cifar Long-Tail

Hi,

Thanks for your great work! I really enjoy reading your paper. Just one quick question here: In the paper, there is a set of experiments on CIFAR10-LT and CIFAR100-LT. I'm wondering what's the imbalance ratio of these sets? I didn't seem to find any specs on this in the paper.

Thanks again!

why did you test with a different model after training?

In your linear evaluation scripts/imagenet/test_imagenet_moco_cld.sh, you used a different model from scripts/imagenet/train_imagenet_moco_cld.sh. Why is that?

#main_imagenet_moco_cld model -> training model

model = builder.MoCo(ResNet.__dict__[args_train['arch']],
     args_test['moco_dim'], args_test['moco_k'], args_test['moco_m'], 
     args_test['moco_t'], args_test['mlp'], two_branch=True,
     normlinear=args_test['normlinear'])

#main_lincls model: -> test model
model = models.__dict__['resnet50']()

Cluster size for moco v2 + CLD seems to be too big?

Hi,

Thanks for this interesting work. When I was running the training script for moco v2 + CLD, I observed that the effective batch size (64=512/8) is actually smaller than the cluster size you assigned (120). This makes the clustering trivial, e.g. every example is one cluster. Can I confirm that this is not a bug? Or did I run it in the wrong way?

What is msg model?

I can see you assign msg but didn't use it. what does it use for?
msg = model.load_state_dict(state_dict, strict=False)

Details of KNN classifier for ImageNet-LT

Thanks for your interesting work!

I would like to know some details of how you perform kNN classifier on the ImageNet-LT dataset, including the number of K and the temperature used in exp(cos_sim / t). Besides, which set of the dataset (train, validation or both) is used to be the support set to search for kNN? Thanks!

Accuracy error

Hi Wang!

Thanks for your code again!

When I tried to train with the cld+moco, I got errors.

File "./main_imagenet_moco_cld.py", line 493, in accuracy

    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Using CLD in barlow twins method

Hi!

I am currently learning SSL, and I found your code very helpful! I wonder if your CLD method could work for barlow twins (https://github.com/facebookresearch/barlowtwins)? In theory yes, but barlow twins take the advantage of high-dimension embedding and your method has to cluster all features in a low-dimension space. I understood that those are in different branches and may not interfere with each other, but I am not very sure about that. Also, barlow twins scale its loss to a large number, which could also cause incompatible problems.

Thx!

CIFAR Download Error

I am trying to run bash scripts/cifar/train_cifar10_moco_cld.sh while encounter cifar download issue.

(simclr-1) liangyu@sphadmin-G560-V5:/space/liangyu/workspace/jhu/code/CLD-UnsupervisedLearning$ bash run. sh [05/14 09:09:56 moco+cld]: Full config saved to checkpoint/cifar10/MoCo+CLD/resnet18/lr0.03-bs256-cldT0.2 -nceT0.07-clusters200-lambda0.8-cosine-weightDecay8e-4-fp16-add_erasing-AugPlus-kMeans-ncek12288-bslr0.03 -normlinear/config.json ==> Preparing data.. Traceback (most recent call last): File "train_cifar_moco_cld.py", line 366, in <module> main(opt) File "train_cifar_moco_cld.py", line 165, in main train_loader, test_loader, ndata = get_dataloader(args, add_erasing=args.erasing, aug_plus=args.aug_plus) File "/space/liangyu/workspace/jhu/code/CLD-UnsupervisedLearning/datasets/dataloader.py", line 81, in $ et_dataloader trainset = datasets.CIFAR10Instance(root='./data/CIFAR-10', train=True, download=True, transform=tra$sform_train, two_imgs=args.two_imgs, three_imgs=args.three_imgs) File "/space/liangyu/workspace/jhu/code/CLD-UnsupervisedLearning/datasets/cifar.py", line 11, in __ini$ __ super(CIFAR10Instance, self).__init__(root=root, train=train, download=download, transform=transform$ File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torchvision/datasets/cifar.py"$ line 65, in __init__ self.download() File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torchvision/datasets/cifar.py"$ line 143, in download download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torchvision/datasets/utils.py"$ line 316, in download_and_extract_archive download_url(url, download_root, filename, md5) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torchvision/datasets/utils.py"$ line 124, in download_url url = _get_redirect_url(url, max_hops=max_redirect_hops) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torchvision/datasets/utils.py"$ line 75, in _get_redirect_url with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/urllib/request.py", line 222, in urlopen return opener.open(url, data, timeout) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/urllib/request.py", line 531, in open response = meth(req, response) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/urllib/request.py", line 641, in http_respon$ e 'http', request, response, code, msg, hdrs) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/urllib/request.py", line 569, in error return self._call_chain(*args) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/urllib/request.py", line 503, in _call_chain result = func(*args) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/urllib/request.py", line 649, in http_error_d efault raise HTTPError(req.full_url, code, msg, hdrs, fp) urllib.error.HTTPError: HTTP Error 500: Internal Server Error Killing subprocess 50818 Traceback (most recent call last): File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torch/distributed/launch.py", l ine 340, in <module> main() File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torch/distributed/launch.py", l ine 326, in main sigkill_handler(signal.SIGTERM, None) # not coming back File "/home/liangyu/anaconda3/envs/simclr-1/lib/python3.7/site-packages/torch/distributed/launch.py", l ine 301, in sigkill_handler raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd) subprocess.CalledProcessError: Command '['/home/liangyu/anaconda3/envs/simclr-1/bin/python', '-u', 'train _cifar_moco_cld.py', '--local_rank=0', '--dataset', 'cifar10', '--num-workers', '4', '--batch-size', '256 ', '--nce-t', '0.07', '--nce-k', '12288', '--base-learning-rate', '0.03', '--lr-scheduler', 'cosine', '-- warmup-epoch', '5', '--weight-decay', '8e-4', '--cld_t', '0.2', '--save-freq', '100', '--three-imgs', '-- use-kmeans', '--num-iters', '5', '--Lambda', '0.8', '--normlinear', '--aug-plus', '--erasing', '--cluster s', '200', '--save-dir', 'checkpoint/cifar10/MoCo+CLD/resnet18/lr0.03-bs256-cldT0.2-nceT0.07-clusters200- lambda0.8-cosine-weightDecay8e-4-fp16-add_erasing-AugPlus-kMeans-ncek12288-bslr0.03-normlinear']' returne d non-zero exit status 1.

BYOL+CLD

Hi!

I am currently working with your proposed CLD loss in BYOL. I found your paper and your code really helpful! I wonder if I should use the detached output from the target encoder when calculating CLD loss in BYOL to stop gradients? Thank you!

How to perform image retrieval like Figure A.11 ?

Thank you for your previous answers. After training models and having checkpoints. I want to do image retrieval like yours (Figure A.11: Comparisons of top retrieves by NPID). How can I do the same and could you please show me the code to do so? @frank-xwang

What is kNN linear classifier?

I read your paper and I was wondering what is kNN linear classifier? As I know kNN is a non-linear classifier right? How is it different from yours?

Unexpected key(s) in state_dict: "module.conv1.weight"

When I finished the downstream task - one layer linear classifier (main_lincls.py), I saved the model using main_lincls.py. When I want to reuse the saved model from main_lincls.py, I got error:

Unexpected key(s) in state_dict: "module.conv1.weight" ........

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.