Coder Social home page Coder Social logo

salesforce / pcl Goto Github PK

View Code? Open in Web Editor NEW
541.0 17.0 82.0 185 KB

PyTorch code for "Prototypical Contrastive Learning of Unsupervised Representations"

License: MIT License

Python 100.00%
representation-learning self-supervised-learning unsupervsied-learning contrastive-learning pre-trained-model

pcl's Introduction

Prototypical Contrastive Learning of Unsupervised Representations (Salesforce Research)

This is a PyTorch implementation of the PCL paper:

@inproceedings{PCL,
	title={Prototypical Contrastive Learning of Unsupervised Representations},
	author={Junnan Li and Pan Zhou and Caiming Xiong and Steven C.H. Hoi},
	booktitle={ICLR},
	year={2021}
}

Requirements:

  • ImageNet dataset
  • Python ≥ 3.6
  • PyTorch ≥ 1.4
  • faiss-gpu: pip install faiss-gpu
  • pip install tqdm

Unsupervised Training:

This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler; single-gpu or DataParallel training is not supported.

To perform unsupervised training of a ResNet-50 model on ImageNet using a 4-gpu or 8-gpu machine, run:

python main_pcl.py \ 
  -a resnet50 \ 
  --lr 0.03 \
  --batch-size 256 \
  --temperature 0.2 \
  --mlp --aug-plus --cos (only activated for PCL v2) \	
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  --exp-dir experiment_pcl
  [Imagenet dataset folder]

Download Pre-trained Models

PCL v1 PCL v2

Linear SVM Evaluation on VOC

To train a linear SVM classifier on VOC dataset, using frozen representations from a pre-trained model, run:

python eval_svm_voc.py --pretrained [your pretrained model] \
  -a resnet50 \ 
  --low-shot (only for low-shot evaluation, otherwise the entire dataset is used) \
  [VOC2007 dataset folder]

Linear SVM classification result on VOC, using ResNet-50 pretrained with PCL for 200 epochs:

Model k=1 k=2 k=4 k=8 k=16 Full
PCL v1 46.9 56.4 62.8 70.2 74.3 82.3
PCL v2 47.9 59.6 66.2 74.5 78.3 85.4

k is the number of training samples per class.

Linear Classifier Evaluation on ImageNet

Requirement: pip install tensorboard_logger
To train a logistic regression classifier on ImageNet, using frozen representations from a pre-trained model, run:

python eval_cls_imagenet.py --pretrained [your pretrained model] \
  -a resnet50 \ 
  --lr 5 \
  --batch-size 256 \
  --id ImageNet_linear \ 
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  [Imagenet dataset folder]

Linear classification result on ImageNet, using ResNet-50 pretrained with PCL for 200 epochs:

PCL v1 PCL v2
61.5 67.6

pcl's People

Contributors

lijunnan1992 avatar svc-scm avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pcl's Issues

The training time of PCL?

Hi, I want to know the training time of PCL on each data set, ImageNet & VOC. And the type of graphics card. Thanks.

A question of preprocessing the imagenet dataset

Hi,

I am very interested in your ProtoNCE paper and I tried to run the unsupervised training example in your readme file. However, I got stuck when loading the imagenet dataset while training.

The problem is that I cannot find any code to generate the train folder:

# Data loading code
    traindir = os.path.join(args.data, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    if args.aug_plus:
        # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
        augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([pcl.loader.GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
    else:
        # MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978
        augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
        
    # center-crop augmentation 
    eval_augmentation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
        ])    
       
    train_dataset = pcl.loader.ImageFolderInstance(
        traindir,
        pcl.loader.TwoCropsTransform(transforms.Compose(augmentation)))
    eval_dataset = pcl.loader.ImageFolderInstance(
        traindir,
        eval_augmentation)

I used your code to download the whole VOC2007 dataset and extract it to the folder VOCdevkit. However, the variable traindir says that I have to have a folder train in the VOC2007 dataset folder. So where does this folder come from?

Best,
Wubin

about cifar10

Thank you for providing this code.
My question is: what is the optimal number of clusters to train on cifar10?

Last prototype never sampled as negative

all_proto_id = [i for i in range(im2cluster.max())]

In line 178 of PCL builder file,
all_proto_id = [i for i in range(im2cluster.max())]
does not include the id of the last prototype and hence the last prototype is never sampled as negative in this implementation.
Should be:
all_proto_id = [i for i in range(im2cluster.max()+1)]

Problem during the adaptation to graph data for PCL.

Hi Junnan,

Thanks for your great work! Now I'd like to conduct the PCL experiments on graph classification task (i.e., TUDataset), which includes about 1-5k graph instances with two-classes labels. I modified the code based on GraphCL. However, the Acc@Proto is 0.0 and the eval results are unchangeable approximately. I printed this line and the accp = 0 and the proto_out & proto_target is as follows:

捕获

My hyperparameters are as follow:

    parser.add_argument('--num-cluster', default=[200], type=int,
                        help='number of clusters')
    parser.add_argument('--bs', dest='bs', type=int, default=128,
                        help='batch_size')
    queue size; number of negative keys (default: 128)

Could you help me locate the reason for this problem? Thanks in advance!

inconsistency between the InfoNCE's negative samples and its corresponding clusters

Hi,

I very much appreciate this work, and thank you for providing the implementation code. I noticed that your implementation of InfoNCE loss is from the original MoCo repository, whose negative samples are determined by the key queue. According to the description, "instance discrimination task can be explained as a special case of prototypical contrastive learning," the two samples that belong to the same cluster can't be considered as a negative pair. But after tracing your code, I think the queue samples which cluster is the same as the query ones would be taken as negative pairs in InfoNCE. Have you tried to take the same cluster queue samples as positive?

best,
Chi-Chang Lee.

question about concentration around a prototype

In the paper, you have mentioned "With the proposed φ, the similarity in a loose cluster (larger φ) are down-scaled,
pulling embeddings closer to the prototype", but i am wondering why the down-scaled similarity can force them get closer?
Could you please explain it more detailedly? Thanks!

About the cluster size analysis and the uniform distribution assumption for each cluster

Hi

I very much appreciate this work, and thank you for providing the implementation code. From your derivation of the proto loss term, the cluster size P(c; theta) would be assumed to 1/k, but you mentioned that each cluster might have imbalance problems in your balance analysis part. I am just curious that why you made this assumption instead of calculating each cluster's sample numbers. Have you tried to conduct such settings in your experiments?

BTW, could you tell me that what is the effect on the performance as the imbalance problem happens?

best,
Chi-Chang Lee.

Questions about concentration estimation

Hi,

Your work is great! I've went through your paper and your code, and I think that your improvements made on MoCo is very clever. I'm just a little confused about part 3.3 of your paper, where you presented a concentration estimation. What do you mean by similar concentration? Could you provide two visualized examples of similar concentration but obviously different sizes? I know it shouldn't be the case in your method since you've intentionally maintained balanced clusters, but I just want to understand this concept more intuitively. It would be even better if you could also explain how you end up with that formula for phi.

Sorry if the questions sound dumb, but I'm new to this area and would really appreciate any help.

Single GPU

I can only use a single GPU. Is it possible to modify your code to run on a single GPU? If so, could you tell me where I should modify it? Thank you so much!

Questioning about the cluster method?

I want to know that have you considered other cluster methods because KNN is unsuitable for some high-dimension and discrete data. I believe this framework has the potential to change the trend in the fields of constructive learning because it allows us to capture more subtle information. The key point is to find a suitable method for the cluster.

About Loss of InfoNCE and Cluter_results

Hi,

  1. I notice that the labels created in InfoNCE loss is always a zero-vector:(

    labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
    )
    I think this is wrong since otherwise the loss will always be zero. Did I mis-understand the codes?

  2. In creating the Custer_Result dictionary, I found that only eval dataset was involved into consideration:
    (

    PCL/main_pcl.py

    Line 299 in 964da1f

    cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset),dtype=torch.long).cuda())
    )
    So what is the motivation behind this operation, I think we should run it on training set.

About NCE

Hi,
your work is really excellent!
However, I have a question when going through the code. Why the InfoNCE and ProtoNCE are replaced by the cross entropy, instead of the way shown in the paper?

Apply to CIFAR-10

Hello, I'm trying to apply this PCL technique to CIFAR-10 since I don't directly have widespread gpu access. After modifying it to run on the cpu and mps for pytorch, I ran into a few problems. One that I fundamentally am having trouble with is in the training.
1.
image
The pseudo-code showing the EM updates doesn't reflect the code. I see in the code:

for epoch in start_epoch to epochs:
    if epoch >= warmup_epoch:
        features = compute_features(data_loader, model)
        normalize features with L2-norm > 1.5
        cluster_result = run_kmeans(features)

    for batch in train_loader:
        load data
        compute model outputs and targets for both instance and prototype learning
        calculate InfoNCE loss and, if applicable, ProtoNCE loss
        perform backpropagation

I don't see in this code how there's an E step AND THEN the M step. Why is there a mismatch here, which is correct?

  1. Assuming the code is correct, why do we calculate cluster_result on the eval_dataset (10,000 samples for cifar). See "eval_loader" and "len(eval_dataset)" being the culprits.
if epoch >= warmup_epoch:
        # compute momentum features for center-cropped images
        features = compute_features(eval_loader, model, low_dim, device)
        
        # placeholder for clustering result
        cluster_result = {'im2cluster':[],'centroids':[],'density':[]}
        for num_cluster in num_clusters:  # Assuming num_clusters is an iterable of desired cluster counts
            cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset), dtype=torch.long))

However, when we call train a few lines later, we pass in train_loader (which has a 50,000 samples, a different number of samples than eval_loader, 10,000 samples) and cluster_result (which holds 10,000 samples).

train(train_loader, model, criterion, optimizer, epoch, device, cluster_result)

Therefore, there's a shape mismatch in the forward function when we do

    111 for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'], cluster_result['centroids'], cluster_result['density'])):
    112     # get positive prototypes
--> 113     pos_proto_id = im2cluster[index]

Thus leading to this error:

IndexError Traceback (most recent call last)
Cell In[77], line 28
24 adjust_learning_rate(optimizer, epoch, lr)
26 # train for one epoch
27 # print("device", device)
---> 28 losses = train(train_loader, model, criterion, optimizer, epoch, device, cluster_result)
30 print('Epoch: [{0}]\t'.format(epoch, loss=losses))
32 if (epoch+1)%5==0:

Cell In[76], line 44, in train(train_loader, model, criterion, optimizer, epoch, device, cluster_result)
34 # visualize the images
35 # plt.figure(figsize=(6, 3)) # Adjust the size as necessary
36 # plt.subplot(1, 2, 1)
(...)
41 # compute output
42 ### HHEEEREEE
43 print("min",min(index), "max", max(index))
---> 44 output, target, output_proto, target_proto = model(im_q=images[0], im_k=images[1], cluster_result=cluster_result, index=index)
45 target=target.to(device)
46 # InfoNCE loss

File ~/anaconda3/envs/x/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/x/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[62], line 113, in MoCo.forward(self, im_q, im_k, is_eval, cluster_result, index)
110 proto_logits = []
111 for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'], cluster_result['centroids'], cluster_result['density'])):
112 # get positive prototypes
--> 113 pos_proto_id = im2cluster[index]
114 pos_prototypes = prototypes[pos_proto_id]
116 # sample negative prototypes

IndexError: index 18843 is out of bounds for dimension 0 with size 10000

Any ideas how index works/what cluster_result should be holding?

Question about eq(9) in your paper.

Hi, thanks for your paper and code. I have a question about eq(9) in your paper, it seems that this eq is p(ci|xi), not p(xi|ci). I think p(xi|ci) includes only a single Gaussian distribution so that the integration on xi equals 1. Can you explain it for me?
image

The settings for KNN classifier?

Thanks for your work. I am confused about your settings for KNN classifier.
I used the KNN predict function from MoCo, which uses the test set to match the training set with K nearest neighbors and T for temperature.
I am curious that, what K and T do you set, and which train or test set do you use as probe set?
Looking forward to your reply!

Imbalanced datasets

Hi,

Thank you for this implementation. It is my understanding that some contrastive frameworks build upon entropy maximization which leads to inapplicability in the contexts of imbalanced datasets. I don't see the direct connection between the ProtoNCE loss and the entropy so I was wondering, does this method support imbalanced datasets?

Thanks

Is it possible to resume training with one of the pretrained models?

Hi,

I was trying to resume the training using the "--resume" argument and one of the PCL pre-trained models provided on the homepage. But I am getting the following error.

Can anyone please help me with this? I am trying to using the pretrained imagenet models on a different dataset, with training some extra epochs.

RuntimeError: Error(s) in loading state_dict for DistributedDataParallel: Missing key(s) in state_dict: "module.queue", "module.queue_ptr", "module.encoder_k.conv1.weight", "module.encoder_k.bn1.weight", "module.encoder_k.bn1.bias", "module.encoder_k.bn1.running_mean", "module.encoder_k.bn1.running_var", "module.encoder_k.layer1.0.conv1.weight", "module.encoder_k.layer1.0.bn1.weight", "module.encoder_k.layer1.0.bn1.bias", "module.encoder_k.layer1.0.bn1.running_mean", "module.encoder_k.layer1.0.bn1.running_var", "module.encoder_k.layer1.0.conv2.weight", "module.encoder_k.layer1.0.bn2.weight", "module.encoder_k.layer1.0.bn2.bias", "module.encoder_k.layer1.0.bn2.running_mean", "module.encoder_k.layer1.0.bn2.running_var", "module.encoder_k.layer1.0.conv3.weight", "module.encoder_k.layer1.0.bn3.weight", "module.encoder_k.layer1.0.bn3.bias", "module.encoder_k.layer1.0.bn3.running_mean", "module.encoder_k.layer1.0.bn3.running_var", "module.encoder_k.layer1.0.downsample.0.weight", "module.encoder_k.layer1.0.downsample.1.weight", "module.encoder_k.layer1.0.downsample.1.bias", "module.encoder_k.layer1.0.downsample.1.running_mean", "module.encoder_k.layer1.0.downsample.1.running_var", "module.encoder_k.layer1.1.conv1.weight", "module.encoder_k.layer1.1.bn1.weight", "module.encoder_k.layer1.1.bn1.bias", "module.encoder_k.layer1.1.bn1.running_mean", "module.encoder_k.layer1.1.bn1.running_var", "module.encoder_k.layer1.1.conv2.weight", "module.encoder_k.layer1.1.bn2.weight", "module.encoder_k.layer1.1.bn2.bias", "module.encoder_k.layer1.1.bn2.running_mean", "module.encoder_k.layer1.1.bn2.running_var", "module.encoder_k.layer1.1.conv3.weight", "module.encoder_k.layer1.1.bn3.weight", "module.encoder_k.layer1.1.bn3.bias", "module.encoder_k.layer1.1.bn3.running_mean", "module.encoder_k.layer1.1.bn3.running_var", "module.encoder_k.layer1.2.conv1.weight", "module.encoder_k.layer1.2.bn1.weight", "module.encoder_k.layer1.2.bn1.bias", "module.encoder_k.layer1.2.bn1.running_mean", "module.encoder_k.layer1.2.bn1.running_var", "module.encoder_k.layer1.2.conv2.weight", "module.encoder_k.layer1.2.bn2.weight", "module.encoder_k.layer1.2.bn2.bias", "module.encoder_k.layer1.2.bn2.running_mean", "module.encoder_k.layer1.2.bn2.running_var", "module.encoder_k.layer1.2.conv3.weight", "module.encoder_k.layer1.2.bn3.weight", "module.encoder_k.layer1.2.bn3.bias", "module.encoder_k.layer1.2.bn3.running_mean", "module.encoder_k.layer1.2.bn3.running_var", "module.encoder_k.layer2.0.conv1.weight", "module.encoder_k.layer2.0.bn1.weight", "module.encoder_k.layer2.0.bn1.bias", "module.encoder_k.layer2.0.bn1.running_mean", "module.encoder_k.layer2.0.bn1.running_var", "module.encoder_k.layer2.0.conv2.weight", "module.encoder_k.layer2.0.bn2.weight", "module.encoder_k.layer2.0.bn2.bias", "module.encoder_k.layer2.0.bn2.running_mean", "module.encoder_k.layer2.0.bn2.running_var", "module.encoder_k.layer2.0.conv3.weight", "module.encoder_k.layer2.0.bn3.weight", "module.encoder_k.layer2.0.bn3.bias", "module.encoder_k.layer2.0.bn3.running_mean", "module.encoder_k.layer2.0.bn3.running_var", "module.encoder_k.layer2.0.downsample.0.weight", "module.encoder_k.layer2.0.downsample.1.weight", "module.encoder_k.layer2.0.downsample.1.bias", "module.encoder_k.layer2.0.downsample.1.running_mean", "module.encoder_k.layer2.0.downsample.1.running_var", "module.encoder_k.layer2.1.conv1.weight", "module.encoder_k.layer2.1.bn1.weight", "module.encoder_k.layer2.1.bn1.bias", "module.encoder_k.layer2.1.bn1.running_mean", "module.encoder_k.layer2.1.bn1.running_var", "module.encoder_k.layer2.1.conv2.weight", "module.encoder_k.layer2.1.bn2.weight", "module.encoder_k.layer2.1.bn2.bias", "module.encoder_k.layer2.1.bn2.running_mean", "module.encoder_k.layer2.1.bn2.running_var", "module.encoder_k.layer2.1.conv3.weight", "module.encoder_k.layer2.1.bn3.weight", "module.encoder_k.layer2.1.bn3.bias", "module.encoder_k.layer2.1.bn3.running_mean", "module.encoder_k.layer2.1.bn3.running_var", "module.encoder_k.layer2.2.conv1.weight", "module.encoder_k.layer2.2.bn1.weight", "module.encoder_k.layer2.2.bn1.bias", "module.encoder_k.layer2.2.bn1.running_mean", "module.encoder_k.layer2.2.bn1.running_var", "module.encoder_k.layer2.2.conv2.weight", "module.encoder_k.layer2.2.bn2.weight", "module.encoder_k.layer2.2.bn2.bias", "module.encoder_k.layer2.2.bn2.running_mean", "module.encoder_k.layer2.2.bn2.running_var", "module.encoder_k.layer2.2.conv3.weight", "module.encoder_k.layer2.2.bn3.weight", "module.encoder_k.layer2.2.bn3.bias", "module.encoder_k.layer2.2.bn3.running_mean", "module.encoder_k.layer2.2.bn3.running_var", "module.encoder_k.layer2.3.conv1.weight", "module.encoder_k.layer2.3.bn1.weight", "module.encoder_k.layer2.3.bn1.bias", "module.encoder_k.layer2.3.bn1.running_mean", "module.encoder_k.layer2.3.bn1.running_var", "module.encoder_k.layer2.3.conv2.weight", "module.encoder_k.layer2.3.bn2.weight", "module.encoder_k.layer2.3.bn2.bias", "module.encoder_k.layer2.3.bn2.running_mean", "module.encoder_k.layer2.3.bn2.running_var", "module.encoder_k.layer2.3.conv3.weight", "module.encoder_k.layer2.3.bn3.weight", "module.encoder_k.layer2.3.bn3.bias", "module.encoder_k.layer2.3.bn3.running_mean", "module.encoder_k.layer2.3.bn3.running_var", "module.encoder_k.layer3.0.conv1.weight", "module.encoder_k.layer3.0.bn1.weight", "module.encoder_k.layer3.0.bn1.bias", "module.encoder_k.layer3.0.bn1.running_mean", "module.encoder_k.layer3.0.bn1.running_var", "module.encoder_k.layer3.0.conv2.weight", "module.encoder_k.layer3.0.bn2.weight", "module.encoder_k.layer3.0.bn2.bias", "module.encoder_k.layer3.0.bn2.running_mean", "module.encoder_k.layer3.0.bn2.running_var", "module.encoder_k.layer3.0.conv3.weight", "module.encoder_k.layer3.0.bn3.weight", "module.encoder_k.layer3.0.bn3.bias", "module.encoder_k.layer3.0.bn3.running_mean", "module.encoder_k.layer3.0.bn3.running_var", "module.encoder_k.layer3.0.downsample.0.weight", "module.encoder_k.layer3.0.downsample.1.weight", "module.encoder_k.layer3.0.downsample.1.bias", "module.encoder_k.layer3.0.downsample.1.running_mean", "module.encoder_k.layer3.0.downsample.1.running_var", "module.encoder_k.layer3.1.conv1.weight", "module.encoder_k.layer3.1.bn1.weight", "module.encoder_k.layer3.1.bn1.bias", "module.encoder_k.layer3.1.bn1.running_mean", "module.encoder_k.layer3.1.bn1.running_var", "module.encoder_k.layer3.1.conv2.weight", "module.encoder_k.layer3.1.bn2.weight", "module.encoder_k.layer3.1.bn2.bias", "module.encoder_k.layer3.1.bn2.running_mean", "module.encoder_k.layer3.1.bn2.running_var", "module.encoder_k.layer3.1.conv3.weight", "module.encoder_k.layer3.1.bn3.weight", "module.encoder_k.layer3.1.bn3.bias", "module.encoder_k.layer3.1.bn3.running_mean", "module.encoder_k.layer3.1.bn3.running_var", "module.encoder_k.layer3.2.conv1.weight", "module.encoder_k.layer3.2.bn1.weight", "module.encoder_k.layer3.2.bn1.bias", "module.encoder_k.layer3.2.bn1.running_mean", "module.encoder_k.layer3.2.bn1.running_var", "module.encoder_k.layer3.2.conv2.weight", "module.encoder_k.layer3.2.bn2.weight", "module.encoder_k.layer3.2.bn2.bias", "module.encoder_k.layer3.2.bn2.running_mean", "module.encoder_k.layer3.2.bn2.running_var", "module.encoder_k.layer3.2.conv3.weight", "module.encoder_k.layer3.2.bn3.weight", "module.encoder_k.layer3.2.bn3.bias", "module.encoder_k.layer3.2.bn3.running_mean", "module.encoder_k.layer3.2.bn3.running_var", "module.encoder_k.layer3.3.conv1.weight", "module.encoder_k.layer3.3.bn1.weight", "module.encoder_k.layer3.3.bn1.bias", "module.encoder_k.layer3.3.bn1.running_mean", "module.encoder_k.layer3.3.bn1.running_var", "module.encoder_k.layer3.3.conv2.weight", "module.encoder_k.layer3.3.bn2.weight", "module.encoder_k.layer3.3.bn2.bias", "module.encoder_k.layer3.3.bn2.running_mean", "module.encoder_k.layer3.3.bn2.running_var", "module.encoder_k.layer3.3.conv3.weight", "module.encoder_k.layer3.3.bn3.weight", "module.encoder_k.layer3.3.bn3.bias", "module.encoder_k.layer3.3.bn3.running_mean", "module.encoder_k.layer3.3.bn3.running_var", "module.encoder_k.layer3.4.conv1.weight", "module.encoder_k.layer3.4.bn1.weight", "module.encoder_k.layer3.4.bn1.bias", "module.encoder_k.layer3.4.bn1.running_mean", "module.encoder_k.layer3.4.bn1.running_var", "module.encoder_k.layer3.4.conv2.weight", "module.encoder_k.layer3.4.bn2.weight", "module.encoder_k.layer3.4.bn2.bias", "module.encoder_k.layer3.4.bn2.running_mean", "module.encoder_k.layer3.4.bn2.running_var", "module.encoder_k.layer3.4.conv3.weight", "module.encoder_k.layer3.4.bn3.weight", "module.encoder_k.layer3.4.bn3.bias", "module.encoder_k.layer3.4.bn3.running_mean", "module.encoder_k.layer3.4.bn3.running_var", "module.encoder_k.layer3.5.conv1.weight", "module.encoder_k.layer3.5.bn1.weight", "module.encoder_k.layer3.5.bn1.bias", "module.encoder_k.layer3.5.bn1.running_mean", "module.encoder_k.layer3.5.bn1.running_var", "module.encoder_k.layer3.5.conv2.weight", "module.encoder_k.layer3.5.bn2.weight", "module.encoder_k.layer3.5.bn2.bias", "module.encoder_k.layer3.5.bn2.running_mean", "module.encoder_k.layer3.5.bn2.running_var", "module.encoder_k.layer3.5.conv3.weight", "module.encoder_k.layer3.5.bn3.weight", "module.encoder_k.layer3.5.bn3.bias", "module.encoder_k.layer3.5.bn3.running_mean", "module.encoder_k.layer3.5.bn3.running_var", "module.encoder_k.layer4.0.conv1.weight", "module.encoder_k.layer4.0.bn1.weight", "module.encoder_k.layer4.0.bn1.bias", "module.encoder_k.layer4.0.bn1.running_mean", "module.encoder_k.layer4.0.bn1.running_var", "module.encoder_k.layer4.0.conv2.weight", "module.encoder_k.layer4.0.bn2.weight", "module.encoder_k.layer4.0.bn2.bias", "module.encoder_k.layer4.0.bn2.running_mean", "module.encoder_k.layer4.0.bn2.running_var", "module.encoder_k.layer4.0.conv3.weight", "module.encoder_k.layer4.0.bn3.weight", "module.encoder_k.layer4.0.bn3.bias", "module.encoder_k.layer4.0.bn3.running_mean", "module.encoder_k.layer4.0.bn3.running_var", "module.encoder_k.layer4.0.downsample.0.weight", "module.encoder_k.layer4.0.downsample.1.weight", "module.encoder_k.layer4.0.downsample.1.bias", "module.encoder_k.layer4.0.downsample.1.running_mean", "module.encoder_k.layer4.0.downsample.1.running_var", "module.encoder_k.layer4.1.conv1.weight", "module.encoder_k.layer4.1.bn1.weight", "module.encoder_k.layer4.1.bn1.bias", "module.encoder_k.layer4.1.bn1.running_mean", "module.encoder_k.layer4.1.bn1.running_var", "module.encoder_k.layer4.1.conv2.weight", "module.encoder_k.layer4.1.bn2.weight", "module.encoder_k.layer4.1.bn2.bias", "module.encoder_k.layer4.1.bn2.running_mean", "module.encoder_k.layer4.1.bn2.running_var", "module.encoder_k.layer4.1.conv3.weight", "module.encoder_k.layer4.1.bn3.weight", "module.encoder_k.layer4.1.bn3.bias", "module.encoder_k.layer4.1.bn3.running_mean", "module.encoder_k.layer4.1.bn3.running_var", "module.encoder_k.layer4.2.conv1.weight", "module.encoder_k.layer4.2.bn1.weight", "module.encoder_k.layer4.2.bn1.bias", "module.encoder_k.layer4.2.bn1.running_mean", "module.encoder_k.layer4.2.bn1.running_var", "module.encoder_k.layer4.2.conv2.weight", "module.encoder_k.layer4.2.bn2.weight", "module.encoder_k.layer4.2.bn2.bias", "module.encoder_k.layer4.2.bn2.running_mean", "module.encoder_k.layer4.2.bn2.running_var", "module.encoder_k.layer4.2.conv3.weight", "module.encoder_k.layer4.2.bn3.weight", "module.encoder_k.layer4.2.bn3.bias", "module.encoder_k.layer4.2.bn3.running_mean", "module.encoder_k.layer4.2.bn3.running_var", "module.encoder_k.fc.0.weight", "module.encoder_k.fc.0.bias", "module.encoder_k.fc.2.weight", "module.encoder_k.fc.2.bias".

Hello, can I ask you about the accuracy of cifar-10, ImageNet100?

I experimented in cifar10, ImageNet100 in ResNet18 by using this code, but the accuracy of linear classifier is
cifar10 : 82.77 (pcl_r : 512, num_cluster : [1000,2000,4000] / pcl_r 1024, num_cluster : [1500,3000,4500])
ImageNet100 : 64.40 (pcl_r : 1024, num_cluster : [2500,5000,10000] / pcl_r : 4096, num_cluster : [6000,8000,12000])

Even I adjusted the hyper-parameter, It can not be achieved other method's performance (byol, etc..). If you know the reason, could you tell me why? Thank you

about run-kmeans

When I look at run-kmeans, it uses L2 distance for clustering, is L2 distance better than cosine distance?

index = faiss.GpuIndexFlatL2(res, d, cfg)

Question about the equation 12 in the paper

Hi ! Could I ask how do you design the self-adaptive temperature parameter, especially the denominator Z log(Z + α). I don't quite know why there exists the "log(Z + α)". Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.