Coder Social home page Coder Social logo

hytseng0509 / crossdomainfewshot Goto Github PK

View Code? Open in Web Editor NEW
320.0 8.0 63.0 73 KB

Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation (ICLR 2020 spotlight)

Python 100.00%
few-shot-learning meta-learning domain-generalization iclr2020

crossdomainfewshot's Introduction

Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation

[Project Page][Paper]

Pytorch implementation for our cross-domain few-shot classification method. With the proposed learned feature-wise transformation layers, we are able to:

  1. improve the performance of exisiting few-shot classification methods under cross-domain setting
  2. achieve stat-of-the-art performance under single-domain setting.

Contact: Hung-Yu Tseng ([email protected])

Paper

Please cite our paper if you find the code or dataset useful for your research.

Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation
Hung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang
International Conference on Learning Representations (ICLR), 2020 (spotlight)

@inproceedings{crossdomainfewshot,
  author = {Tseng, Hung-Yu and Lee, Hsin-Ying and Huang, Jia-Bin and Yang, Ming-Hsuan},
  booktitle = {International Conference on Learning Representations},
  title = {Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation},
  year = {2020}
}

Usage

Prerequisites

  • Python >= 3.5
  • Pytorch >= 1.3 and torchvision (https://pytorch.org/)
  • You can use the requirements.txt file we provide to setup the environment via Anaconda.
conda create --name py36 python=3.6
conda install pytorch torchvision -c pytorch
pip3 install -r requirements.txt

Install

Clone this repository:

git clone https://github.com/hytseng0509/CrossDomainFewShot.git
cd CrossDomainFewShot

Datasets

Download 5 datasets seperately with the following commands.

  • Set DATASET_NAME to: cars, cub, miniImagenet, places, or plantae.
cd filelists
python3 process.py DATASET_NAME
cd ..
  • Refer to the instruction here for constructing your own dataset.

Feature encoder pre-training

We adopt baseline++ for MatchingNet, and baseline from CloserLookFewShot for other metric-based frameworks.

  • Download the pre-trained feature encoders.
cd output/checkpoints
python3 download_encoder.py
cd ../..
  • Or train your own pre-trained feature encoder (specify PRETRAIN to baseline++ or baseline).
python3 train_baseline.py --method PRETRAIN --dataset miniImagenet --name PRETRAIN --train_aug

Training with multiple seen domains

Baseline training w/o feature-wise transformations.

  • METHOD : metric-based framework matchingnet, relationnet_softmax, or gnnnet.
  • TESTSET: unseen domain cars, cub, places, or plantae.
python3 train_baseline.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_ori_METHOD --warmup PRETRAIN --train_aug

Training w/ learning-to-learned feature-wise transformations.

python3 train.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_lft_METHOD --warmup PRETRAIN --train_aug

Evaluation

Test the metric-based framework METHOD on the unseen domain TESTSET.

  • Specify the saved model you want to evaluate with --name (e.g., --name multi_TESTSET_lft_METHOD from the above example).
python3 test.py --method METHOD --name NAME --dataset TESTSET

Note

  • This code is built upon the implementation from CloserLookFewShot.
  • The dataset, model, and code are for non-commercial research purposes only.
  • You can change the number of shot (i.e. 1/5 shots) using the argument --n_shot.
  • You need a GPU with 16G memory for training the gnnnet approach w/ learning-to-learned feature-wise transformations.
  • 04/2020: We've corrected the code for training with multiple domains. Please find the link here for the model trained with the current implementation on Pytorch 1.4.

crossdomainfewshot's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

crossdomainfewshot's Issues

train_baseline result and provided pre-trained model on baseline++ with ResNet10

Hi, thanks for sharing your code. When I use the pertrained baseline++ with ResNet10 (downloaded according to Readme), I can always get a very good result combined with gnn (around 63% for 1 shot on MiniImageNet). But when I try training ResNet10 by myself using train_baseline.py, the result combined with gnnnet will drop 4-5% (I tried 3 times with different baseline++ and only to get 58%-59%). Do you have any idea about the possible reason?

baseline

The baseline and baseline++ can not be unzipped now, I can not figure out why it happened. If you can help me, I will be appreciated for that.

Auxiliary Loss

Hi,

I saw you updated your code and added auxiliary classifier. Why can it stabilize the training? Where did you get this idea?

result for protonet

Hi,

In https://openreview.net/forum?id=SJl5Np4tPr, I noticed that you've also done some experiments with Prototypical Network. However, I tried to run this experiment (only training on miniImagenet with pre-defined parameters) using your training code by passing an argument "--method protonet" and the results are quite low (1 shot 50%, 5 shot 63% on miniImagenet). Could you please tell me how you conducted the experiments with protonet?

Thanks.

GG! best accuracy 0.000000 in Feature encoder pre-training stage

command:

CUDA_VISIBLE_DEVICES=0 python3 train_baseline.py --method baseline++ --dataset cars --name baseline++ --train_aug

output:

Epoch 399 | Batch 75/257 | Loss 0.226424
Epoch 399 | Batch 100/257 | Loss 0.233635
Epoch 399 | Batch 125/257 | Loss 0.232594
Epoch 399 | Batch 150/257 | Loss 0.234371
Epoch 399 | Batch 175/257 | Loss 0.231336
Epoch 399 | Batch 200/257 | Loss 0.230566
Epoch 399 | Batch 225/257 | Loss 0.225214
Epoch 399 | Batch 250/257 | Loss 0.222642
GG! best accuracy 0.000000

No matter use PyTorch 1.4 or the latest version(PyTorch 1.7), I got the same problem.

train.py problem??

sorry,when i run train.py, there is a bug
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
could you please tell me why that happend?

Cannot get CUB_200_2011 dataset from "python3 process.py cub"

As mentioned in the title

Since Caltech move the dataset to Google drive, we can't use the command wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz to get the dataset

So I use the following method to download the dataset and manually unzip the file and then run write_cub_filelist.py

$ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" -O CUB_200_2011.tgz && rm -rf /tmp/cookies.txt
$ tar -zxf CUB_200_2011.tgz
$ python3 write_cub_filelist.py

Do you have any idea on handle this problem?

pseudo seen and pseudo unseen domains

Hi,
I found in your code that every epoch you split the ps and pu domains by random.sample(base_set, k=2), so you only use one domain as the ps domain and one another domain as the pu domain during your experiment?

Plantae Dataset NOT Available?

hello, I tried to get plantae dataset by link you provided in process.py but the website it points responded with "404 not found". Then I tried to get this dataset from amazon, i.e. inaturalist 2018 competition but it's too large. More surprisingly, the dataset I got from the link it provided was broken and can not be unzipped. It seems that the tar format is invalidated or the data is just broken. Do you have any other approaches to get this dataset or can I get inaturalist-12K dataset(https://www.kaggle.com/datasets/aryanpandey1109/inaturalist12k) instead to reproduce your experiment? Tks!

About ft_optim grad from ft_loss

Hi, I reproduced your code and found that ft_loss did not produce a gradient in the film layer, so how does your learning to learn update ft_optim?

which is the inner loop?

Hi,

I read your papers several times and I think optimizing the model params is the inner loop and optimizing the ft layers is the outer loop. But I am not 100% sure. What do you think?

About the Plantae dataset

Hi,
I downloaded the Plantae dataset from your link in the code, and I found that there are more than 2,900 classes. However, I checked the referenced paper where I saw the Plantae has only 2100 classes. Could you please tell me where how did you get the Plantae dataset?
Thx

Fail to download/process database

  1. I fail to download dataset cub with your code for the official website for downloading dataset returns 404 files not found...
  2. I fail to run train_baseline.py with the terminal command:
    python3 train_baseline.py --method gnnnet --dataset multi --testset cub --name multi_TESTSET_ori_METHOD --warmup model722
    The program raises an exception which tells that it cant' find some images in folder cars_train, then i find the problem that the process.py run "wget" to download database cars and however the tgz was named cars_train.tgz.1 which causes the problem.

I report these issues to help and let you be awareness of it , best wishes.

code error in dataset py

Hello, when I run the code train_ baseline.py, I use multiple domains during training, but the code reports an error。

File "/home/wangchengcheng/WCC/CrossDomainFewShot/train_baseline.py", line 96, in
base_loader = base_datamgr.get_data_loader( base_file , aug=params.train_aug )
File "/home/wangchengcheng/WCC/CrossDomainFewShot/data/datamgr.py", line 57, in get_data_loader
dataset = SimpleDataset(data_file, transform)
File "/home/wangchengcheng/WCC/CrossDomainFewShot/data/dataset.py", line 15, in init
with open(data_file, 'r') as f:
TypeError: expected str, bytes or os.PathLike object, not list

The error code is in dataset py

class SimpleDataset:
def init(self, data_file, transform, target_transform=identity):
with open(data_file, 'r') as f:
self.meta = json.load(f)
self.transform = transform
self.target_transform = target_transform

I saw data_file is list , so I reported an error. Maybe you have corrected this problem, but there is a problem with the code version. Thank you for your reply

about training problem.

Thanks for your sharing.
There is a question about should i train the pre-trained encoder from 0 ,or i can transfer the weights of resnet10/18/34 to train the model. Will that have any impact? Such as decreased accuracy, overfitting?
And is ResNet34 performs worse than resnet10 or 18?
I just got in touch with this cross-domain not long,hope you can reply.It will help me a lot. Thank you so much!

About the training time

Hi, @hytseng0509. Thanks for your insightful work! I am now running the Learning-to-learn LFT on multi-dataset with the guideline of your codes train.py. Could you tell me the average training time for this training process? Very appreciated!

Question about model training.

Hello Dr. Tseng, Thanks for releasing your code for LFT. I am introducing your repo to reimplement the paper. I wonder whether you use some tricks, for example learning rate decay, to train the model or not. And which model do you utilize to test, the "best_model.tar", "400.tar" or something else? Thanks for your help.

when Training with multiple seen domains

when I run (python3 train_baseline.py --method relationnet_softmax --dataset multi --testset cars --name multi_cars_ori_relationnet_softmax --warmup baseline --train_aug)
I got the
1664371000(1)
but when I run (python3 train_baseline.py --method gnnnet --dataset multi --testset cars --name multi_cars_ori_gnnnet --warmup PRETRAIN --train_aug)
I got the
1664371222(1)

Ask about the use of some codes

Hello, thank you for your answer last time. My code ability is relatively poor, and some parts are not very clear. Could you please explain it to me
1

This is a two-dimensional convolution method. I think you choose to use nn.Conv2d or nn. functional.Conv2d according to these two parameters self.weight.fast,self.bias.fast , but I don't understand what these two parameters mean, According to the function description, it seems to be the tensor of filter,tensor of f bias,but i dont konw why do it.The code I usually see is to directly use nn.conv2d

2
I understand the meaning of this method roughly. It is based on whether the name of the network parameter contains gamma, beta. the parameter is divided into model and FW. What I don't understand is why I do this. you saved these data. It seems that you used these data in the later training.
@hytseng0509

gnn baseline seems not working

the loss seem to converge on bad state, even after the first few epochs and the validation result is equivalent to random guessing. And I waited about 40 epochs, seems to me that it won't work even if I train a few hundred more epochs. I train the model from scratch instead of pretrained parameters.

problem about download_models.py

why it is if len(sys.argv) != 6: rather than if len(sys.argv) != 4:?
i use python3 download_models.py cars 1 matchingnet but it is not right.
could you please tell me the reason why it is 6 and give a correct examlp?

How to get table 1

Hi,
It seems there is no code corresponding to table 1. How to get the results in table 1? And how to get the results with symbols '-', 'FT' and 'LFT' in that table?

Thanks.

speed up verifying the code of training on multi-domains

Hi,
I found a few bugs in your code, for example,

  1. m is not used.
    for m in model.modules():
  2. There is no 'feat_aug' in input args
    model = protonet.ProtoNet( model_dict[params.model], feat_aug=params.feat_aug, tf_path=params.tf_dir, **train_few_shot_params)

    Could you plz spend some efforts to improve the code and make sure the training on multi-domain works? Thanks.

About the pre-trained backbone

Hello,
Thanks for your work! I have one question for the pre-trained Resnet-10 backbone: did you use different hyper parameters to train this backbone or we can just get the same by running the code "train_baseline.py" where I see you use default Adam optimizer training 400 epochs?

more detail explanation for "create_graph=True" and "weight.fast"

LFTNet.py, "update model parameters according to model_loss"
meta_grad = torch.autograd.grad(model_loss, self.split_model_parameters()[0], create_graph=True) for k, weight in enumerate(self.split_model_parameters()[0]): weight.fast = weight - self.model_optim.param_groups[0]['lr']*meta_grad[k] meta_grad = [g.detach() for g in meta_grad]
What's the purpose of adding "create_graph=True"? Why the weight.fast is updated rather than weight?
Does this have anything to do with "ft_loss.backward()"?

  Could you please give me more detailed explanation? Thanks!

Question about Training w/ learning-to-learned feature-wise transformations.

Hello Dr. Tseng, Thanks for releasing your code for LFT.
Namespace(data_dir='./filelists', dataset='multi', method='gnnnet', model='ResNet12', n_shot=5, name='tmp', num_classes=200, resume='resnet12/', resume_epoch=-1, save_dir='./output', save_freq=1, start_epoch=0, stop_epoch=400, test_n_way=5, testset='plantae', train_aug=True, train_n_way=5, warmup='gg3b0')
Traceback (most recent call last):
File "train.py", line 103, in
start_epoch = model.resume(resume_file)
File "/home/huangjianping/CrossDomainFewShot-master/methods/LFTNet.py", line 225, in resume
self.model.load_state_dict(state['model_state'])
KeyError: 'model_state'
help me!

Some questions about the code.

As we can see, the line 280 and the line 283 of the file 'methods/backbone.py' mean that we use the Feature Wise Transformation module in MAML not the metric-based models. But this contradicts the paper, doesn't it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.