Coder Social home page Coder Social logo

labram's Introduction

LaBraM

This is the official implementation of our ICLR 2024 paper "Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI".

labram

Abstract

The current electroencephalogram (EEG) based deep learning models are typically designed for specific datasets and applications in brain-computer interaction (BCI), limiting the scale of the models and thus diminishing their perceptual capabilities and generalizability. Recently, Large Language Models (LLMs) have achieved unprecedented success in text processing, prompting us to explore the capabilities of Large EEG Models (LEMs). We hope that LEMs can break through the limitations of different task types of EEG datasets, and obtain universal perceptual capabilities of EEG signals through unsupervised pre-training. Then the models can be fine-tuned for different downstream tasks. However, compared to text data, the volume of EEG datasets is generally small and the format varies widely. For example, there can be mismatched numbers of electrodes, unequal length data samples, varied task designs, and low signal-to-noise ratio. To overcome these challenges, we propose a unified foundation model for EEG called Large Brain Model (LaBraM). LaBraM enables cross-dataset learning by segmenting the EEG signals into EEG channel patches. Vector-quantized neural spectrum prediction is used to train a semantically rich neural tokenizer that encodes continuous raw EEG channel patches into compact neural codes. We then pre-train neural Transformers by predicting the original neural codes for the masked EEG channel patches. The LaBraMs were pre-trained on about 2,500 hours of various types of EEG signals from around 20 datasets and validated on multiple different types of downstream tasks. Experiments on abnormal detection, event type classification, emotion recognition, and gait prediction show that our LaBraM outperforms all compared SOTA methods in their respective fields.

Environment Set Up

Install required packages:

conda create -n labram python=3.11
conda activate labram
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install tensorboardX
pip install -r requirements.txt

Run Experiments

Prepare pre-training data

You should transfer raw EEG files (such as .cnt, .edf, .bdf, and so on) into hdf5-format files using the example code in dataset_maker/make_h5dataset_for_pretrain.py. Notably, you can also write your own codes for preprocessing EEG data. Make sure that the preprocessing is consistent with that of our paper, that is, removing useless channels, filtering between 0.1 Hz and 75 Hz, notch filtering of 50 Hz, resampling to 200 Hz, and setting the unit to $\mu V$.

Train the neural tokenizer

The neural tokenizer is trained by vector-quantized neural spectrum prediction. It is recommended to train it on platforms with 8 * NVIDIA GeForce RTX 3090 or better GPUs.

OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=8 run_vqnsp_training.py \
    --output_dir ./checkpoints/vqnsp/ \
    --log_dir ./log/vqnsp/ \
    --model vqnsp_encoder_base_decoder_3x200x12 \
    --codebook_n_emd 8192 \
    --codebook_emd_dim 64 \
    --quantize_kmeans_init \
    --batch_size 128 \
    --opt adamw \
    --opt_betas 0.9 0.99 \
    --weight_decay 1e-4  \
    --warmup_epochs 10 \
    --epochs 100 \
    --save_ckpt_freq 20 

LaBraM pre-train

We pre-train LaBraM by predicting the original neural codes for the masked EEG channel patches.

OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=8 run_labram_pretraining.py \
        --output_dir ./checkpoints/labram_base \
        --log_dir ./log/labram_base \
        --model labram_base_patch200_1600_8k_vocab \
        --tokenizer_model vqnsp_encoder_base_decoder_3x200x12 \
        --tokenizer_weight ./checkpoints/vqnsp.pth \
        --batch_size 64 \
        --lr 5e-4 \
        --warmup_epochs 5 \
        --clip_grad 3.0 \
        --drop_path 0. \
        --layer_scale_init_value 0.1 \
        --opt_betas 0.9 0.98 \
        --opt_eps 1e-8  \
        --epochs 50 \
        --save_ckpt_freq 5 \
        --codebook_dim 64 \
        --gradient_accumulation_steps 1

Fine-tune on downstream tasks

Before fine-tuning, use the code in dataset_maker/(make_TUAB.py, make_TUEV.py) to preprocess the downstream datasets as well as split data into training, validation, and test set. Notably you are encouraged to try different hyperparameters, such as the learning rate and warmup_epochs which can largely influence the final performance, to get better results. Here is the hyperparameter we used in the paper:

OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=8 run_class_finetuning.py \
        --output_dir ./checkpoints/finetune_tuab_base/ \
        --log_dir ./log/finetune_tuab_base \
        --model labram_base_patch200_200 \
        --finetune ./checkpoints/labram-base.pth \
        --weight_decay 0.05 \
        --batch_size 64 \
        --lr 5e-4 \
        --update_freq 1 \
        --warmup_epochs 3 \
        --epochs 30 \
        --layer_decay 0.65 \
        --drop_path 0.1 \
        --dist_eval \
        --save_ckpt_freq 5 \
        --disable_rel_pos_bias \
        --abs_pos_emb \
        --dataset TUAB \
        --disable_qkv_bias \
        --seed 0

Citation

If you find our paper/code useful, please consider citing our work:

@inproceedings{
jiang2024large,
title={Large Brain Model for Learning Generic Representations with Tremendous {EEG} Data in {BCI}},
author={Wei-Bang Jiang and Li-Ming Zhao and Bao-Liang Lu},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=QzTpTRVtrP}
}

labram's People

Contributors

935963004 avatar itsaphel avatar

Stargazers

Rick den Otter avatar  avatar  avatar Bruce avatar eric avatar 7K avatar Kun&Qi avatar Lingyi Peng avatar Hao Zhang avatar Fan Xu avatar Uniguri avatar Tx Jukie Zhang avatar  avatar Francesco Pisu avatar Daehoon Gwak avatar Liu Mingyang avatar ziyuemu avatar Zhaoyang - SUTD avatar  avatar Wentao Lu avatar LiuYong avatar Yunhan Shi avatar wired87 avatar  avatar yang_yaoxian avatar kandakji avatar  avatar Monte Lubowitz avatar  avatar Lahmeri Mohamed Amine avatar minghao avatar Xiao Gu avatar XuZhang avatar  avatar Thomson avatar MZJ-11 avatar YuLongZJU avatar Liu Yunqi avatar  avatar  avatar Aze avatar  avatar Xinxu Wei avatar Wei Xue avatar Andreas Zamanos avatar Ahnaf Mozib Samin avatar Mohammadreza Hendiani avatar Ziyi Zhao avatar  avatar Autumnii avatar  avatar  avatar  avatar Morgan Hough avatar  avatar embneural avatar Amir M. Parvizi avatar Vladislav Sorokin avatar Shuo Feng avatar  avatar Aisu avatar Neeraj Wagh avatar Ze Wang avatar Boltzmachine avatar Yifei Yu avatar  avatar tian avatar Jianwei Zhao avatar  avatar Huai-Qiu-Zhang avatar  avatar Tianyu Wang avatar  avatar Michael avatar  avatar Edward Dong avatar ZHANG Zhi avatar Zorro Albert avatar snoop2head avatar Junha Park avatar  avatar  avatar  avatar Pablo Marcos avatar Niall McGuire avatar Peacekie avatar Yangius avatar Yufan Feng avatar  avatar zhaopeng avatar  avatar Yasuo Kabe avatar Kexin Lou avatar  avatar Learner avatar Minsuk Choi avatar  avatar Cheng-Yeh Chen avatar Debolina Das avatar Federico Zucchi avatar

Watchers

 avatar

labram's Issues

Preprocessing data

Hi!

Thanks for the exciting work. I have a question regarding the preprocessing of the pertaining datasets.

In the readme, you say, "Notably, you can also write your own codes for preprocessing EEG data. Make sure that the preprocessing is consistent with that of our paper, that is, removing useless channels, filtering between 0.1 Hz and 75 Hz, notch filtering of 50 Hz, resampling to 200 Hz, and setting the unit to uV".

I'm wondering if you have some example code of precisely this type of preprocessing and if not, what do you mean by "removing useless channels," i.e., what is a useless channel?

关于训练模型时的数据集加载问题

作者您好,根据您提供的代码,我制作了相应的多个数据集文件用于训练vqnsp,每个数据集单独保存为一个.hdf5文件。并且数据集借助仓库中提供的ShockDataset()类进行加载。但是在加载时发现,使用shuffle=True创建的Sampler会大大降低数据集加载的速度,导致GPU一直处于闲置的状态(数据集加载过慢)。当shuffle=False时,训练速度才恢复正常。我想这和dataloader对样本进行采样的逻辑有关,当随机采样时,需要频繁访问多个磁盘位置的数据;当按顺序采样时,只需要采样物理存储位置较近的相邻的数据即可。所以顺序采样加载数据集的速度更快。但是这其实不符合小批量随机梯度下降的训练逻辑。请问您在训练过程中有遇到这样的问题吗,不知您是如何解决的?

AttributeError: 'VQNSP' object has no attribute 'module'. Did you mean: 'modules'?

Hello! I get the following error while training the vqnsp.

raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'VQNSP' object has no attribute 'module'. Did you mean: 'modules'?

If I use "modules" instead of "module", then the code works. Is it required to update your code or am I missing something?

How to deal with data sets with different number of channels?

Thank you very much for the code. I would like to ask you, you mentioned in the code that you need to provide different time windows to ensure that the sequence length of each data set is equal to 256, but for data sets with the number of channels less than 64/32, do we need to complete it to 64/32? 32?

为什么TUAB数据集2000个采样为一个样本?

作者您好,非常感谢您的工作,我有一个疑问。那就是您的时间窗口是4或者8,但是在make_TUAB.py的脚本中,您隔2000个采样才取一个样本,而1s应该是200个采样,也就是时间窗口为10.请问这个设置有什么深意嘛?谢谢您
image

关于预训练多少epoch最优的问题

作者,您好!

首先非常感谢您分享如此精彩的工作。我有一个小小的困惑,恳请您能给予解答。
就是在预训练阶段是否划分了验证集进行模型预训练效果评估,如果没有验证集,单纯基于所有预训练数据,如何判断预训练达到多少epoch后停止可以得到最优的预训练模型。

祝好!

可以提供一个模型使用的example吗?

我想对 LaBraM 在EEG数据上的编码向量进行分析(模型权重从labram-base.pth导入,之后就冻结,不需要作训练和微调)。

假设信号已完成0.1Hz-75Hz的带通滤波,50Hz的陷波滤波, 重采样至200Hz

输入:
sig.shape :(64, 2006060) # 64通道,200Hz采样率,一小时时长的EEG信号

输出:
LaBraM模型对该eeg的表征向量:V

V.shape :(64, n, V_dim)
64:通道数
n:EEG信号根据窗口宽度(1秒?),被分割为n=3600段
模型对信号的编码向量维度(好像是32维)

V = LaBraM( sig )

Creating embeddings

Hello, thanks so much for making your code available!

I'd like to embed EEG data using the model and wanted to check whether the following makes sense to you. Given a pretrained model, e.g. your provided base model, I'd avoid adding a classification head to NeuralTransformer and catch the output of self.forward_features(...). Given input signals of [batch_size, channels, samples] and default repo parameters, the model would then provide embeddings [batch_size, embed_dim].

Does that seem OK or would you recommend another approach?

Thanks for your help!

Classification accuracy calculation Error in finetuning when multilabel situation

Hi @935963004 , thanks for opening the excellent work!

As title illustrated, in finetuning phase line 117, the training accuracy is calculated as the class_acc = (output.max(-1)[-1] == targets).float().mean() , since the output has the tensor with dimension (64, 3) and target (64, 1), after max(-1)[-1], the pred dimension is (64), directly calling == will use the pytorch broadcasting mechanism,

  • pred broadcast expanded from [64] to [1, 64]
  • target expanded from [64, 1] broadcast to [64, 64]
    As a result, both tensors are expanded to [64, 64], and then the two expanded tensors are compared element-wise. Therefore, each pred[i] element is compared with an element of each row of target[:, 0] in target, and the end result is a [64, 64] tensor representing the comparison result for each combination.
    The correct approach is as follows
    simply replace with the (output.max(-1)[-1] == targets.squeeze()).float().mean()
    OR
    just as the val and test processing:
    class_acc = utils.get_metrics(output.detach().cpu().numpy(), targets.detach().cpu().numpy(), ['accuracy'], is_binary)['accuracy']

Here is the minimal reproducible code:

import torch
pred = torch.tensor([0, 0, 2, 1, 0, 2, 2, 2, 1, 1, 0, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2, 0, 0, 2,
0, 0, 0, 0, 0, 2, 2, 0, 1, 2, 0, 2, 0, 2, 1, 2, 0, 0, 1, 0, 2, 1, 2, 0,
2, 1, 1, 2, 1, 1, 2, 2, 0, 1, 0, 0, 1, 1, 0, 2])
target = torch.tensor([[0],[0],[2],[1],[0],[2],[2], [2],[1],[1],[0],[2],[0],[2],[0],[1],[1], [1], [2], [0], [2], [1], [0], [2],[0],[0], [0],[0], [0], [2],[2],[0],[1],[2],
[0],[2],[0],[2],[1],[2],[0],[0],[1],[0],[2],[1], [2],[0],[2],[1],[1],[2], [1], [1],[2],[2],[0],[1],[0],[0],[1],[1],[0],[2]])

comparison = (pred == target).float()
comparison_squeeze = (pred == target.squeeze()).float()

accuracy = comparison.mean()
accuracy_squeeze = comparison_squeeze.mean()

accuracy.item(), accuracy_squeeze.item()

(0.3408203125, 0.96875)

'RelativePositionBias' is lost

Traceback (most recent call last):
File "/root/autodl-tmp/LaBraM/run_labram_pretraining.py", line 28, in
import modeling_pretrain
File "/root/autodl-tmp/LaBraM/modeling_pretrain.py", line 16, in
from modeling_finetune import Block, _cfg, PatchEmbed, RelativePositionBias
ImportError: cannot import name 'RelativePositionBias' from 'modeling_finetune' (/root/autodl-tmp/LaBraM/modeling_finetune.py)

There is not 'RelativePositionBias' in modeling_finetune.py. How to solve?

关于labram中codebook有无的消融实验

祝贺您!很高兴看到您的团队发表的这篇充满见解的文章,关于其中的实验细节,我想向您询问一下。我看到您在文章中说,是否使用CodeBook对实验的结果不会有太大的提升。关于这个实验我很好奇,我想向您询问,有无codebook的实验是怎么做的呢?

UserWarning: y_pred contains classes not in y_true warnings.warn("y_pred contains classes not in y_true")

Test: [7/8] eta: 0:00:06 loss: 1.3861 (1.3866) accuracy: 0.2396 (0.2500) balanced_accuracy: 0.2500 (0.2471) cohen_kappa: 0.0000 (0.0000) f1_weighted: 0.0926 (0.1013) time: 6.2649 data: 4.9884 max mem: 6735
Test: Total time: 0:00:51 (6.4069 s / it)

  • loss 1.387
    Accuracy of the network on the 680 test EEG: 0.25%
    Max accuracy val: 0.25%, max accuracy test: 0.25%
    I can't seem to see the effect of the training.

What are A and N in B N A T?

class TemporalConv(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, in_chans=1, out_chans=8):
        super().__init__()
        self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7))
        self.gelu1 = nn.GELU()
        self.norm1 = nn.GroupNorm(4, out_chans)
        self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=(1, 3), padding=(0, 1))
        self.gelu2 = nn.GELU()
        self.norm2 = nn.GroupNorm(4, out_chans)
        self.conv3 = nn.Conv2d(out_chans, out_chans, kernel_size=(1, 3), padding=(0, 1))
        self.norm3 = nn.GroupNorm(4, out_chans)
        self.gelu3 = nn.GELU()

    def forward(self, x, **kwargs):
        x = rearrange(x, 'B N A T -> B (N A) T')
        B, NA, T = x.shape
        x = x.unsqueeze(1)
        x = self.gelu1(self.norm1(self.conv1(x)))
        x = self.gelu2(self.norm2(self.conv2(x)))
        x = self.gelu3(self.norm3(self.conv3(x)))
        x = rearrange(x, 'B C NA T -> B NA (T C)')
        return x

Hello. I am analyzing your code to utilize it, and in the forward method of the TemporalConv class in modeling_pretrain.py above, in the part where einops is rearranged, the input dimension is listed as 4-dimensional. I thought B is the batch size, N is the number of electrodes, and T is the sample length, but I couldn't figure out what A means. Also, I have a question about whether N is the number of electrodes because of A.

CUDA Error

I'm trying to fine tune the model and it's resulting in the following error

Start training for 50 epochs
/opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [52,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [13,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [18,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [60,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 3392161 closing signal SIGTERM ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 1 (pid: 3392162) of binary: /home/anaconda3/envs/labram/bin/python
Traceback (most recent call last):
File "/home/anaconda3/envs/labram/bin/torchrun", line 33, in
sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')())

NVCC

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

Trying to run it on a node with 4 RTX A5000

Some help with reproduction

Hello @935963004,

I would like to starting say thank you for your work, I think it is a fundamental and necessary work in EEG decoding. Thank you for that!

So, I am trying to understand and run your code, but some things are not working, and I would like to request your assistance. From the beginning, with a toy example.

import torch
from torch import nn

from modeling_finetune import NeuralTransformer

# As commment in the meet, the expect input is:
# Batch size, channels, time//patch_size, patch_size

in_chans = 1  # **Not working if in_chans is different of 1. Issue with temporal_embedding.**
batch_size = 1
patch_size = 200
n_time_points_patched = 16  # Max number for patch, the value is hardcode in
# the model
EEG_size = 1600

# Generating an empty vector just to get the output.
X = torch.zeros(batch_size, in_chans, n_time_points_patched, patch_size)
# Everything is default
model = NeuralTransformer(
    EEG_size=EEG_size,
    patch_size=patch_size,
    in_chans=in_chans,
    out_chans=8,
    num_classes=1000,
    embed_dim=200,
    depth=12,
    num_heads=10,
    mlp_ratio=4.,
    qkv_bias=False,
    qk_norm=None,
    qk_scale=None,
    drop_rate=0.,
    attn_drop_rate=0.,
    drop_path_rate=0.,
    norm_layer=nn.LayerNorm,
    init_values=0, # default value is not working, changed from None to zero.
    use_abs_pos_emb=False,  # Not working
    use_rel_pos_bias=False, 
    use_shared_rel_pos_bias=False,
    use_mean_pooling=True,
    init_scale=0.001,
)

with torch.no_grad():
    y_pred = model(X)

My questions are:

  • How to make it work with any number of channels?
  • How do we solve the issue with positional embedding? And what about temporal embedding?
  • How to adapt the model to get something as input:
    "(batch, channel, time_steps)"

In my naive intuition if I change the in_chans everything should working because of the TemporalConv module, but it's not.

FYI @LemonFace0309, @jonxuxu and @shahbuland, @RashikShahjahan

All the best!

TUAB fine-tuning replication: unexpected accuracies and loss values

Hi @935963004,

Thanks so much for your work with this project - really exciting stuff. I was trying out fine-tuning on the TUAB dataset and came across some unexpected numbers. It could be something on my end and I'll provide update(s) as I dig into it further but wanted to get a thread started. Thanks for any help!

Created the TUAB dataset with make_TUAB.py and then running fine tuning with the recommended settings in the readme (except for setting to 1 gpu, increasing batch size, and num workers, and removing --dist_eval tag):

!OMP_NUM_THREADS=1 torchrun --nnodes=1 --nproc_per_node=1 run_class_finetuning.py \
        --output_dir ./checkpoints/finetune_tuab_base/ \
        --log_dir ./log/finetune_tuab_base \
        --model labram_base_patch200_200 \
        --finetune ./checkpoints/labram-base.pth \
        --weight_decay 0.05 \
        --batch_size 512 \
        --lr 5e-4 \
        --update_freq 1 \
        --warmup_epochs 5 \
        --epochs 50 \
        --layer_decay 0.65 \
        --drop_path 0.1 \
        --save_ckpt_freq 5 \
        --disable_rel_pos_bias \
        --abs_pos_emb \
        --dataset TUAB \
        --disable_qkv_bias \
        --seed 0 \
        --num_workers 12 \

Below is some sample output I'm seeing:

...
Use step level LR scheduler!
Set warmup steps = 605
Set warmup steps = 0
Max WD = 0.0500000, Min WD = 0.0500000
criterion = BCEWithLogitsLoss()
Auto resume checkpoint: 
Start training for 50 epochs
[W reducer.cpp:1346] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Warning: NaN or Inf found in input tensor.
Epoch: [0]  [  0/121]  eta: 3 days, 16:01:13  lr: 0.000000  min_lr: 0.000000  loss: 0.6931 (0.6931)  class_acc: 1.0000 (1.0000)  loss_scale: 32768.0000 (32768.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: inf (inf)  time: 2618.7900  data: 2615.6816  max mem: 34391
Warning: NaN or Inf found in input tensor.
Warning: NaN or Inf found in input tensor.
Epoch: [0]  [ 10/121]  eta: 7:21:42  lr: 0.000008  min_lr: 0.000000  loss: 0.6929 (0.6927)  class_acc: 1.0000 (0.9998)  loss_scale: 8192.0000 (11170.9091)  weight_decay: 0.0500 (0.0500)  grad_norm: 6.2091 (inf)  time: 238.7637  data: 237.7895  max mem: 34433
Epoch: [0]  [ 20/121]  eta: 3:31:06  lr: 0.000017  min_lr: 0.000000  loss: 0.6915 (0.6911)  class_acc: 1.0000 (0.9999)  loss_scale: 8192.0000 (9752.3810)  weight_decay: 0.0500 (0.0500)  grad_norm: 6.1851 (inf)  time: 0.7382  data: 0.0030  max mem: 34433
Epoch: [0]  [ 30/121]  eta: 2:09:22  lr: 0.000025  min_lr: 0.000000  loss: 0.6854 (0.6872)  class_acc: 1.0000 (0.9999)  loss_scale: 8192.0000 (9249.0323)  weight_decay: 0.0500 (0.0500)  grad_norm: 6.0339 (inf)  time: 0.9033  data: 0.1924  max mem: 34433
Epoch: [0]  [ 40/121]  eta: 1:27:24  lr: 0.000033  min_lr: 0.000000  loss: 0.6681 (0.6793)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8991.2195)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.7410 (inf)  time: 1.0505  data: 0.3342  max mem: 34433
Epoch: [0]  [ 50/121]  eta: 1:01:48  lr: 0.000041  min_lr: 0.000000  loss: 0.6368 (0.6673)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8834.5098)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.5797 (inf)  time: 0.9733  data: 0.2568  max mem: 34433
Epoch: [0]  [ 60/121]  eta: 0:44:31  lr: 0.000050  min_lr: 0.000000  loss: 0.5923 (0.6513)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8729.1803)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.6199 (inf)  time: 0.8588  data: 0.1478  max mem: 34433
Epoch: [0]  [ 70/121]  eta: 0:32:04  lr: 0.000058  min_lr: 0.000000  loss: 0.5391 (0.6321)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8653.5211)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.5495 (inf)  time: 0.7452  data: 0.0358  max mem: 34433
Epoch: [0]  [ 80/121]  eta: 0:22:40  lr: 0.000066  min_lr: 0.000000  loss: 0.4821 (0.6105)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8596.5432)  weight_decay: 0.0500 (0.0500)  grad_norm: 5.2271 (inf)  time: 0.7910  data: 0.0819  max mem: 34433
Epoch: [0]  [ 90/121]  eta: 0:15:19  lr: 0.000075  min_lr: 0.000000  loss: 0.4252 (0.5875)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8552.0879)  weight_decay: 0.0500 (0.0500)  grad_norm: 4.8080 (inf)  time: 1.0040  data: 0.2929  max mem: 34433
Epoch: [0]  [100/121]  eta: 0:09:22  lr: 0.000083  min_lr: 0.000000  loss: 0.3711 (0.5638)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8516.4356)  weight_decay: 0.0500 (0.0500)  grad_norm: 4.3451 (inf)  time: 0.9471  data: 0.2359  max mem: 34433
Epoch: [0]  [110/121]  eta: 0:04:28  lr: 0.000091  min_lr: 0.000000  loss: 0.3215 (0.5400)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8487.2072)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.8768 (inf)  time: 0.7331  data: 0.0248  max mem: 34433
Epoch: [0]  [120/121]  eta: 0:00:22  lr: 0.000099  min_lr: 0.000000  loss: 0.2762 (0.5166)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8462.8099)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.4341 (inf)  time: 0.7080  data: 0.0001  max mem: 34433
Epoch: [0] Total time: 0:45:20 (22.4855 s / it)
Averaged stats: lr: 0.000099  min_lr: 0.000000  loss: 0.2762 (0.5166)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8462.8099)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.4341 (inf)
Val:  [ 0/20]  eta: 2:50:41  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 512.0632  data: 511.3996  max mem: 34433
Val:  [10/20]  eta: 0:09:48  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 58.8147  data: 58.4929  max mem: 34433
Val:  [19/20]  eta: 0:01:01  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 61.3067  data: 61.0130  max mem: 34433
Val: Total time: 0:20:26 (61.3127 s / it)
* loss 0.235
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 1:34:08  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 513.4699  data: 513.1569  max mem: 34433
Test:  [10/11]  eta: 0:00:48  loss: 0.2346 (0.2346)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 48.5245  data: 48.2462  max mem: 34433
Test: Total time: 0:08:53 (48.5333 s / it)
* loss 0.235
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [1]  [  0/121]  eta: 0:25:29  lr: 0.000100  min_lr: 0.000000  loss: 0.2362 (0.2362)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 3.0189 (3.0189)  time: 12.6372  data: 11.8822  max mem: 34435
Epoch: [1]  [ 10/121]  eta: 0:03:19  lr: 0.000108  min_lr: 0.000000  loss: 0.2185 (0.2189)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 2.8253 (2.8273)  time: 1.7946  data: 1.0805  max mem: 34435
Epoch: [1]  [ 20/121]  eta: 0:02:23  lr: 0.000117  min_lr: 0.000000  loss: 0.1991 (0.2033)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 2.6067 (2.6502)  time: 0.8571  data: 0.1461  max mem: 34435
Epoch: [1]  [ 30/121]  eta: 0:01:58  lr: 0.000125  min_lr: 0.000000  loss: 0.1711 (0.1894)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 2.2782 (2.4876)  time: 1.0270  data: 0.3148  max mem: 34435
Epoch: [1]  [ 40/121]  eta: 0:01:40  lr: 0.000133  min_lr: 0.000000  loss: 0.1473 (0.1768)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.9917 (2.3386)  time: 1.0426  data: 0.3300  max mem: 34435
Epoch: [1]  [ 50/121]  eta: 0:01:25  lr: 0.000142  min_lr: 0.000001  loss: 0.1273 (0.1656)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.7436 (2.2026)  time: 1.0637  data: 0.3517  max mem: 34435
Epoch: [1]  [ 60/121]  eta: 0:01:12  lr: 0.000150  min_lr: 0.000001  loss: 0.1103 (0.1555)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.5298 (2.0785)  time: 1.1007  data: 0.3882  max mem: 34435
Epoch: [1]  [ 70/121]  eta: 0:00:57  lr: 0.000158  min_lr: 0.000001  loss: 0.0962 (0.1463)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.3468 (1.9653)  time: 0.9096  data: 0.1978  max mem: 34435
Epoch: [1]  [ 80/121]  eta: 0:00:46  lr: 0.000166  min_lr: 0.000001  loss: 0.0842 (0.1381)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.1898 (1.8619)  time: 0.9218  data: 0.2102  max mem: 34435
Epoch: [1]  [ 90/121]  eta: 0:00:34  lr: 0.000175  min_lr: 0.000001  loss: 0.0741 (0.1306)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 1.0556 (1.7673)  time: 1.1428  data: 0.4293  max mem: 34435
Epoch: [1]  [100/121]  eta: 0:00:23  lr: 0.000183  min_lr: 0.000001  loss: 0.0654 (0.1238)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.9381 (1.6807)  time: 1.1746  data: 0.4622  max mem: 34435
Epoch: [1]  [110/121]  eta: 0:00:12  lr: 0.000191  min_lr: 0.000001  loss: 0.0580 (0.1176)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.8384 (1.6012)  time: 1.2109  data: 0.4988  max mem: 34435
Epoch: [1]  [120/121]  eta: 0:00:01  lr: 0.000200  min_lr: 0.000001  loss: 0.0518 (0.1120)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.7521 (1.5281)  time: 0.9662  data: 0.2559  max mem: 34435
Epoch: [1] Total time: 0:02:13 (1.1073 s / it)
Averaged stats: lr: 0.000200  min_lr: 0.000001  loss: 0.0518 (0.1120)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.7521 (1.5281)
Val:  [ 0/20]  eta: 0:05:09  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 15.4873  data: 15.0481  max mem: 34435
Val:  [10/20]  eta: 0:00:16  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.6699  data: 1.3682  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.3343  data: 1.0552  max mem: 34435
Val: Total time: 0:00:26 (1.3377 s / it)
* loss 0.046
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:29  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.6329  data: 13.3226  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0461 (0.0461)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4797  data: 1.2113  max mem: 34435
Test: Total time: 0:00:16 (1.4852 s / it)
* loss 0.046
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [2]  [  0/121]  eta: 0:21:41  lr: 0.000200  min_lr: 0.000001  loss: 0.0463 (0.0463)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.6762 (0.6762)  time: 10.7563  data: 10.0010  max mem: 34435
Epoch: [2]  [ 10/121]  eta: 0:03:00  lr: 0.000209  min_lr: 0.000001  loss: 0.0437 (0.0438)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.6413 (0.6420)  time: 1.6236  data: 0.9095  max mem: 34435
Epoch: [2]  [ 20/121]  eta: 0:02:06  lr: 0.000217  min_lr: 0.000001  loss: 0.0410 (0.0416)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.6035 (0.6112)  time: 0.7737  data: 0.0637  max mem: 34435
Epoch: [2]  [ 30/121]  eta: 0:01:43  lr: 0.000225  min_lr: 0.000001  loss: 0.0370 (0.0396)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.5469 (0.5830)  time: 0.8707  data: 0.1589  max mem: 34435
Epoch: [2]  [ 40/121]  eta: 0:01:27  lr: 0.000233  min_lr: 0.000001  loss: 0.0335 (0.0378)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.4963 (0.5570)  time: 0.9061  data: 0.1927  max mem: 34435
Epoch: [2]  [ 50/121]  eta: 0:01:14  lr: 0.000242  min_lr: 0.000001  loss: 0.0304 (0.0361)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.4528 (0.5330)  time: 0.9294  data: 0.2166  max mem: 34435
Epoch: [2]  [ 60/121]  eta: 0:01:02  lr: 0.000250  min_lr: 0.000001  loss: 0.0277 (0.0345)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.4138 (0.5108)  time: 0.9026  data: 0.1909  max mem: 34435
Epoch: [2]  [ 70/121]  eta: 0:00:49  lr: 0.000258  min_lr: 0.000001  loss: 0.0252 (0.0331)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.3788 (0.4902)  time: 0.7824  data: 0.0718  max mem: 34435
Epoch: [2]  [ 80/121]  eta: 0:00:39  lr: 0.000267  min_lr: 0.000001  loss: 0.0231 (0.0317)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.3477 (0.4711)  time: 0.7819  data: 0.0711  max mem: 34435
Epoch: [2]  [ 90/121]  eta: 0:00:29  lr: 0.000275  min_lr: 0.000001  loss: 0.0212 (0.0305)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.3201 (0.4533)  time: 0.8553  data: 0.1451  max mem: 34435
Epoch: [2]  [100/121]  eta: 0:00:19  lr: 0.000283  min_lr: 0.000001  loss: 0.0195 (0.0293)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2956 (0.4367)  time: 0.8650  data: 0.1550  max mem: 34435
Epoch: [2]  [110/121]  eta: 0:00:10  lr: 0.000291  min_lr: 0.000001  loss: 0.0180 (0.0283)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2735 (0.4211)  time: 0.8804  data: 0.1692  max mem: 34435
Epoch: [2]  [120/121]  eta: 0:00:00  lr: 0.000300  min_lr: 0.000001  loss: 0.0166 (0.0272)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2533 (0.4066)  time: 0.7990  data: 0.0883  max mem: 34435
Epoch: [2] Total time: 0:01:51 (0.9207 s / it)
Averaged stats: lr: 0.000300  min_lr: 0.000001  loss: 0.0166 (0.0272)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2533 (0.4066)
Val:  [ 0/20]  eta: 0:04:56  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 14.8285  data: 14.5373  max mem: 34435
Val:  [10/20]  eta: 0:00:15  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.5975  data: 1.3217  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.2781  data: 1.0134  max mem: 34435
Val: Total time: 0:00:25 (1.2811 s / it)
* loss 0.015
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:27  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.4538  data: 13.1486  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0154 (0.0154)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4634  data: 1.1955  max mem: 34435
Test: Total time: 0:00:16 (1.4693 s / it)
* loss 0.015
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [3]  [  0/121]  eta: 0:22:54  lr: 0.000300  min_lr: 0.000001  loss: 0.0154 (0.0154)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2353 (0.2353)  time: 11.3583  data: 10.6085  max mem: 34435
Epoch: [3]  [ 10/121]  eta: 0:03:06  lr: 0.000309  min_lr: 0.000001  loss: 0.0148 (0.0148)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2266 (0.2267)  time: 1.6779  data: 0.9647  max mem: 34435
Epoch: [3]  [ 20/121]  eta: 0:02:11  lr: 0.000317  min_lr: 0.000001  loss: 0.0142 (0.0143)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2171 (0.2189)  time: 0.7983  data: 0.0875  max mem: 34435
Epoch: [3]  [ 30/121]  eta: 0:01:45  lr: 0.000325  min_lr: 0.000001  loss: 0.0132 (0.0138)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.2025 (0.2116)  time: 0.8703  data: 0.1595  max mem: 34435
Epoch: [3]  [ 40/121]  eta: 0:01:27  lr: 0.000334  min_lr: 0.000001  loss: 0.0122 (0.0133)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1886 (0.2047)  time: 0.8464  data: 0.1368  max mem: 34435
Epoch: [3]  [ 50/121]  eta: 0:01:13  lr: 0.000342  min_lr: 0.000001  loss: 0.0114 (0.0129)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1765 (0.1982)  time: 0.8564  data: 0.1469  max mem: 34435
Epoch: [3]  [ 60/121]  eta: 0:01:01  lr: 0.000350  min_lr: 0.000001  loss: 0.0107 (0.0125)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1654 (0.1920)  time: 0.8801  data: 0.1703  max mem: 34435
Epoch: [3]  [ 70/121]  eta: 0:00:49  lr: 0.000358  min_lr: 0.000001  loss: 0.0100 (0.0121)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1552 (0.1862)  time: 0.7979  data: 0.0883  max mem: 34435
Epoch: [3]  [ 80/121]  eta: 0:00:39  lr: 0.000367  min_lr: 0.000001  loss: 0.0094 (0.0117)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1456 (0.1807)  time: 0.7810  data: 0.0717  max mem: 34435
Epoch: [3]  [ 90/121]  eta: 0:00:29  lr: 0.000375  min_lr: 0.000001  loss: 0.0088 (0.0114)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1370 (0.1755)  time: 0.8510  data: 0.1414  max mem: 34435
Epoch: [3]  [100/121]  eta: 0:00:19  lr: 0.000383  min_lr: 0.000001  loss: 0.0083 (0.0110)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1290 (0.1706)  time: 0.8608  data: 0.1515  max mem: 34435
Epoch: [3]  [110/121]  eta: 0:00:10  lr: 0.000392  min_lr: 0.000001  loss: 0.0078 (0.0107)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1217 (0.1659)  time: 0.8725  data: 0.1636  max mem: 34435
Epoch: [3]  [120/121]  eta: 0:00:00  lr: 0.000400  min_lr: 0.000001  loss: 0.0073 (0.0104)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1148 (0.1614)  time: 0.7904  data: 0.0820  max mem: 34435
Epoch: [3] Total time: 0:01:50 (0.9140 s / it)
Averaged stats: lr: 0.000400  min_lr: 0.000001  loss: 0.0073 (0.0104)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1148 (0.1614)
Val:  [ 0/20]  eta: 0:04:54  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 14.7361  data: 14.4306  max mem: 34435
Val:  [10/20]  eta: 0:00:15  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.5994  data: 1.3224  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.2757  data: 1.0101  max mem: 34435
Val: Total time: 0:00:25 (1.2790 s / it)
* loss 0.007
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:28  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.4998  data: 13.2017  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0069 (0.0069)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4676  data: 1.2003  max mem: 34435
Test: Total time: 0:00:16 (1.4737 s / it)
* loss 0.007
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
Epoch: [4]  [  0/121]  eta: 0:22:45  lr: 0.000401  min_lr: 0.000001  loss: 0.0069 (0.0069)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1085 (0.1085)  time: 11.2845  data: 10.5294  max mem: 34435
Epoch: [4]  [ 10/121]  eta: 0:03:05  lr: 0.000409  min_lr: 0.000002  loss: 0.0067 (0.0067)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1053 (0.1055)  time: 1.6719  data: 0.9574  max mem: 34435
Epoch: [4]  [ 20/121]  eta: 0:02:10  lr: 0.000417  min_lr: 0.000002  loss: 0.0065 (0.0065)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.1020 (0.1027)  time: 0.7890  data: 0.0790  max mem: 34435
Epoch: [4]  [ 30/121]  eta: 0:01:45  lr: 0.000425  min_lr: 0.000002  loss: 0.0061 (0.0063)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0966 (0.1000)  time: 0.8712  data: 0.1616  max mem: 34435
Epoch: [4]  [ 40/121]  eta: 0:01:27  lr: 0.000434  min_lr: 0.000002  loss: 0.0058 (0.0062)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0916 (0.0974)  time: 0.8668  data: 0.1574  max mem: 34435
Epoch: [4]  [ 50/121]  eta: 0:01:14  lr: 0.000442  min_lr: 0.000002  loss: 0.0055 (0.0060)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0869 (0.0950)  time: 0.8685  data: 0.1592  max mem: 34435
Epoch: [4]  [ 60/121]  eta: 0:01:01  lr: 0.000450  min_lr: 0.000002  loss: 0.0052 (0.0059)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0826 (0.0927)  time: 0.8772  data: 0.1675  max mem: 34435
Epoch: [4]  [ 70/121]  eta: 0:00:49  lr: 0.000459  min_lr: 0.000002  loss: 0.0049 (0.0057)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0785 (0.0904)  time: 0.7934  data: 0.0831  max mem: 34435
Epoch: [4]  [ 80/121]  eta: 0:00:39  lr: 0.000467  min_lr: 0.000002  loss: 0.0047 (0.0056)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0747 (0.0883)  time: 0.7916  data: 0.0816  max mem: 34435
Epoch: [4]  [ 90/121]  eta: 0:00:29  lr: 0.000475  min_lr: 0.000002  loss: 0.0045 (0.0054)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0712 (0.0863)  time: 0.8624  data: 0.1533  max mem: 34435
Epoch: [4]  [100/121]  eta: 0:00:19  lr: 0.000483  min_lr: 0.000002  loss: 0.0042 (0.0053)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0678 (0.0843)  time: 0.8471  data: 0.1383  max mem: 34435
Epoch: [4]  [110/121]  eta: 0:00:10  lr: 0.000492  min_lr: 0.000002  loss: 0.0040 (0.0052)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0647 (0.0824)  time: 0.8458  data: 0.1372  max mem: 34435
Epoch: [4]  [120/121]  eta: 0:00:00  lr: 0.000500  min_lr: 0.000002  loss: 0.0039 (0.0051)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0618 (0.0806)  time: 0.7797  data: 0.0708  max mem: 34435
Epoch: [4] Total time: 0:01:50 (0.9124 s / it)
Averaged stats: lr: 0.000500  min_lr: 0.000002  loss: 0.0039 (0.0051)  class_acc: 1.0000 (1.0000)  loss_scale: 8192.0000 (8192.0000)  weight_decay: 0.0500 (0.0500)  grad_norm: 0.0618 (0.0806)
Val:  [ 0/20]  eta: 0:04:55  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 14.7555  data: 14.4493  max mem: 34435
Val:  [10/20]  eta: 0:00:15  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.5907  data: 1.3137  max mem: 34435
Val:  [19/20]  eta: 0:00:01  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.2756  data: 1.0098  max mem: 34435
Val: Total time: 0:00:25 (1.2788 s / it)
* loss 0.004
Accuracy of the network on the 14718 val EEG: 0.00%
Test:  [ 0/11]  eta: 0:02:27  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 13.4220  data: 13.1036  max mem: 34435
Test:  [10/11]  eta: 0:00:01  loss: 0.0037 (0.0037)  accuracy: 0.0000 (0.0000)  balanced_accuracy: 0.0000 (0.0000)  pr_auc: 0.0000 (0.0000)  roc_auc: 0.0000 (0.0000)  time: 1.4606  data: 1.1914  max mem: 34435
Test: Total time: 0:00:16 (1.4663 s / it)
* loss 0.004
Accuracy of the network on the 14718 test EEG: 0.00%
Max accuracy val: 0.00%, max accuracy test: 0.00%
...

einops.EinopsError: Error while processing rearrange-reduction pattern "B N (A T) -> B N A T". Input tensor shape: torch.Size([23, 2000]). Additional info: {'T': 200}.

einops.EinopsError: Error while processing rearrange-reduction pattern "B N (A T) -> B N A T".
Input tensor shape: torch.Size([23, 2000]). Additional info: {'T': 200}.
Wrong shape: expected 3 dims. Received 2-dim tensor.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 790759) of binary: /d/miniconda3/envs/lb_labram/bin/python

The TUAB data set and labram-base.pth model used,
But the input and output do not match
How to solve this error? Thank you so much

.cnt files do classification tasks

If the dataset is a cnt file, what code needs to be used or written to use the model?I seem to be seeing a couple of preprocessed python scripts:make_h5dataset_for_pretrain.py、dataset.py. Do I just modify the folder path and run it directly, or do I need to write my own code? If I need to write my own code, what do I need to pay attention to? I've seen a lot of people have questions about how to run the model on their own dataset, can you provide some help? Thank you

Issue with cuda

Hello, and thank you for the awesome work!
I got the following error which is (I think) related to distributed training while try to run your example on colab pro plus

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

| distributed init (rank 0): env://, gpu 0
Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Traceback (most recent call last):
File "/content/LaBraM/run_class_finetuning.py", line 565, in
main(opts, ds_init)
File "/content/LaBraM/run_class_finetuning.py", line 232, in main
utils.init_distributed_mode(args)
File "/content/LaBraM/utils.py", line 428, in init_distributed_mode
torch.cuda.set_device(args.gpu)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 399, in set_device
torch._C._cuda_setDevice(device)
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

W0630 22:56:38.144000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20757 closing signal SIGTERM
W0630 22:56:38.144000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20758 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20759 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20761 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20763 closing signal SIGTERM
W0630 22:56:38.145000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 20764 closing signal SIGTERM
E0630 22:56:38.262000 139452743365248 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 3 (pid: 20760) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 347, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 879, in main
run(args)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 870, in run
elastic_launch(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 132, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

run_class_finetuning.py FAILED

Failures:
[1]:
time : 2024-06-30_22:56:38
host : 7283cc30eeb0
rank : 5 (local_rank: 5)
exitcode : 1 (pid: 20762)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause (first observed failure):
[0]:
time : 2024-06-30_22:56:38
host : 7283cc30eeb0
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 20760)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

用公开CHB-MIT数据集跑微调代码?

作者您好,我在使用您的微调代码去跑CHB-MIT数据集(二分类)时,发现训练集准确率可以很高,loss也在降低。但是验证集和测试集的评价指标一直为0。请问您有遇到这种问题吗,不知您是怎么解决的?
1
2
3

为什么不在pretrain的时候使用第一步训好的模型权重?

你好,谢谢你们分享如此精彩的项目,我仔细看了一下labram的代码,发现总览图中下半部分的模型就是neural tokenizer的模型,只是head不一样。也就是说在第一步训练tokenizer的时候就把位置编码和时间编码弄进去了。然后建模的时候是把单个patch变成feature,对所有patch一起建模的。我没想通的是为什么要在第二步训练下面这个LaBram模型的时候完全重新训练一个,把LaBram作为student模型。直接用第一步训好的neural tokenizer,加载这个权重,然后只换一个head不行吗?就像你们在第三步微调的时候那样,加载原有的模型权重,只重新训练head。真诚发问,希望可以得到回复。
image

Error: Unexpected key(s) in state_dict: "logit_scale".

Hi! Thank you for your great work. I am trying to load the pre-trained base model to extract feature embeddings (no fine-tuning). When I load the model, I get the following error:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for NeuralTransformerForMEM: Unexpected key(s) in state_dict: "logit_scale".

Any solution? Also, do you have any pointers like which python file should I run to extract the feature embeddings from the pre-trained model?

关于TUAB和TUEV数据集的预处理

您好,感谢您的优秀工作!我最近在复现您的实验,在处理原始TUAB和TUEV数据到h5数据集时遇到一些问题,您提供的代码是读取.cnt文件而,make_TUAB.py 和 make_TUEV.py 处理后的文件是存到.pkl格式的,请问这个如何处理成h5 dataset呢?我理解pretrain 模型的输入都是h5 dataset格式的,提前感谢您的耐心回复!

RuntimeError: The size of tensor a (341) must match the size of tensor b (286) at non-singleton dimension 1

I put the modified cnt dataset on the model and ran it sending some errors.
D:\ProgramData\anaconda3\envs\labram\python.exe E:\lab\DL\LaBraM-main\run_class_finetuning.py
Not using distributed mode
Namespace(batch_size=64, epochs=30, update_freq=1, save_ckpt_freq=5, robust_test=None, model='labram_base_patch200_200', qkv_bias=True, rel_pos_bias=True, abs_pos_emb=True, layer_scale_init_value=0.1, input_size=200, drop=0.0, attn_drop_rate=0.0, drop_path=0.1, disable_eval_during_finetuning=False, model_ema=False, model_ema_decay=0.9999, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, weight_decay_end=None, lr=0.0005, layer_decay=0.9, warmup_lr=1e-06, min_lr=1e-06, warmup_epochs=5, warmup_steps=-1, smoothing=0.1, reprob=0.25, remode='pixel', recount=1, resplit=False, finetune='', model_key='model|module', model_prefix='', model_filter_name='gzp', init_scale=0.001, use_mean_pooling=True, disable_weight_decay_on_rel_pos_bias=False, nb_classes=4, output_dir='E:/lab/DL/LaBraM-main/checkpoints/finetune_MI_base', log_dir='E:/lab/DL/LaBraM-main/log/finetune_MI_base', device='cuda', seed=0, resume='', auto_resume=True, save_ckpt=True, start_epoch=0, eval=False, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, local_rank=-1, dist_on_itp=False, dist_url='env://', enable_deepspeed=False, dataset='MI', distributed=False)
2199 399 680
Sampler_train = <torch.utils.data.distributed.DistributedSampler object at 0x000001F93E230D50>
Patch size = 200
Model = NeuralTransformer(
(patch_embed): TemporalConv(
(conv1): Conv2d(1, 8, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7))
(gelu1): GELU(approximate='none')
(norm1): GroupNorm(4, 8, eps=1e-05, affine=True)
(conv2): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(gelu2): GELU(approximate='none')
(norm2): GroupNorm(4, 8, eps=1e-05, affine=True)
(conv3): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(norm3): GroupNorm(4, 8, eps=1e-05, affine=True)
(gelu3): GELU(approximate='none')
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): ModuleList(
(0): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.00909090880304575)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.0181818176060915)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.027272727340459824)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(4): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.036363635212183)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(5): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.045454543083906174)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(6): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.054545458406209946)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(7): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.06363636255264282)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(8): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.0727272778749466)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(9): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.08181818574666977)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(10): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.09090909361839294)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(11): Block(
(norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=200, out_features=600, bias=False)
(q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=200, out_features=200, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): DropPath(p=0.10000000149011612)
(norm2): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=200, out_features=800, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=800, out_features=200, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm): Identity()
(fc_norm): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
(head): Linear(in_features=200, out_features=4, bias=True)
)
number of params: 5825540
LR = 0.00050000
Batch size = 64
Update frequent = 1
Number of training examples = 2199
Number of training training per epoch = 34
Assigned values = [0.2541865828329001, 0.2824295364810001, 0.31381059609000006, 0.3486784401000001, 0.3874204890000001, 0.4304672100000001, 0.4782969000000001, 0.531441, 0.5904900000000001, 0.6561, 0.7290000000000001, 0.81, 0.9, 1.0]
Skip weight decay name marked in model: {'time_embed', 'cls_token', 'pos_embed'}
Param groups = {
"layer_0_no_decay": {
"weight_decay": 0.0,
"params": [
"cls_token",
"pos_embed",
"patch_embed.conv1.bias",
"patch_embed.norm1.weight",
"patch_embed.norm1.bias",
"patch_embed.conv2.bias",
"patch_embed.norm2.weight",
"patch_embed.norm2.bias",
"patch_embed.conv3.bias",
"patch_embed.norm3.weight",
"patch_embed.norm3.bias"
],
"lr_scale": 0.2541865828329001
},
"layer_13_no_decay": {
"weight_decay": 0.0,
"params": [
"time_embed",
"fc_norm.weight",
"fc_norm.bias",
"head.bias"
],
"lr_scale": 1.0
},
"layer_0_decay": {
"weight_decay": 0.05,
"params": [
"patch_embed.conv1.weight",
"patch_embed.conv2.weight",
"patch_embed.conv3.weight"
],
"lr_scale": 0.2541865828329001
},
"layer_1_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.0.gamma_1",
"blocks.0.gamma_2",
"blocks.0.norm1.weight",
"blocks.0.norm1.bias",
"blocks.0.attn.q_bias",
"blocks.0.attn.v_bias",
"blocks.0.attn.q_norm.weight",
"blocks.0.attn.q_norm.bias",
"blocks.0.attn.k_norm.weight",
"blocks.0.attn.k_norm.bias",
"blocks.0.attn.proj.bias",
"blocks.0.norm2.weight",
"blocks.0.norm2.bias",
"blocks.0.mlp.fc1.bias",
"blocks.0.mlp.fc2.bias"
],
"lr_scale": 0.2824295364810001
},
"layer_1_decay": {
"weight_decay": 0.05,
"params": [
"blocks.0.attn.qkv.weight",
"blocks.0.attn.proj.weight",
"blocks.0.mlp.fc1.weight",
"blocks.0.mlp.fc2.weight"
],
"lr_scale": 0.2824295364810001
},
"layer_2_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.1.gamma_1",
"blocks.1.gamma_2",
"blocks.1.norm1.weight",
"blocks.1.norm1.bias",
"blocks.1.attn.q_bias",
"blocks.1.attn.v_bias",
"blocks.1.attn.q_norm.weight",
"blocks.1.attn.q_norm.bias",
"blocks.1.attn.k_norm.weight",
"blocks.1.attn.k_norm.bias",
"blocks.1.attn.proj.bias",
"blocks.1.norm2.weight",
"blocks.1.norm2.bias",
"blocks.1.mlp.fc1.bias",
"blocks.1.mlp.fc2.bias"
],
"lr_scale": 0.31381059609000006
},
"layer_2_decay": {
"weight_decay": 0.05,
"params": [
"blocks.1.attn.qkv.weight",
"blocks.1.attn.proj.weight",
"blocks.1.mlp.fc1.weight",
"blocks.1.mlp.fc2.weight"
],
"lr_scale": 0.31381059609000006
},
"layer_3_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.2.gamma_1",
"blocks.2.gamma_2",
"blocks.2.norm1.weight",
"blocks.2.norm1.bias",
"blocks.2.attn.q_bias",
"blocks.2.attn.v_bias",
"blocks.2.attn.q_norm.weight",
"blocks.2.attn.q_norm.bias",
"blocks.2.attn.k_norm.weight",
"blocks.2.attn.k_norm.bias",
"blocks.2.attn.proj.bias",
"blocks.2.norm2.weight",
"blocks.2.norm2.bias",
"blocks.2.mlp.fc1.bias",
"blocks.2.mlp.fc2.bias"
],
"lr_scale": 0.3486784401000001
},
"layer_3_decay": {
"weight_decay": 0.05,
"params": [
"blocks.2.attn.qkv.weight",
"blocks.2.attn.proj.weight",
"blocks.2.mlp.fc1.weight",
"blocks.2.mlp.fc2.weight"
],
"lr_scale": 0.3486784401000001
},
"layer_4_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.3.gamma_1",
"blocks.3.gamma_2",
"blocks.3.norm1.weight",
"blocks.3.norm1.bias",
"blocks.3.attn.q_bias",
"blocks.3.attn.v_bias",
"blocks.3.attn.q_norm.weight",
"blocks.3.attn.q_norm.bias",
"blocks.3.attn.k_norm.weight",
"blocks.3.attn.k_norm.bias",
"blocks.3.attn.proj.bias",
"blocks.3.norm2.weight",
"blocks.3.norm2.bias",
"blocks.3.mlp.fc1.bias",
"blocks.3.mlp.fc2.bias"
],
"lr_scale": 0.3874204890000001
},
"layer_4_decay": {
"weight_decay": 0.05,
"params": [
"blocks.3.attn.qkv.weight",
"blocks.3.attn.proj.weight",
"blocks.3.mlp.fc1.weight",
"blocks.3.mlp.fc2.weight"
],
"lr_scale": 0.3874204890000001
},
"layer_5_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.4.gamma_1",
"blocks.4.gamma_2",
"blocks.4.norm1.weight",
"blocks.4.norm1.bias",
"blocks.4.attn.q_bias",
"blocks.4.attn.v_bias",
"blocks.4.attn.q_norm.weight",
"blocks.4.attn.q_norm.bias",
"blocks.4.attn.k_norm.weight",
"blocks.4.attn.k_norm.bias",
"blocks.4.attn.proj.bias",
"blocks.4.norm2.weight",
"blocks.4.norm2.bias",
"blocks.4.mlp.fc1.bias",
"blocks.4.mlp.fc2.bias"
],
"lr_scale": 0.4304672100000001
},
"layer_5_decay": {
"weight_decay": 0.05,
"params": [
"blocks.4.attn.qkv.weight",
"blocks.4.attn.proj.weight",
"blocks.4.mlp.fc1.weight",
"blocks.4.mlp.fc2.weight"
],
"lr_scale": 0.4304672100000001
},
"layer_6_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.5.gamma_1",
"blocks.5.gamma_2",
"blocks.5.norm1.weight",
"blocks.5.norm1.bias",
"blocks.5.attn.q_bias",
"blocks.5.attn.v_bias",
"blocks.5.attn.q_norm.weight",
"blocks.5.attn.q_norm.bias",
"blocks.5.attn.k_norm.weight",
"blocks.5.attn.k_norm.bias",
"blocks.5.attn.proj.bias",
"blocks.5.norm2.weight",
"blocks.5.norm2.bias",
"blocks.5.mlp.fc1.bias",
"blocks.5.mlp.fc2.bias"
],
"lr_scale": 0.4782969000000001
},
"layer_6_decay": {
"weight_decay": 0.05,
"params": [
"blocks.5.attn.qkv.weight",
"blocks.5.attn.proj.weight",
"blocks.5.mlp.fc1.weight",
"blocks.5.mlp.fc2.weight"
],
"lr_scale": 0.4782969000000001
},
"layer_7_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.6.gamma_1",
"blocks.6.gamma_2",
"blocks.6.norm1.weight",
"blocks.6.norm1.bias",
"blocks.6.attn.q_bias",
"blocks.6.attn.v_bias",
"blocks.6.attn.q_norm.weight",
"blocks.6.attn.q_norm.bias",
"blocks.6.attn.k_norm.weight",
"blocks.6.attn.k_norm.bias",
"blocks.6.attn.proj.bias",
"blocks.6.norm2.weight",
"blocks.6.norm2.bias",
"blocks.6.mlp.fc1.bias",
"blocks.6.mlp.fc2.bias"
],
"lr_scale": 0.531441
},
"layer_7_decay": {
"weight_decay": 0.05,
"params": [
"blocks.6.attn.qkv.weight",
"blocks.6.attn.proj.weight",
"blocks.6.mlp.fc1.weight",
"blocks.6.mlp.fc2.weight"
],
"lr_scale": 0.531441
},
"layer_8_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.7.gamma_1",
"blocks.7.gamma_2",
"blocks.7.norm1.weight",
"blocks.7.norm1.bias",
"blocks.7.attn.q_bias",
"blocks.7.attn.v_bias",
"blocks.7.attn.q_norm.weight",
"blocks.7.attn.q_norm.bias",
"blocks.7.attn.k_norm.weight",
"blocks.7.attn.k_norm.bias",
"blocks.7.attn.proj.bias",
"blocks.7.norm2.weight",
"blocks.7.norm2.bias",
"blocks.7.mlp.fc1.bias",
"blocks.7.mlp.fc2.bias"
],
"lr_scale": 0.5904900000000001
},
"layer_8_decay": {
"weight_decay": 0.05,
"params": [
"blocks.7.attn.qkv.weight",
"blocks.7.attn.proj.weight",
"blocks.7.mlp.fc1.weight",
"blocks.7.mlp.fc2.weight"
],
"lr_scale": 0.5904900000000001
},
"layer_9_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.8.gamma_1",
"blocks.8.gamma_2",
"blocks.8.norm1.weight",
"blocks.8.norm1.bias",
"blocks.8.attn.q_bias",
"blocks.8.attn.v_bias",
"blocks.8.attn.q_norm.weight",
"blocks.8.attn.q_norm.bias",
"blocks.8.attn.k_norm.weight",
"blocks.8.attn.k_norm.bias",
"blocks.8.attn.proj.bias",
"blocks.8.norm2.weight",
"blocks.8.norm2.bias",
"blocks.8.mlp.fc1.bias",
"blocks.8.mlp.fc2.bias"
],
"lr_scale": 0.6561
},
"layer_9_decay": {
"weight_decay": 0.05,
"params": [
"blocks.8.attn.qkv.weight",
"blocks.8.attn.proj.weight",
"blocks.8.mlp.fc1.weight",
"blocks.8.mlp.fc2.weight"
],
"lr_scale": 0.6561
},
"layer_10_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.9.gamma_1",
"blocks.9.gamma_2",
"blocks.9.norm1.weight",
"blocks.9.norm1.bias",
"blocks.9.attn.q_bias",
"blocks.9.attn.v_bias",
"blocks.9.attn.q_norm.weight",
"blocks.9.attn.q_norm.bias",
"blocks.9.attn.k_norm.weight",
"blocks.9.attn.k_norm.bias",
"blocks.9.attn.proj.bias",
"blocks.9.norm2.weight",
"blocks.9.norm2.bias",
"blocks.9.mlp.fc1.bias",
"blocks.9.mlp.fc2.bias"
],
"lr_scale": 0.7290000000000001
},
"layer_10_decay": {
"weight_decay": 0.05,
"params": [
"blocks.9.attn.qkv.weight",
"blocks.9.attn.proj.weight",
"blocks.9.mlp.fc1.weight",
"blocks.9.mlp.fc2.weight"
],
"lr_scale": 0.7290000000000001
},
"layer_11_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.10.gamma_1",
"blocks.10.gamma_2",
"blocks.10.norm1.weight",
"blocks.10.norm1.bias",
"blocks.10.attn.q_bias",
"blocks.10.attn.v_bias",
"blocks.10.attn.q_norm.weight",
"blocks.10.attn.q_norm.bias",
"blocks.10.attn.k_norm.weight",
"blocks.10.attn.k_norm.bias",
"blocks.10.attn.proj.bias",
"blocks.10.norm2.weight",
"blocks.10.norm2.bias",
"blocks.10.mlp.fc1.bias",
"blocks.10.mlp.fc2.bias"
],
"lr_scale": 0.81
},
"layer_11_decay": {
"weight_decay": 0.05,
"params": [
"blocks.10.attn.qkv.weight",
"blocks.10.attn.proj.weight",
"blocks.10.mlp.fc1.weight",
"blocks.10.mlp.fc2.weight"
],
"lr_scale": 0.81
},
"layer_12_no_decay": {
"weight_decay": 0.0,
"params": [
"blocks.11.gamma_1",
"blocks.11.gamma_2",
"blocks.11.norm1.weight",
"blocks.11.norm1.bias",
"blocks.11.attn.q_bias",
"blocks.11.attn.v_bias",
"blocks.11.attn.q_norm.weight",
"blocks.11.attn.q_norm.bias",
"blocks.11.attn.k_norm.weight",
"blocks.11.attn.k_norm.bias",
"blocks.11.attn.proj.bias",
"blocks.11.norm2.weight",
"blocks.11.norm2.bias",
"blocks.11.mlp.fc1.bias",
"blocks.11.mlp.fc2.bias"
],
"lr_scale": 0.9
},
"layer_12_decay": {
"weight_decay": 0.05,
"params": [
"blocks.11.attn.qkv.weight",
"blocks.11.attn.proj.weight",
"blocks.11.mlp.fc1.weight",
"blocks.11.mlp.fc2.weight"
],
"lr_scale": 0.9
},
"layer_13_decay": {
"weight_decay": 0.05,
"params": [
"head.weight"
],
"lr_scale": 1.0
}
}
Optimizer config: {'lr': 0.0005, 'weight_decay': 0.0, 'eps': 1e-08}
Use step level LR scheduler!
Set warmup steps = 170
Set warmup steps = 0
Max WD = 0.0500000, Min WD = 0.0500000
criterion = LabelSmoothingCrossEntropy()
Auto resume checkpoint:
Start training for 30 epochs
Traceback (most recent call last):
File "E:\lab\DL\LaBraM-main\run_class_finetuning.py", line 582, in
main(opts, ds_init)
File "E:\lab\DL\LaBraM-main\run_class_finetuning.py", line 496, in main
train_stats = train_one_epoch(
^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\engine_for_finetuning.py", line 77, in train_one_epoch
loss, output = train_class_batch(
^^^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\engine_for_finetuning.py", line 19, in train_class_batch
outputs = model(samples, ch_names)
^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\ProgramData\anaconda3\envs\labram\Lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\ProgramData\anaconda3\envs\labram\Lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\modeling_finetune.py", line 395, in forward
x = self.forward_features(x, input_chans=input_chans, return_patch_tokens=return_patch_tokens, return_all_tokens=return_all_tokens, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\lab\DL\LaBraM-main\modeling_finetune.py", line 362, in forward_features
x = x + pos_embed
~~^~~~~~~~~~~
RuntimeError: The size of tensor a (341) must match the size of tensor b (286) at non-singleton dimension 1

进程已结束,退出代码为 1

How to deal with different number of channels within one dataset?

Hi Weibang,

Thanks for sharing your excellent work! There's one thing that I'm not super clear about. When preprocessing a dataset like TUEP, in which the number of channels ranges from 19 to 23, how did you deal with this variation so that different samples could be batched and dumped into hdf5 files for later use?

Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.