Coder Social home page Coder Social logo

lins-lab / ttab Goto Github PK

View Code? Open in Web Editor NEW
98.0 2.0 9.0 2.07 MB

[ICML23] On Pitfalls of Test-Time Adaptation

Home Page: https://arxiv.org/abs/2306.03536

License: Apache License 2.0

Python 99.30% Jupyter Notebook 0.70%
source-free-domain-adaptation test-time-adaptation

ttab's People

Contributors

marcelluszhao avatar tlin-taolin avatar ymaster7 avatar yuejiangliu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ttab's Issues

Is there guides to analysize the results for run_exps.py ?

Thanks again for your marvelous job. After running run_exps.py, there're lots of json files in ./logs folder. Are there written tools to analysize the results? Eg. to get the mean acc over all the corruptions for cifar10c?

Also, after running run_exps.py, there're lots of tmux sessions with only one window. It bothers to kill all the related sessions afterwards when there're other tmux sesions related to other experiments. Is there anything wrong with my usage of the code? Looking forward to a more detailed doc for all the useful tools in the repository.

Thanks very much.

Different dataset sizes cause list out of range error on OfficeHome experiment

The error is noticed when adapting models from source domain realworld to target domain art. Domain art has smaller size of data.

In your source code (ttab.loads.datasets.dataset.py line 332), the method to load a target domain is to load the source domain first, then replace the data and targets (ttab.loads.datasets.dataset_shift.py). The data_size and indices are not updated, so len(dataset.data) != len(dataset). The iterator still uses larger indices of realworld to get items so it cause that error.

Consider integrating RoTTA

Hello, I noticed a new test-time adaptation method called Robust Test-Time Adaptation (RoTTA) proposed in this paper
https://arxiv.org/abs/2303.13899.
According to the paper, RoTTA achieves state-of-the-art performance and outperforms some existing methods like TENT, CoTTA, NOTE, etc.

I tried integrating RoTTA into the codebase myself. However, when I tested it, the accuracy rate I got was much lower than reported in the original RoTTA paper . I tried to debug it but couldn't get the accuracy up. I wonder if you have considered integrating RoTTA? As the original authors of the codebase, you may have better insights on how to properly integrate RoTTA.

Please let me know if you would be interested in integrating RoTTA. I'm happy to provide more details on the issues I encountered. I think RoTTA could further improve the codebase's adaptation capabilities.

Thank you for your attention! Please let me know if you need any other information from me.

I have one question about CrossMixture and HomogeneousNoMixture settings

Hi,
Based on my understanding of these two settings, HomogeneousNoMixture will pass test samples within one domain and the other without mixing them together, but CrossMixture will first merge two datasets then pass the mixed test samples to model, so the data batches send to model is largely different, but when I print the logits out from one_adapt_step function,

def one_adapt_step(

y_hat = model(batch._x)

I found out that no matter I use CrossMixture or HomogeneousNoMixture setting, the output is exactly same, so the input is same as well. Which is really weird to me since I suppose the input would be really different.
But based on the final result the acc after adaptation is indeed different between these two settings so the setting indeed works, I pasted two print_out result below so that you can understand my question.

Could you help me find out what's the reason for the same output?

image
image

The unexpected error of training resnet and wideresnet on imageNet dataset.

I want to train resnet18 on ImageNet datasets.
And I change class CIFARDataset into ImageNetDataset in pretrain_self_supervised_cifar10.py file.
Here is my config

config = {
    # general
    "seed": 2022,
    "ckpt_path": "/data/TTA-exp/ttab/pretrain/ckpt",
    "device": "cuda:0",
    "log_dir": "/data/TTA-exp/ttab/data/runs",
    # data
    "data_path": "/data/TTA-exp/ttab/datasets",
    "data_name": "imagenet",
    "num_classes": 15,
    # model
    "model_name": "resnet18",
    "task_name": "classification",
    "resume": False,
    # hyperparams
    "entry_of_shared_layers": "layer3",
    "dim_out": 4,
    "use_iabn": False,
    "iabn_k": 4,
    "use_ls": True,  # label smoothing: https://arxiv.org/abs/1906.02629
    "threshold_note": 1,
    "rotation_type": "expand",
    "lr": 0.1,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "maxEpochs": 150,
    "milestones": [75, 125],
    "batch_size": 128,
    "save_epoch": 10,  # save weights file per SAVE_EPOCH epoch.
}

When we are training resnet18 on ImageNet dataset.
We get this error.
2f43006757be838a0d1be249fe7dd85

And then we change line 310 in ttab/loads/models/resnet.py self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1) into self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1), then we get this error.
27fa018cd360690e066071c7098e49e

Moreover, We used the model widerest28_10 still get this error.
ea68ba75f61ff2a173e1bba3a788877

The experimental results of note are inconsistent with the results in this paper

Dear authors,
I have requests on several things.
We run the code in the following environment:
"--model_adaptation_method"——— "note"
"--model_selection_method"——— "last_iterate"
"--model_selection_method"——— "cifar10"
"--model_name"——— "resnet26"
"--episodic"——— "false"
"--data_names"——— ("cifar10_c_deterministic-snow-5;"
"cifar10_c_deterministic-brightness-5;"
"cifar10_c_deterministic-fog-5;"
"cifar10_c_deterministic-frost-5;"
"cifar10_c_deterministic-contrast-5;"
"cifar10_c_deterministic-motion_blur-5;"
"cifar10_c_deterministic-glass_blur-5;"
"cifar10_c_deterministic-zoom_blur-5;"
"cifar10_c_deterministic-gaussian_noise-5;"
"cifar10_c_deterministic-shot_noise-5;"
"cifar10_c_deterministic-jpeg_compression-5;"
"cifar10_c_deterministic-impulse_noise-5;"
"cifar10_c_deterministic-pixelate-5;"
"cifar10_c_deterministic-elastic_transform-5;"
"cifar10_c_deterministic-defocus_blur-5",)
"--batch_size"—— 100
"--lr"—— 1e-4
"--n_train_steps"——— 1
"--inter_domain"———“HomogeneousNoMixture”
The error rate of note is 41%, which is quite different from the result in table 2 "NOTE-online"(24.0 ± 0.1) in this paper. Is this issue caused by the difference between our experimental environment and the setting of the original paper?Or there are other reasons?

Dear authors, I have requests on several things I already sent an email to author.

Dear authors,
I have requests on several things.
There are two "--inter_domain" in the code: HeterogeneousNoMixture, CrossMixture. CrossMiture shuffles data across domains; HeterogeneousNoMixture constructs label shfit (non-iid) over domains.
I want to mix the multi domains first, then construct label shift.
So I do the following two steps:

  1. data = self._intra_shuffle_dataset(
    self._merge_datasets(test_datasets), random_seed=random_seed)
  2. data2 = self._intra_non_iid_shift(
    dataset=data,
    non_iid_pattern=test_case.inter_domain.non_iid_pattern,
    non_iid_ness=test_case.inter_domain.non_iid_ness,
    random_seed=random_seed,
    )
    return data2;
    However, after step 1, step 2 cannot success (the label shift fails to be constructed).
    Could you please help me solve this issue? Or add some annotation to " _intra_non_iid_shift" .

The CoTTA code is not working properly.

Thank you for creating a great benchmark.

I have one question on the implementation.

python3 run_exp.py --model_adaptation_method cotta

When I try to use CoTTA with this parameter, I get the following error.

TypeError: __init__() got an unexpected keyword argument 'resample'

Line 125 in cotta.py,

transforms.RandomAffine(
    degrees=[-8, 8] if soft else [-15, 15],
    translate=(1 / 16, 1 / 16),
    scale=(0.95, 1.05) if soft else (0.9, 1.1),
    shear=None,
    resample=PIL.Image.BILINEAR,
    fillcolor=None,
),

By looking around, the official pytorch website says

CLASStorchvision.transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, center=None)

and there is no resample argument.

Thanks in advance.

Request for Checkpoints on Cifar100 Dataset.

Hi, thanks for your excellent work!
To ensure the accuracy of my replication efforts, I am kindly requesting access to the pretrained checkpoints you used in your experiments on the Cifar100 dataset. I have tried to use "python ssl_pretrain.py --data-name cifar100 --model-name resnet26" to pre-train a Cifar100 model, but it can not reproduce the results in the paper.

Reproducing Figure 3 results in paper

Dear authors,
thank you for your great works!
Now, i'm trying to reproduce some results in your paper under the TTAB benchmark.

But, the results which i have got are not consistent with those in your paper..

Especially, the figure below is our reproduced version of Figure 3(a) with lr=0.01 (green curve in the paper).
fig3_repr

You can see that the result is quite different to the figure 3(a)..
In the experiment, i use following configurations based on your paper.

ckpt_path:"./pretrain/ckpt/resnet26_bn_ssh_cifar10.pth" (this ckpt is what you provided in google drive)
seed:2022
model_adaptation_method:"shot"
model_selection_method:"last_iterate"
data_names:"cifar10_c_deterministic-gaussian_noise-5"
lr:0.01
n_train_steps:1
(The others are same as the base setting in TTAB implementations)

Can you provide more details about the configurations for figure 3(a) and 3(b)?

Thank you!

Requests on pretrain code and experimental settings for other datasets

Dear authors,

I have requests on several things I already sent an email to author.

  1. Could you release the pretraining code for the datasets other than CIFAR10?
    • It would be pleasure if you provide the pretrain code and its corresponding experimental settings such as epochs, architecture, and learning rate on Office-Home or PACS.
  2. Could you release 'parameter.py' for each model/dataset in terms of standard settings ?
    • the standard setting I said refers to the setting utilized in Table 2, including the pre-train codes (and settings) for each dataset.

My requests are just for reproducing the results in Table 2 of your paper.

I hope this requests do not disturb you much.

Thank you.
Best,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.