Coder Social home page Coder Social logo

sit's Introduction

SiT: Self-supervised vIsion Transformer

This repository contains the official PyTorch self-supervised pretraining, finetuning, and evaluation codes for SiT (Self-supervised image Transformer).

The finetuning strategy is adopted from Deit

Usage

  • Create an environment

conda create -n SiT python=3.8

  • Activate the environment and install the necessary packages

conda activate SiT

conda install pytorch torchvision torchaudio cudatoolkit=11.0 -c pytorch

pip install -r requirements.txt

Self-supervised pre-training

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch_size 64 --epochs 801 --data-set 'ImageNet' --output_dir 'checkpoints/SSL/ImageNet'

Self-supervised pre-trained models using SiT can be downloaded from here

Notes:

  1. assign the --dataset_location parameter to the location of the downloaded dataset
  2. Set lmbda to high value when pretraining on small datasets, e.g. lmbda=5

If you use this code for a paper, please cite:

@article{atito2021sit,

  title={SiT: Self-supervised vIsion Transformer},

  author={Atito, Sara and Awais, Muhammad and Kittler, Josef},

  journal={arXiv preprint arXiv:2104.03602},

  year={2021}

}

License

This repository is released under the GNU General Public License.

sit's People

Contributors

sara-ahmed avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sit's Issues

Please provide details when training on cifar10

Good job!!
Could you provide some hyper-parameters for training SSL on cifar10?(just like LR,optimizer,weight decay etc)I'm curious about how to train vit on cifar10 instead of fine tuning.

SSL training on image folder

Hi,
I would like to train SiT on a set of images stored in a folder. Do you have train scripts to load the data from folder?

Questions about updated code

Hello @Sara-Ahmed,

Thank you for posting the PyTorch implementation of SiT! It seems like the current commit is unable to run, and I see in #28 you mention you will update the github soon. I have tried running the code from commit 1aacd6a and have nearly gotten the pre-training working. However, I can't find a version of torchvision that works with my GPU (GeForce RTX 3060 sm_86) and the SiT code from this commit. I have to update timm to version 0.4.12 to avoid this error which seems to break functionality with your code.

Do you think this will be fixed in the next release of SiT code on github? Also, you could please include all package versions used in requirements.txt (or include a copy of pip freeze)?

Thank you for all your help!
Roshan

An Error has occurred in self-supervised pre-training

@Sara-Ahmed
Thank you for sharing your wonderful achievements!

When I ran self-supervised pre-training as described, the following subprocess CalledProcessError was raised. Can you please help me how to solve this problem?

Typed command
python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 72 --epochs 501 --min-lr 5e-6 --lr 1e-3 --training-mode 'SSL' --data-set 'STL10' --output 'checkpoints/SSL/STL10' --validate-every 10

Errors encountered
subprocess.CalledProcessError: Command '['/usr/bin/python', '-u', 'main.py', '--batch-size', '72', '--epochs', '501', '--min-lr', '5e-6', '--lr', '1e-3', '--training-mode', 'SSL', '--data-set', 'STL10', '--output', 'checkpoints/SSL/STL10', '--validate-every', '10']' returned non-zero exit status 2.

why 3 tasks/objectives?

Hi, I was reading your paper these days, but i don't really understand why you setted 3 tasks/objectives(named image reconstruction,contrastive predition and rotation prediction), what are the purposes ? thanks

patch size on CIFAR100

Hi, @Sara-Ahmed

Thanks for your great work.

How do you deal with the CIFAR100 (32x32 image size)?

Adjust patch size or just resize the input image to 224x224 first?

utils.py is outdated

I am having trouble running the code, mostly from the error messages that the "utils.py" is missing several functions.

AttributeError: module 'utils' has no attribute 'get_params_groups'

Thanks!

--resume issue...!

When I train again using "--resume", the Nan value appears in the model both SSL and fine-tuning. I checked the image entering input, but there was no problem.

The following will stop and warning. "Loss is nan, stopping training"

Cannot dist

请问一下我只有一个gpu没法分布式训练怎么办

Get latent feature space

Hi, I would like to get latent feature space. From my understanding, I can take the output from here x = self.forward_features(x) in forward. am I correct ?

image

Thank you

How to see the image reconstruction task results

The usage example shows how to finetune the classifier head of the model in the command line, but I'm not sure how to get the reconstructed image from this output.
Can you please provide a code sample for image reconstruction?
Which part of the model output can be used to visually represent inference results like in the diagram?

checkpoints

Thank you for your excellent work! I am interested in fine-tuning SiT for other tasks. Will you consider providing the pre-trained models?

ContrastiveLoss problem

I think in the code “nominator = torch.exp(positives / self.temperature) # 2*bs
denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature),” denominator contains nominator and the formula (k!=i) in the paper is somewhat different,thank you.

ZeroDivisionError: float division by zero

Hi Thanks for the nice work. I tried to trained the network on my own dataset. However, I got the issue below.

Traceback (most recent call last):
File "/home/Project/SiT/main.py", line 397, in
main(args)
File "/home/Project/SiT/main.py", line 343, in main
args.clip_grad, model_ema, mixup_fn)
File "/home/Project/SiT/engine.py", line 125, in train_SSL
for imgs1, rots1, imgs2, rots2 in metric_logger.log_every(data_loader, print_freq, header):
File "/home/Project/SiT/utils.py", line 164, in log_every
header, total_time_str, total_time / len(iterable)))
ZeroDivisionError: float division by zero

It seems like the iterable is alway 0 somehow, and I checked my data loader, there are thousands of data available.

L1 Regression Task

Hi Sara,

shouldn't you use
torch.exp(-.5*_s) * loss + 0.5 * _s
in line 51 for the L1 loss, as you want to minimize 1/alpha_1 and not 1/(2*alpha_1**2), as it is right now?

SiT/losses.py

Lines 32 to 51 in ffc7317

class LearnedLoss():
def __init__(self, losstype, batch_size=None):
if losstype == 'CrossEntropy':
self.lossF = torch.nn.CrossEntropyLoss()
self.adj = 1
elif losstype == 'L1':
self.lossF = torch.nn.L1Loss()
self.adj = 0.5
elif losstype == 'Contrastive':
self.lossF = ContrastiveLoss(batch_size)
self.adj = 1
def calculate_loss(self, output, label):
return self.lossF(output, label)
def calculate_weighted_loss(self, loss, _s):
return self.adj * torch.exp(-_s) * loss + 0.5 * _s

Thanks in advance for your answer!

EDIT 23.07:
I edited my question correcting an error in my proposal.

Random Erase is erroneously not being used during SSL training

Why is Random Erase not being used during SSL training? It seems that it is erroneously being turned off in main.py. See where args.reprob and args.recount are set to 0.

SiT/main.py

Lines 182 to 186 in 1aacd6a

# disable any harsh augmentation in case of Self-supervise training
if args.training_mode == 'SSL':
print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training")
args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0
args.aa = ''

Calling Rotation after the transform leads to an error

I think the recent commit to "datasets_utils.py" introduces an error as the input to the "RandomRotation" call would now be a tensor (because of the "ToTensor" transform call before) instead of a PIL Image.

The error message is listed below:

TypeError: transpose() received an invalid combination of arguments - got (int), but expected one of:
 * (name dim0, name dim1)
 * (int dim0, int dim1)

lack finetune and linprobe code

Hi,

Thanks for sharing this wonderful project.
When I run the following commands to finetune the project, I find that some code for finetuning and linprobing is lacking, which lead following commands cannot run at all.

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 120 --epochs 501 --min-lr 5e-6 --training-mode 'finetune' --data-set 'STL10' --finetune 'checkpoints/SSL/STL10/checkpoint.pth' --output 'checkpoints/finetune/STL10' --validate-every 10

Also, many function snippets in utils.py are missed, i.e., utils.restart_from_checkpoint, utils.fix_random_seeds(args.seed), also utils.get_sha().

Would you please share it?

Thank you.

Best,
Vera

Is the accuracy rate Top1 or Top5?

Thank you for your work and sharing.
I am concerned about whether the accuracy given in the paper is top1 or top5, especially the 93.02% obtained by SiT on STL-10 in Table 3.

RASampler

I cannot understand the RASampler, can you share the related paper or blog about the RASampler, thank you!

Data augmentation step before applying rotation

Hi Sara,

in your paper, you write: "We found that the network struggles to distinguish between the rotated image and the rotation of the flipped image as two different classes. Instead, we included the horizontal flipping to the data augmentation step before applying rotation, and hence, the network is trained to classify the image and the flipped image to the same class."

In:

def getItem(X, target = None, transform=None, training_mode = 'SSL'):
# in case of finetuning, returning the image and the target
if training_mode == 'finetune':
if transform is not None:
X = transform(X)
return X, target
X1, rot1 = RandomRotation(X)
X2, rot2 = RandomRotation(X)
if transform is not None:
X1 = transform(X1)
X2 = transform(X2)
return X1, rot1, X2, rot2

You do it the other way around, or am I mistaken?
You first take the same batch, apply random rotation independently, and then apply the standard augmentation routine. So, in that case, the original image and the flipped one may have two different rotational classes, or not?

Thanks in advance for clarification,
Best,
Julian

single node multi-GPU hangs

Hi,
I am running SSL training on a single node with two GPUs. It runs only when --nproc_per_node=1. When I set nproc_per_node=2 it gets stuck after init for the second GPU.

init_distributed_mode ....
| distributed init (rank 0): env://
| distributed init (rank 1): env://

setting dist_url to env://127.0.0.1 didn't fix it. I also tried --world_size=2.

Fine-tuning settings

First of all, great job! I just finished reading the paper, but could not find details regarding the fine-tuning settings on small-scale datasets. In particular, optimizer, number of epochs, and data augmentations when fine-tuning the SiT models pre-trained either on the same datasets or on ImageNet. If you could post those details it would make it easier to reproduce results. Thanks!

OSError: ... symbol free_gemm_select version libcublasLt.so.11 not defined

I followed the demo instructions for training and SSL model on STL-10 verbatim. However, I get this error immediately after starting the training process.

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --batch-size 72 --epochs 501 --min-lr 5e-6 --lr 1e-3 --training-mode 'SSL' --data-set 'STL10' --output 'checkpoints/SSL/STL10' --validate-every 10

results in this error:

OSError: /home/mroos/miniconda3/envs/SiT/lib/python3.8/site-packages/torch/lib/../../../../libcublas.so.11: symbol free_gemm_select version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference

Passing token logits to the loss

First of all, I want to thank you for making the code available.
It is well written and easily understandable.

I have a question about the rotational token and the tokens used for reconstructing the original image. I'm am not an expert on using PyTorch as I have always used tensorflow, so please forgive me if I ask stupid things.

As far as I can see, you're defining all the heads for the SSL loss here:

# Classifier head(s)
if training_mode == 'SSL':
self.rot_head = nn.Linear(self.num_features, 4)
self.contrastive_head = nn.Linear(self.num_features, 512)
self.convTrans = nn.ConvTranspose2d(embed_dim, in_chans, kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size))
# create learnable parameters for the MTL task
self.rot_w = nn.Parameter(torch.tensor([1.0]))
self.contrastive_w = nn.Parameter(torch.tensor([1.0]))
self.recons_w = nn.Parameter(torch.tensor([1.0]))
else:
self.rot_head = nn.Linear(self.num_features, num_classes)
self.contrastive_head = nn.Linear(self.num_features, num_classes)

And as far as I understand it, there is no final activation on these heads; They return logits, am I correct?

In the train_SSL routine, you then pass these logits to your criterion routine:

SiT/engine.py

Lines 146 to 148 in 27cc31d

loss, (loss1, loss2, loss3) = criterion(rot_p, rots,
contrastive1_p, contrastive2_p,
imgs_recon, imgs, r_w, cn_w, rec_w)

Which, if the train_mode is SSL, is the MTL_loss routine. There the logits are passed directly into their respective loss functions:

SiT/losses.py

Lines 32 to 51 in 27cc31d

class LearnedLoss():
def __init__(self, losstype, batch_size=None):
if losstype == 'CrossEntropy':
self.lossF = torch.nn.CrossEntropyLoss()
self.adj = 1
elif losstype == 'L1':
self.lossF = torch.nn.L1Loss()
self.adj = 0.5
elif losstype == 'Contrastive':
self.lossF = ContrastiveLoss(batch_size)
self.adj = 1
def calculate_loss(self, output, label):
return self.lossF(output, label)
def calculate_weighted_loss(self, loss, _s):
return self.adj * torch.exp(-_s) * loss + 0.5 * _s

Am I missing something? Especially in the reconstruction case, I think this cannot work, as you have normalized original images and unnormalized reconstructed image-logits and calculate the l1 loss between them. The contrastive loss should be fine, as you normalize the logits in the loss function. However, the CE loss for the rotational token is also calculated without prior activation, and I wonder why.

Could you point me to the error in my thinking?
Thanks

visualize results

Hi, I fine-tuned the model on my custom dataset for object detection and now I want to visualize the images and detected bounding boxes. Any idea how to do that?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.