Coder Social home page Coder Social logo

mboudiaf / tim Goto Github PK

View Code? Open in Web Editor NEW
116.0 6.0 18.0 6.22 MB

(NeurIPS 2020) Transductive Information Maximization for Few-Shot Learning https://arxiv.org/abs/2008.11297

License: MIT License

Shell 28.91% Python 71.09%
few-shot-classifcation few-shot-learning few-shot neurips-2020 mutual-information transductive-learning optimization-methods

tim's Introduction

TIM: Transductive Information Maximization

Introduction

This repo contains the code for our NeurIPS 2020 paper "Transductive Infomation Maximization (TIM) for few-shot learning" available at https://arxiv.org/abs/2008.11297. Our method maximizes the mutual information between the query features and predictions of a few-shot task, subject to supervision constraint from the support set. Results provided in the paper can be reproduced with this repo. Code was developped under python 3.8.3 and pytorch 1.4.0. The code is parallelized over tasks (which makes the execution of the 10'000 tasks very efficient).

1. Getting started

Please find the data and pretrained models at icloud drive. Please use cat command to reform the original file, and extract

Checkpoints: The checkpoints/ directory should be placed in the root dir, and be structured as follows:

├── mini
│   └── softmax
│       ├── densenet121
│       │   ├── best
│       │   ├── checkpoint.pth.tar
│       │   └── model_best.pth.tar

Data: The checkpoints should be placed in the root directory, and have a structure like. Because of the size, tiered_imagenet has been sharded into 24 shard, 1GB each. Use cat tiered_imagenet_* to reform the original file. Extract everything to data/. The data folder should should be structured as follows:

├── cub
  ├── attributes.txt
  └── CUB_200_2011
      ├── attributes
      ├── bounding_boxes.txt
      ├── classes.txt
      ├── image_class_labels.txt
      ├── images
      ├── images.txt
      ├── parts
      ├── README
      └── train_test_split.txt
├── mini_imagenet
└── tiered_imagenet
  ├── class_names.txt
  ├── data
  ├── synsets.txt
  ├── test_images_png.pkl
  ├── test_labels.pkl
  ├── train_images_png.pkl
  ├── train_labels.pkl
  ├── val_images_png.pkl
  └── val_labels.pkl

All required libraries should be easily found online, except for visdom_logger that you can download using:

pip install git+https://github.com/luizgh/visdom_logger

2. Train models (optional)

Instead of using the pre-trained models, you may want to train the models from scratch. Before anything, don't forget to activate the downloaded environment:

source env/bin/activate

Then to visualize the results, turn on your local visdom server:

python -m visdom.server -port 8097

and open it in your browser : http://localhost:8097/ . Then, for instance, if you want to train a Resnet-18 on mini-Imagenet, go to the root of the directory, and execute:

bash scripts/train/resnet18.sh

Important : Whenever you have trained yourself a new model and want to test it, please specify the option eval.fresh_start=True to your test command. Otherwise, the code may use cached information (used to speed-up experiments) from previously used models that are longer valid.

3. Reproducing the main results

Before anything, don't forget to activate the downloaded environement:

source env/bin/activate

3.1 Benchmarks (Table 1. in paper)

(1 shot/5 shot) Arch mini-Imagenet Tiered-Imagenet
TIM-ADM Resnet-18 73.6 / 85.0 80.0 / 88.5
TIM-GD Resnet-18 73.9 / 85.0 79.9 / 88.5
TIM-ADM WRN28-10 77.5 / 87.2 82.0 / 89.7
TIM-GD WRN28-10 77.8 / 87.4 82.1 / 89.8

To reproduce the results from Table 1. in the paper, use the bash files at scripts/evaluate/. For instance, if you want to reproduce the methods on mini-Imagenet, go to the root of the directory and execute:

bash scripts/evaluate/<tim_adm or tim_gd>/mini.sh

This will reproduce the results for the three network architectures in the paper (Resnet-18/WideResNet28-10/DenseNet-121). Upon completion, exhaustive logs can be found in logs/ folder

3.2 Domain shift (Table 2. in paper)

(5 shot) Arch CUB -> CUB mini-Imagenet -> CUB
TIM-ADM Resnet18 90.7 70.3
TIM-GD Resnet18 90.8 71.0

If you want to reproduce the methods on CUB -> CUB, go to the root of the directory and execute:

bash scripts/evaluate/<tim_adm or tim_gd>/cub.sh

If you want to reproduce the methods on mini -> CUB, go to the root of the directory and execute:

bash scripts/evaluate/<tim_adm or tim_gd>/mini2cub.sh

3.3 Tasks with more ways (Table 3. in paper)

If you want to reproduce the methods with more ways (10 and 20 ways) on mini-Imagenet, go to the root of the directory and execute:

bash scripts/evaluate/<tim_adm or tim_gd>/mini_10_20_ways.sh
(1 shot/5 shot) Arch 10 ways 20 ways
TIM-ADM Resnet18 56.0 / 72.9 39.5 / 58.8
TIM-GD Resnet18 56.1 / 72.8 39.3 / 59.5

3.4 Ablation study (Table 4. in paper)

If you want to reproduce the 4 loss configurations of on mini-Imagenet, Tiered-Imagenet and CUB, go to the root of the directory and execute:

bash scripts/ablation/<tim_adm or tim_gd or tim_gd_all>/weighting_effect.sh

for respectively TIM-ADM, TIM-GD {W} and TIM-GD {phi, W}.

Contact

For further questions or details, reach out to Malik Boudiaf ([email protected])

Acknowledgements

We would like to thank the authors from SimpleShot code https://github.com/mileyan/simple_shot and LaplacianShot https://github.com/imtiazziko/LaplacianShot for giving access to their pre-trained models and to their codes from which this repo was inspired.

tim's People

Contributors

mboudiaf avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

tim's Issues

Question

ERROR - FSL training - Failed after 0:04:19!
Traceback (most recent calls WITHOUT Sacred internals):
File "/home/data/qinyanfei/code/TIM-master/src/main.py", line 118, in main
if (epoch) % trainer.meta_val_interval == 0:
AttributeError: 'Trainer' object has no attribute 'meta_val_interval'

About training parameters for WRN

Thank you for releasing code for such a creative method! However I've faced some problems when reproducing the results.

From what I've known, TIM is one of the most s-o-t-a few-shot methods using WRN as backbone at present. But the pretraining strategy shown in the code seems problematic, as I could only train a model that gets a 34.99% accuracy in a 16-way 1-shot task on miniImageNet, comparing to ~44% with the strategy used in SIB or EPNet. It has been pointed out that the quality of pretrained backbone will significantly influence the performance of the method, and since TIM outperforms SIB and EPNet a lot with no extra fine-tuning stage, this becomes puzzling.

This the training strategy for (miniImageNet, WRN) I found in your code:
inital LR=0.1, optimizer: SGD w/ nesterov momentum=0.9 , weight decay=1e-4
N_epoch=90
LR schedule: multistep, LR*=0.1@epoch 45&67
Label smoothing 0.1
Data augment: Color jitter

Did I miss something? If not, could you tell me why you didn't adopt a better pretrained backbone to further improve the results? Thanks a lot.

Issues downloading tiered-imagenet

I have been having some issues downloading tieredImagenet from the link provided it says forbidden network. Could you kindly change the settings so that I can download it please since I am trying to use it for my research

unknown to the library visdom_logger

Hello, I have a trivial problem about library visdom_logger.
I have never seen this library, and also I did not find it in the PYPI.
Could you please tell what it is and where I can download it?
Thank you!

Train a model to reproduce domain-shift results

Hi, when I want to evaluate Domain shift (Table 2. in paper), it is needed to provide an appropriate pre-trained model (for example "checkpoints/mini2cub/softmax/resnet18") but, if we want to train the appropriate model from scratch, how it is possible to train it for the cross domain?
I did not find the script.

Hyperparameters for TIM

Hey Malik,

I am trying to use your model in my work and I was wondering what are the correct hyperparameters you are using in your experiments. In the paper you say that you use 1000 iterations and for the Adam optimizers the suggested ones is the paper, I assume that the lr is the one that pytorch uses 1e-3. However in your tim.py code under the config() you use lr = 1e-4 I was wondering if that is correct

ModuleNotFoundError: No module named 'visdom_logger'

Traceback (most recent call last):
File "/home/dell/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/home/dell/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/data/qinyanfei/code/TIM-master/src/main.py", line 8, in
from visdom_logger import VisdomLogger
ModuleNotFoundError: No module named 'visdom_logger'

Is the module missing?

About Adapting This Method to Other Datasets

Hi,
In #7 I've said I plan to adapt this method to another aluminum dataset.
I've already done that, yielding a rather great result(5 way 1 shot with accuracy 0.5171, 5 way 5 shot with accuracy 0.6927).
There is a question: in the backbone training stage, the highest accuracy it achieved is about 0.64(below 0.6927), is this abnormal?
And another interesting thing is, when I use backbone trained with the aluminum dataset to evaluate on NEU-CLS dataset, it achieved an incredible accuracy of [0.7010, 0.8141]!!!
Best

How can I adapt this method to another dataset?

I want to adapt this dataset to Ali Aluminum Dataset.
Now I've made my own split file, written my own al_tianchi.sh like:

python3 -m src.main \
		with dataset.path="/path/to/al_tianchi" \
		visdom_port=8097 \
		dataset.split_dir="split/al_tianchi" \
		ckpt_path="checkpoints/al_tianchi/softmax/resnet18" \
		dataset.batch_size=128 \
		dataset.jitter=True \
		model.arch='resnet18' \
		model.num_classes=10 \
		optim.scheduler="multi_step" \
		epochs=90 \
		trainer.label_smoothing=0.1

and trained my own resnet18 model(as I understand, the backbone).
What should I do next to test TIM on this dataset?
Sorry for my bad English, and sorry for my disturbing.
Best Wishes for You :)

tuning parameters

I wish you a merry Christmas. On this beautiful day, I would like to ask you a question about tuning parameters. When I run the training file according to your steps, for example, when I run resnet18.sh (because I will report an error, I only modify / src / datasets)/ ingredient.py:num_workers = 0,After training, the accuracy of 1-shot 0.3776 and 5-shot 0.5026 can only be achieved in the mini dataset. Is there any other parameter adjustment skills

AttributeError: 'Trainer' object has no attribute 'meta_val_interval'

when i run training code:bash scripts/train/resnet18.sh,
have a problem;

ERROR - FSL training - Failed after 0:04:19!
Traceback (most recent calls WITHOUT Sacred internals):
File "/home/data/qinyanfei/code/TIM-master/src/main.py", line 118, in main
if (epoch) % trainer.meta_val_interval == 0:
AttributeError: 'Trainer' object has no attribute 'meta_val_interval'

can you answer me?

URL

HI ,
I try to run "Download_data.py" and "Download_models.py" for some resources,but always show:
" requests.exceptions.ConnectionError: HTTPSConnectionPool(host='docs.google.com', port=443): Max retries exceeded with url: /uc?export=download&id=15MFsig6pjXO7vZdo-1znJoXtHv4NY-AF (Caused by NewConnectionError('<requests.packages.urllib3.connection.VerifiedHTTPSConnection object at 0x0000014B42FF0208>: Failed to establish a new connection: ",but i can open "https://docs.google.com",excuse for this URL is right?

The google drive download link of Dataset is unavailable

@mboudiaf . I can not download dataset files by using the python file ‘download_data.py’ in the file directory ‘./scripts/downloads‘’, maybe the dataset google drive download link is unavailable, can you provide the new dataset google drive download link to me? Thank you very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.