lins-lab / ttab Goto Github PK
View Code? Open in Web Editor NEW[ICML23] On Pitfalls of Test-Time Adaptation
Home Page: https://arxiv.org/abs/2306.03536
License: Apache License 2.0
[ICML23] On Pitfalls of Test-Time Adaptation
Home Page: https://arxiv.org/abs/2306.03536
License: Apache License 2.0
Thanks again for your marvelous job. After running run_exps.py, there're lots of json files in ./logs folder. Are there written tools to analysize the results? Eg. to get the mean acc over all the corruptions for cifar10c?
Also, after running run_exps.py, there're lots of tmux sessions with only one window. It bothers to kill all the related sessions afterwards when there're other tmux sesions related to other experiments. Is there anything wrong with my usage of the code? Looking forward to a more detailed doc for all the useful tools in the repository.
Thanks very much.
The error is noticed when adapting models from source domain realworld to target domain art. Domain art has smaller size of data.
In your source code (ttab.loads.datasets.dataset.py line 332), the method to load a target domain is to load the source domain first, then replace the data and targets (ttab.loads.datasets.dataset_shift.py). The data_size and indices are not updated, so len(dataset.data) != len(dataset). The iterator still uses larger indices of realworld to get items so it cause that error.
Hi. Thanks for the fantastic work. :D . Is there a leaderboard for each setting to show the performance of each method along with the corresponding hyperparameters? Or is there a way to submit the tuned results?
Thanks a lot.
Hello, I noticed a new test-time adaptation method called Robust Test-Time Adaptation (RoTTA) proposed in this paper
https://arxiv.org/abs/2303.13899.
According to the paper, RoTTA achieves state-of-the-art performance and outperforms some existing methods like TENT, CoTTA, NOTE, etc.
I tried integrating RoTTA into the codebase myself. However, when I tested it, the accuracy rate I got was much lower than reported in the original RoTTA paper . I tried to debug it but couldn't get the accuracy up. I wonder if you have considered integrating RoTTA? As the original authors of the codebase, you may have better insights on how to properly integrate RoTTA.
Please let me know if you would be interested in integrating RoTTA. I'm happy to provide more details on the issues I encountered. I think RoTTA could further improve the codebase's adaptation capabilities.
Thank you for your attention! Please let me know if you need any other information from me.
Dear authors, I'm using your pre-trained model code ssl_pretrain.py and hope to apply it on ImageNet dataset. However, I noticed currently the code does not support ImageNet dataset.Would it be possible to add the support? Please let me know if you have any plans for this. Thanks!
Hi,
Based on my understanding of these two settings, HomogeneousNoMixture will pass test samples within one domain and the other without mixing them together, but CrossMixture will first merge two datasets then pass the mixed test samples to model, so the data batches send to model is largely different, but when I print the logits out from one_adapt_step function,
ttab/ttab/model_adaptation/tent.py
Line 72 in 815fc94
ttab/ttab/model_adaptation/tent.py
Line 83 in 815fc94
Could you help me find out what's the reason for the same output?
I want to train resnet18 on ImageNet datasets.
And I change class CIFARDataset
into ImageNetDataset
in pretrain_self_supervised_cifar10.py
file.
Here is my config
config = {
# general
"seed": 2022,
"ckpt_path": "/data/TTA-exp/ttab/pretrain/ckpt",
"device": "cuda:0",
"log_dir": "/data/TTA-exp/ttab/data/runs",
# data
"data_path": "/data/TTA-exp/ttab/datasets",
"data_name": "imagenet",
"num_classes": 15,
# model
"model_name": "resnet18",
"task_name": "classification",
"resume": False,
# hyperparams
"entry_of_shared_layers": "layer3",
"dim_out": 4,
"use_iabn": False,
"iabn_k": 4,
"use_ls": True, # label smoothing: https://arxiv.org/abs/1906.02629
"threshold_note": 1,
"rotation_type": "expand",
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 5e-4,
"maxEpochs": 150,
"milestones": [75, 125],
"batch_size": 128,
"save_epoch": 10, # save weights file per SAVE_EPOCH epoch.
}
When we are training resnet18 on ImageNet dataset.
We get this error.
And then we change line 310 in ttab/loads/models/resnet.py self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
into self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
, then we get this error.
Moreover, We used the model widerest28_10
still get this error.
In TTAB, we trained WideResnet28*10 with Cifar-10, which is also used in the original paper of CoTTA. Nevertheless, the accuracy of using CoTTA to predict data Snow-5 is 51% in TTAB; the accuracy of using CoTTA to predict data Snow-5 is 85% in the original paper. So why there is a huge gap between the two results?
Dear authors,
I have requests on several things.
We run the code in the following environment:
"--model_adaptation_method"——— "note"
"--model_selection_method"——— "last_iterate"
"--model_selection_method"——— "cifar10"
"--model_name"——— "resnet26"
"--episodic"——— "false"
"--data_names"——— ("cifar10_c_deterministic-snow-5;"
"cifar10_c_deterministic-brightness-5;"
"cifar10_c_deterministic-fog-5;"
"cifar10_c_deterministic-frost-5;"
"cifar10_c_deterministic-contrast-5;"
"cifar10_c_deterministic-motion_blur-5;"
"cifar10_c_deterministic-glass_blur-5;"
"cifar10_c_deterministic-zoom_blur-5;"
"cifar10_c_deterministic-gaussian_noise-5;"
"cifar10_c_deterministic-shot_noise-5;"
"cifar10_c_deterministic-jpeg_compression-5;"
"cifar10_c_deterministic-impulse_noise-5;"
"cifar10_c_deterministic-pixelate-5;"
"cifar10_c_deterministic-elastic_transform-5;"
"cifar10_c_deterministic-defocus_blur-5",)
"--batch_size"—— 100
"--lr"—— 1e-4
"--n_train_steps"——— 1
"--inter_domain"———“HomogeneousNoMixture”
The error rate of note is 41%, which is quite different from the result in table 2 "NOTE-online"(24.0 ± 0.1) in this paper. Is this issue caused by the difference between our experimental environment and the setting of the original paper?Or there are other reasons?
Dear authors,
I have requests on several things.
There are two "--inter_domain" in the code: HeterogeneousNoMixture, CrossMixture. CrossMiture shuffles data across domains; HeterogeneousNoMixture constructs label shfit (non-iid) over domains.
I want to mix the multi domains first, then construct label shift.
So I do the following two steps:
Thank you for creating a great benchmark.
I have one question on the implementation.
python3 run_exp.py --model_adaptation_method cotta
When I try to use CoTTA with this parameter, I get the following error.
TypeError: __init__() got an unexpected keyword argument 'resample'
Line 125 in cotta.py,
transforms.RandomAffine(
degrees=[-8, 8] if soft else [-15, 15],
translate=(1 / 16, 1 / 16),
scale=(0.95, 1.05) if soft else (0.9, 1.1),
shear=None,
resample=PIL.Image.BILINEAR,
fillcolor=None,
),
By looking around, the official pytorch website says
CLASStorchvision.transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, center=None)
and there is no resample
argument.
Thanks in advance.
Hi, thanks for your excellent work!
To ensure the accuracy of my replication efforts, I am kindly requesting access to the pretrained checkpoints you used in your experiments on the Cifar100 dataset. I have tried to use "python ssl_pretrain.py --data-name cifar100 --model-name resnet26" to pre-train a Cifar100 model, but it can not reproduce the results in the paper.
Dear authors,
thank you for your great works!
Now, i'm trying to reproduce some results in your paper under the TTAB benchmark.
But, the results which i have got are not consistent with those in your paper..
Especially, the figure below is our reproduced version of Figure 3(a) with lr=0.01 (green curve in the paper).
You can see that the result is quite different to the figure 3(a)..
In the experiment, i use following configurations based on your paper.
ckpt_path:"./pretrain/ckpt/resnet26_bn_ssh_cifar10.pth" (this ckpt is what you provided in google drive)
seed:2022
model_adaptation_method:"shot"
model_selection_method:"last_iterate"
data_names:"cifar10_c_deterministic-gaussian_noise-5"
lr:0.01
n_train_steps:1
(The others are same as the base setting in TTAB implementations)
Can you provide more details about the configurations for figure 3(a) and 3(b)?
Thank you!
What is the group_norm_num_group for resnet26_gn ? Can you provide the pretrain command line for gn model in readme? Thanks for the marvelous job!
Dear authors,
I have requests on several things I already sent an email to author.
My requests are just for reproducing the results in Table 2 of your paper.
I hope this requests do not disturb you much.
Thank you.
Best,
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.