emadeldeen24 / adast Goto Github PK
View Code? Open in Web Editor NEW[IEEE TETCI] "ADAST: Attentive Cross-domain EEG-based Sleep Staging Framework with Iterative Self-Training"
License: Apache License 2.0
[IEEE TETCI] "ADAST: Attentive Cross-domain EEG-based Sleep Staging Framework with Iterative Self-Training"
License: Apache License 2.0
Great paper, may I know where was it published. Coming from arvix btw.
Hello, I got the following error on PC in Windows 11 and i have already installed the latest version of pandas and numpy and mne.
D:/Desktop/Pattern Project/Project_Pattern/Dataset/PSG/SHHS1\shhs1-200097.edf
Extracting EDF parameters from D:\Desktop\Pattern Project\Project_Pattern\Dataset\PSG\SHHS1\shhs1-200097.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 3959999 = 0.000 ... 31679.992 secs...
Traceback (most recent call last):
File "d:\Desktop\Pattern Project\Project_Pattern\Code\prepare_shhs.py", line 137, in
main()
File "d:\Desktop\Pattern Project\Project_Pattern\Code\prepare_shhs.py", line 65, in main
raw_ch_df = raw.to_data_frame(scaling_time=sampling_rate)[select_ch]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: to_data_frame() got an unexpected keyword argument 'scaling_time'
Hello
Can you please explain about dataloader file? I couldn't load the downloaded dataset in the dataloader file.
How should i split the data and put it in the following part?
def data_generator(data_path, domain_id, configs):
# loading path
train_dataset = torch.load(os.path.join(data_path, "train_" + domain_id + ".pt"))
valid_dataset = torch.load(os.path.join(data_path, "val_" + domain_id + ".pt"))
test_dataset = torch.load(os.path.join(data_path, "test_" + domain_id + ".pt"))
# Loading datasets
train_dataset = Load_Dataset(train_dataset)
valid_dataset = Load_Dataset(valid_dataset)
test_dataset = Load_Dataset(test_dataset)
You may need to install : mne=='0.20.7
import os
import random
import shutil
import torch
import numpy as np
def split_dataset(data_dir, domain_id, train_ratio, val_ratio, test_ratio):
if not os.path.exists(data_dir):
print(f"Directory '{data_dir}' does not exist.")
return
# 获取所有的npz文件
npz_files = [f for f in os.listdir(data_dir) if f.endswith(".npz")]
random.shuffle(npz_files)
# (a, decimals=0, out=None)
total_files = len(npz_files)
train_count = int(total_files * train_ratio)
val_count = int(np.around(total_files * val_ratio))
test_count = total_files - train_count - val_count
print("train:",train_count,"val:",val_count,"test:",test_count)
train_files = npz_files[:train_count]
val_files = npz_files[train_count:train_count+val_count]
test_files = npz_files[train_count+val_count:]
print(train_files)
# 转换并保存为.pt文件
convert_and_save(train_files, data_dir, domain_id, "train")
convert_and_save(val_files, data_dir, domain_id, "val")
convert_and_save(test_files, data_dir, domain_id, "test")
print("Conversion to .pt files completed.")
def convert_and_save(npz_files, data_dir, domain_id, split_type):
samples_list = []
labels_list = []
for file in npz_files:
npz_path = os.path.join(data_dir, file)
data = np.load(npz_path)
samples = torch.from_numpy(data['x']).float()
labels = torch.from_numpy(data['y']).long()
samples_list.append(samples)
labels_list.append(labels)
samples_tensor = torch.cat(samples_list, dim=0)
labels_tensor = torch.cat(labels_list, dim=0)
dataset = {"samples": samples_tensor, "labels": labels_tensor}
pt_filename = f"{split_type}_{domain_id}.pt"
pt_path = os.path.join(data_dir, pt_filename)
print(pt_filename,pt_path)
torch.save(dataset, pt_path)
print(f"{split_type.capitalize()} converted and saved as .pt files.")
data_dir = "./" # 填写包含npz文件的目录
train_ratio = 0.7 # 训练集的比例
val_ratio = 0.15 # 验证集的比例
test_ratio = 0.15 # 测试集的比例
domain_id = 'a'
split_dataset(data_dir, domain_id, train_ratio, val_ratio, test_ratio)
model_output_dim 29 configs.final_out_channels 128
features torch.Size([128, 8960])
model[1]: Self_Attn(
(query_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(key_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(value_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
(softmax): Softmax(dim=-1)
) model[2][0]: Classifier(
(logits): Linear(in_features=3712, out_features=5, bias=True)
) model[2][1]: Classifier(
(logits): Linear(in_features=3712, out_features=5, bias=True)
)
input torch.Size([128, 8960])
Traceback (most recent call last):
File "train_CD.py", line 112, in
main_train_cd()
File "train_CD.py", line 99, in main_train_cd
target_model = cross_domain_train(src_train_dl, trg_train_dl, trg_valid_dl,
File "/mnt/public/home/leichen01/ADAST/ADAST-main/trainer/ADAST.py", line 51, in cross_domain_train
val_self_training((feature_extractor, trg_att, (classifier_1, classifier_2)), trg_train_dl, device, src_id,
File "/mnt/public/home/leichen01/ADAST/ADAST-main/trainer/training_evaluation.py", line 31, in val_self_training
predictions = model[2]0
File "/ADAST/ADAST-main/models/models.py", line 81, in forward
logits = self.logits(input)
File "torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x8960 and 3712x5)
input torch.Size([128, 8960])
input torch.Size([128, 8960])
features torch.Size([37, 8960])
model[1]: Self_Attn(
(query_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(key_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
(value_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
(softmax): Softmax(dim=-1)
) model[2][0]: Classifier(
(logits): Linear(in_features=8960, out_features=5, bias=True)
) model[2][1]: Classifier(
(logits): Linear(in_features=8960, out_features=5, bias=True)
)
input torch.Size([37, 8960])
input torch.Size([37, 8960])
feature_extractor src_feat torch.Size([128, 128, 29])
att src_feat torch.Size([128, 3712])
input torch.Size([128, 3712])
Traceback (most recent call last):
File "train_CD.py", line 112, in
main_train_cd()
File "train_CD.py", line 99, in main_train_cd
target_model = cross_domain_train(src_train_dl, trg_train_dl, trg_valid_dl,
File "ADAST/ADAST-main/trainer/ADAST.py", line 89, in cross_domain_train
src_pred = classifier_1(src_feat)
File "ADAST-main/models/models.py", line 81, in forward
logits = self.logits(input)
File "/mnt/public/home/leichen01/miniconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x3712 and 8960x5)
the code
i used in
def cross_domain_train(src_train_dl, trg_train_dl, trg_valid_dl,
src_id, trg_id,
device, logger, configs, args, param_config):
........
# training..
for epoch in range(1, configs.num_epoch + 1):
joint_loaders = enumerate(zip(src_train_dl, pseudo_trg_train_dl))
feature_extractor.train()
classifier_1.train()
classifier_2.train()
src_att.train()
trg_att.train()
feature_discriminator.train()
for step, ((src_data, src_labels), (trg_data, pseudo_trg_labels)) in joint_loaders:
src_data, src_labels, trg_data, pseudo_trg_labels = src_data.float().to(device), src_labels.long().to(
device), trg_data.float().to(device), pseudo_trg_labels.long().to(device)
for param in feature_discriminator.parameters():
param.requires_grad = True
# pass data through the source model network.
src_feat = feature_extractor(src_data)
print("feature_extractor src_feat",src_feat.shape)
src_feat = src_att(src_feat)
print("att src_feat",src_feat.shape)
src_pred = classifier_1(src_feat)
src_pred_2 = classifier_2(src_feat)
class Classifier(nn.Module):
def __init__(self, configs):
super(Classifier, self).__init__()
model_output_dim = configs.features_len
print("model_output_dim",model_output_dim,"configs.final_out_channels",configs.final_out_channels)
self.logits = nn.Linear(model_output_dim * configs.final_out_channels, configs.num_classes)
def forward(self, input):
print("input",input.shape)
logits = self.logits(input)
return logits
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.