Coder Social home page Coder Social logo

emadeldeen24 / adast Goto Github PK

View Code? Open in Web Editor NEW
32.0 32.0 1.0 666 KB

[IEEE TETCI] "ADAST: Attentive Cross-domain EEG-based Sleep Staging Framework with Iterative Self-Training"

License: Apache License 2.0

Python 100.00%
attention-mechanism deep-learning domain-adaptation eeg pseudo-label self-attention self-training sleep-stage-classification time-series transfer-learning

adast's People

Contributors

emadeldeen24 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

mohamedr002

adast's Issues

Info

Great paper, may I know where was it published. Coming from arvix btw.

to_data_frame() got an unexpected keyword argument 'scaling_time'

Hello, I got the following error on PC in Windows 11 and i have already installed the latest version of pandas and numpy and mne.

D:/Desktop/Pattern Project/Project_Pattern/Dataset/PSG/SHHS1\shhs1-200097.edf
Extracting EDF parameters from D:\Desktop\Pattern Project\Project_Pattern\Dataset\PSG\SHHS1\shhs1-200097.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 3959999 = 0.000 ... 31679.992 secs...
Traceback (most recent call last):
File "d:\Desktop\Pattern Project\Project_Pattern\Code\prepare_shhs.py", line 137, in
main()
File "d:\Desktop\Pattern Project\Project_Pattern\Code\prepare_shhs.py", line 65, in main
raw_ch_df = raw.to_data_frame(scaling_time=sampling_rate)[select_ch]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: to_data_frame() got an unexpected keyword argument 'scaling_time'

dataloader

Hello
Can you please explain about dataloader file? I couldn't load the downloaded dataset in the dataloader file.
How should i split the data and put it in the following part?

def data_generator(data_path, domain_id, configs):
# loading path
train_dataset = torch.load(os.path.join(data_path, "train_" + domain_id + ".pt"))
valid_dataset = torch.load(os.path.join(data_path, "val_" + domain_id + ".pt"))
test_dataset = torch.load(os.path.join(data_path, "test_" + domain_id + ".pt"))

# Loading datasets
train_dataset = Load_Dataset(train_dataset)
valid_dataset = Load_Dataset(valid_dataset)
test_dataset = Load_Dataset(test_dataset) 

dataloader

the code convert npz to pt

import os
import random
import shutil
import torch
import numpy as np
def split_dataset(data_dir, domain_id, train_ratio, val_ratio, test_ratio):
    if not os.path.exists(data_dir):
        print(f"Directory '{data_dir}' does not exist.")
        return

    # 获取所有的npz文件
    npz_files = [f for f in os.listdir(data_dir) if f.endswith(".npz")]
    random.shuffle(npz_files)
    # (a, decimals=0, out=None)
    total_files = len(npz_files)
    train_count = int(total_files * train_ratio)
    val_count = int(np.around(total_files * val_ratio))
    test_count = total_files - train_count - val_count
    print("train:",train_count,"val:",val_count,"test:",test_count)

    train_files = npz_files[:train_count]
    val_files = npz_files[train_count:train_count+val_count]
    test_files = npz_files[train_count+val_count:]
    print(train_files)


    # 转换并保存为.pt文件
    convert_and_save(train_files, data_dir, domain_id, "train")
    convert_and_save(val_files, data_dir, domain_id, "val")
    convert_and_save(test_files, data_dir, domain_id, "test")

    print("Conversion to .pt files completed.")

def convert_and_save(npz_files, data_dir, domain_id, split_type):

    samples_list = []
    labels_list = []

    for file in npz_files:
        npz_path = os.path.join(data_dir, file)
        data = np.load(npz_path)
        samples = torch.from_numpy(data['x']).float()
        labels = torch.from_numpy(data['y']).long()
        samples_list.append(samples)
        labels_list.append(labels)

    samples_tensor = torch.cat(samples_list, dim=0)
    labels_tensor = torch.cat(labels_list, dim=0)
    dataset = {"samples": samples_tensor, "labels": labels_tensor}
    pt_filename = f"{split_type}_{domain_id}.pt"
    pt_path = os.path.join(data_dir, pt_filename)
    
    print(pt_filename,pt_path)
    torch.save(dataset, pt_path)

    print(f"{split_type.capitalize()} converted and saved as .pt files.")

data_dir = "./"  # 填写包含npz文件的目录
train_ratio = 0.7  # 训练集的比例
val_ratio = 0.15  # 验证集的比例
test_ratio = 0.15  # 测试集的比例
domain_id = 'a'
split_dataset(data_dir, domain_id, train_ratio, val_ratio, test_ratio)

After processing the data, running the code encountered a problem,in'val_self_training'

model_output_dim 29 configs.final_out_channels 128
features torch.Size([128, 8960])
model[1]: Self_Attn(
(query_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(key_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(value_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
(softmax): Softmax(dim=-1)
) model[2][0]: Classifier(
(logits): Linear(in_features=3712, out_features=5, bias=True)
) model[2][1]: Classifier(
(logits): Linear(in_features=3712, out_features=5, bias=True)
)
input torch.Size([128, 8960])
Traceback (most recent call last):
File "train_CD.py", line 112, in
main_train_cd()
File "train_CD.py", line 99, in main_train_cd
target_model = cross_domain_train(src_train_dl, trg_train_dl, trg_valid_dl,
File "/mnt/public/home/leichen01/ADAST/ADAST-main/trainer/ADAST.py", line 51, in cross_domain_train
val_self_training((feature_extractor, trg_att, (classifier_1, classifier_2)), trg_train_dl, device, src_id,
File "/mnt/public/home/leichen01/ADAST/ADAST-main/trainer/training_evaluation.py", line 31, in val_self_training
predictions = model[2]0
File "/ADAST/ADAST-main/models/models.py", line 81, in forward
logits = self.logits(input)
File "torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x8960 and 3712x5)

, Then, I set the model_ Output_ Dim set to 70, but I encountered a problem again

input torch.Size([128, 8960])
input torch.Size([128, 8960])
features torch.Size([37, 8960])
model[1]: Self_Attn(
(query_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(key_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(value_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
(softmax): Softmax(dim=-1)
) model[2][0]: Classifier(
(logits): Linear(in_features=8960, out_features=5, bias=True)
) model[2][1]: Classifier(
(logits): Linear(in_features=8960, out_features=5, bias=True)
)

input torch.Size([37, 8960])
input torch.Size([37, 8960])
feature_extractor src_feat torch.Size([128, 128, 29])
att src_feat torch.Size([128, 3712])
input torch.Size([128, 3712])
Traceback (most recent call last):
File "train_CD.py", line 112, in
main_train_cd()
File "train_CD.py", line 99, in main_train_cd
target_model = cross_domain_train(src_train_dl, trg_train_dl, trg_valid_dl,
File "ADAST/ADAST-main/trainer/ADAST.py", line 89, in cross_domain_train
src_pred = classifier_1(src_feat)
File "ADAST-main/models/models.py", line 81, in forward
logits = self.logits(input)
File "/mnt/public/home/leichen01/miniconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x3712 and 8960x5)

the code i used in

def cross_domain_train(src_train_dl, trg_train_dl, trg_valid_dl,
                       src_id, trg_id,
                       device, logger, configs, args, param_config):
........
# training..
        for epoch in range(1, configs.num_epoch + 1):
            joint_loaders = enumerate(zip(src_train_dl, pseudo_trg_train_dl))
            feature_extractor.train()
            classifier_1.train()
            classifier_2.train()

            src_att.train()
            trg_att.train()
            feature_discriminator.train()

            for step, ((src_data, src_labels), (trg_data, pseudo_trg_labels)) in joint_loaders:
                src_data, src_labels, trg_data, pseudo_trg_labels = src_data.float().to(device), src_labels.long().to(
                    device), trg_data.float().to(device), pseudo_trg_labels.long().to(device)

                for param in feature_discriminator.parameters():
                    param.requires_grad = True

                # pass data through the source model network.
                src_feat = feature_extractor(src_data)
                print("feature_extractor src_feat",src_feat.shape)
                src_feat = src_att(src_feat)
                print("att src_feat",src_feat.shape)
                src_pred = classifier_1(src_feat)
                src_pred_2 = classifier_2(src_feat)
                
class Classifier(nn.Module):
    def __init__(self, configs):
        super(Classifier, self).__init__()
        model_output_dim = configs.features_len
        print("model_output_dim",model_output_dim,"configs.final_out_channels",configs.final_out_channels)
        self.logits = nn.Linear(model_output_dim * configs.final_out_channels, configs.num_classes)

    def forward(self, input):
        print("input",input.shape)
        logits = self.logits(input)
        return logits                

Can you give me some guidance,my email: [email protected]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.