oatml / non-parametric-transformers Goto Github PK
View Code? Open in Web Editor NEWCode for "Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning"
License: Apache License 2.0
Code for "Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning"
License: Apache License 2.0
Hi y’all, awesome paper! I was digging through the code but can’t find the details for the CIFAR10 experiments using the ResNet18 encoder. Are there any details on what ResNet code was used for that. In particular: were the images upscale from 32x32? Is the architecture the same as the original ResNet paper for CIFAR10 experiments (not ImageNet)? Is the training recipe the same?
This is easily fixable but a bit frustrating that someone didn't just run pip freeze.
Traceback (most recent call last):
File "run.py", line 12, in <module>
from npt.column_encoding_dataset import ColumnEncodingDataset
File "/home/cottrell/dev/non-parametric-transformers/npt/column_encoding_dataset.py", line 10, in <module>
from npt.datasets.boston_housing import BostonHousingDataset
File "/home/cottrell/dev/non-parametric-transformers/npt/datasets/boston_housing.py", line 2, in <module>
from sklearn.datasets import load_boston
File "/home/cottrell/anaconda3/envs/npt/lib/python3.8/site-packages/sklearn/datasets/__init__.py", line 157, in __getattr__
raise ImportError(msg)
ImportError:
`load_boston` has been removed from scikit-learn since version 1.2.
The documentation refers to what appear to be non-existant entities like:
$ git grep data_loaders
npt/model/npt.py: expected to provide datasets as given in `npt.data_loaders`.)
The paper description appears to not entirely be consistent with the code, at least not obviously so.
In non-parametric-transformers/npt/model/npt.py lines 145 and 149:
self.c.embedding_layer_norm should be self.c.model_embedding_layer_norm to match the config file.
Hi @jlko
great results and repository! Using the provided environment file, I get the following error when running the examples:
.../non-parametric-transformers/npt/loss.py", line 5, in
from pytorch_lightning.metrics.functional import auroc as lightning_auroc
ModuleNotFoundError: No module named 'pytorch_lightning.metrics'
My environment has pytorch lightning 1.6.0, and it seems to me the repository structure had changed, Could you provide the correct version of the package, perhaps update the environment.yml for compatibility?
Thanks!
When using the conda env I get the above error.
Fixed locally by replacing "non-parametric-transformers/npt/utils/batch_utils.py",
line 7:
from torch._six import container_abcs
with
import collections.abc as container_abcs
After the environment setup, I could successfully train the NPT model on the table dataset with the following command.
python run.py --data_set boston-housing --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.5 --exp_optimizer lookahead_lamb --metrics_auroc False --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 2000 --exp_gradient_clipping 1 --exp_batch_size -1 --model_dim_hidden 128 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --exp_cache_cadence 1 --model_checkpoint_key boston__bs_-1__feature_mask --exp_n_runs 10 --exp_test_perc 0.1 --exp_val_perc 0.2
When I run cifar10
experiment with the following command, I encounter OOM after iterating some batches.
python run.py --data_set cifar10 --exp_eval_every_n 1 --exp_eval_every_epoch_or_steps epochs --exp_optimizer_warmup_proportion 0.7 --exp_optimizer lookahead_lamb --metrics_auroc False --exp_print_every_nth_forward 100 --model_num_heads 8 --model_stacking_depth 8 --exp_lr 1e-3 --exp_scheduler flat_and_anneal --exp_num_total_steps 1000000 --exp_gradient_clipping 1 --exp_batch_size 512 --model_dim_hidden 64 --model_augmentation_bert_mask_prob "dict(train=0.15, val=0, test=0)" --exp_tradeoff 1 --exp_tradeoff_annealing cosine --model_checkpoint_key cifar10__bs_512__feature_mask__linear__flip_crop --data_force_reload True --model_image_n_patches 64 --model_image_patch_type linear --model_image_n_channels 3 --model_image_n_classes 10 --model_image_random_crop_and_flip True
I can see the usage of GPU memory increase when iterating, thus it looks like a memory leak bug.
Below is the output log:
Configuring arguments...
Reading dict for model_augmentation_bert_mask_prob.
Running model with CUDA
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose 'Don't visualize my results'
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Disabling AUROC metric.
Using random crops and flips in data augmentation.
Files already downloaded and verified
Files already downloaded and verified
Detected cifar10 dataset. Setting num_classes = 10.
Constructed train, val, test binary matrices with n_targets:
tensor(45000)
tensor(5000)
tensor(10000)
CV Index: 0
Train-test Split 1/5
c.exp_n_runs = 1. Stopping at 1 splits.
Building NPT.
Using feature type embedding (unique embedding for categorical and numerical features).
Using feature index embedding (unique embedding for each column).
Clipping gradients to value 1.0.
Model has 900355706 parameters,batch size 512.
Initialized "lookahead_lamb" optimizer.
Warming up for 700000.0/1000000.0 steps.
/usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:136: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
Initialized "flat_and_anneal" learning rate scheduler.
Initialized "cosine" augmentation/label tradeoff annealer. Annealing to minimum value in 1000000 steps.
Disabled AUROC in loss module.
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported.
warnings.warn("Setting attributes on ParameterList is not supported.")
Loading image dataset at epoch 1.
Loaded image dataset with 45000 train rows | 5000 val rows | 10000 test rows |
Finished loading.
Traceback (most recent call last):
File "run.py", line 253, in <module>
main(args)
File "run.py", line 23, in main
run_cv(args=args, wandb_args=wandb_args)
File "run.py", line 117, in run_cv
run_cv_splits(wandb_args, args, c, wandb_run)
File "run.py", line 195, in run_cv_splits
trainer.train_and_eval()
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/train.py", line 218, in train_and_eval
if self.per_epoch_train_eval(epoch=epoch):
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/train.py", line 174, in per_epoch_train_eval
eval_model=False)
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/train.py", line 441, in run_epoch
epoch, print_n, batch_index)
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/train.py", line 551, in run_batch
self.forward_and_loss(**forward_kwargs)
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/train.py", line 606, in forward_and_loss
output = self.model(masked_tensors, **extra_args) # 模型 forward 的输入接口
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/model/npt.py", line 431, in forward
X = self.enc(X)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/model/npt_modules.py", line 237, in forward
return self.mab(X, X)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/zengdajun11new/data/zjh/non-parametric-transformers/npt/model/npt_modules.py", line 168, in forward
multihead = torch.cat(multihead.split(Q.size(0), 0), 2)
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 31.75 GiB total capacity; 28.66 GiB already allocated; 15.75 MiB free; 30.54 GiB reserved in total by PyTorch)
wandb: Waiting for W&B process to finish, PID 4826... (failed 1).
wandb: You can sync this run to the cloud by running:
wandb: wandb sync ./wandb/offline-run-20230104_133104-1f12jy2j
wandb: Find logs at: ./wandb/offline-run-20230104_133104-1f12jy2j/logs/debug.log
wandb:
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.