Coder Social home page Coder Social logo

Regarding transforms about trojanzoo HOT 5 CLOSED

ain-soph avatar ain-soph commented on May 24, 2024
Regarding transforms

from trojanzoo.

Comments (5)

ain-soph avatar ain-soph commented on May 24, 2024 1

Close this issue if you have no further concern.

from trojanzoo.

ain-soph avatar ain-soph commented on May 24, 2024

Actually I never test MNIST, let me know if any issue exists. The transform for MNIST is a simple ToTensor
And note that the normalization parameters are embedded into model as a layer rather than dataset transform. So that all image tensors range in [0, 1]

elif self.data_shape in ([3, 16, 16], [3, 32, 32]):
transform = get_transform_cifar(mode, auto_augment=self.auto_augment,
cutout=self.cutout, cutout_length=self.cutout_length,
data_shape=self.data_shape)
else:
transform = transforms.ToTensor()

def get_transform_cifar(mode: str, auto_augment: bool = False,
cutout: bool = False, cutout_length: int = None,
data_shape: list[int] = [3, 32, 32]) -> transforms.Compose:
if mode != 'train':
return transforms.ToTensor()
cutout_length = data_shape[-1] // 2 if cutout_length is None else cutout_length
transform_list = [
transforms.RandomCrop(data_shape[-2:], padding=data_shape[-1] // 8),
transforms.RandomHorizontalFlip(),
]
if auto_augment:
transform_list.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10))
transform_list.append(transforms.ToTensor())
if cutout:
transform_list.append(Cutout(cutout_length))
return transforms.Compose(transform_list)

from trojanzoo.

ain-soph avatar ain-soph commented on May 24, 2024
CUDA_VISIBLE_DEVICES=0 python examples/train.py --verbose 1 --color --epoch 600 --batch_size 96 --cutout --grad_clip 5.0 --lr 0.025 --lr_scheduler --save --dataset cifar10 --model resnet18_comp

And you will get ResNet18 (first convolutional layer compressed version) with 96.5% accuracy.

from trojanzoo.

agSidharth avatar agSidharth commented on May 24, 2024

ya so I was testing these models manually and the transform that worked for me is

transform.Compose([transforms.ToTensor(),transforms.Normalize([0.49139968, 0.48215827, 0.44653124],[0.24703233, 0.24348505, 0.26158768]))

To give the expected accuracy.

from trojanzoo.

ain-soph avatar ain-soph commented on May 24, 2024

I remembered that Without data augment such as random crop and cutout, the model accuracy won’t exceed 92%. But I could be wrong.

And if you use my model class and put normalization into the transform, you’d better set the model transform layer mean/std to be 0/1.

from trojanzoo.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.