Coder Social home page Coder Social logo

clnet's Introduction

Overview

This is the PyTorch implementation of the paper CLNet: Complex Input Lightweight Neural Network designed for Massive MIMO CSI Feedback. If you feel this repo helpful, please cite our paper:

@article{ji2021clnet,
  title={CLNet: Complex Input Lightweight Neural Network designed for Massive MIMO CSI Feedback},
  author={Ji, Sijie and Li, Mo},
  journal={IEEE Wireless Communications Letters},
  year={2021},
  publisher={IEEE}
  doi={10.1109/LWC.2021.3100493}}
}

Requirements

To use this project, you need to ensure the following requirements are installed.

Project Preparation

A. Data Preparation

The channel state information (CSI) matrix is generated from COST2100 model. Chao-Kai Wen and Shi Jin group provides a pre-processed version of COST2100 dataset in Google Drive, which is easier to use for the CSI feedback task; You can also download it from Baidu Netdisk.

You can generate your own dataset according to the open source library of COST2100 as well. The details of data pre-processing can be found in our paper.

B. Project Tree Arrangement

We recommend you to arrange the project tree as follows.

home
├── CLNet  # The cloned CLNet repository
│   ├── dataset
│   ├── models
│   ├── utils
│   ├── main.py
├── COST2100  # The data folder
│   ├── DATA_Htestin.mat
│   ├── ...
├── Experiments
│   ├── checkpoints  # The checkpoints folder
│   │     ├── in_04.pth
│   │     ├── ...
│   ├── run.sh  # The bash script
...

Train CLNet from Scratch

An example of run.sh is listed below. Simply use it with sh run.sh. It starts to train CLNet from scratch. Change scenario using --scenario and change compression ratio with --cr.

python /home/CLNet/main.py \
  --data-dir '/home/COST2100' \
  --scenario 'in' \
  --epochs 1000 \
  --batch-size 200 \
  --workers 8 \
  --cr 4 \
  --scheduler cosine \
  --gpu 0 \
  2>&1 | tee log.out

Results and Reproduction

A. Model Complexity

The params and flops are directly caculated by thop. If you use this repo's code directly, the model complexity will be printed to the trainning log. A sample training log for your reference. The flops reported in the paper are caculated by fvcore to align with other SOTA works. The fvcore caculator didn't count the BN layer in, therefore it's less than thop.

Compression Ratio #Params Flops
1/4 2102K 4.42M
1/8 1053K 3.37M
1/16 528.7K 2.85M
1/32 266.5K 2.58M
1/64 135.4K 2.45M

B. Performance

The NMSE result reported in the paper as follow:

Scenario Compression Ratio NMSE Checkpoints
indoor 1/4 -29.16 in4.pth
indoor 1/8 -15.60 in8.pth
indoor 1/16 -11.15 in16.pth
indoor 1/32 -8.95 in32.pth
indoor 1/64 -6.34 in64.pth
outdoor 1/4 -12.88 out4.pth
outdoor 1/8 -8.29 out8.pth
outdoor 1/16 -5.56 out16.pth
outdoor 1/32 -3.49 out32.pth
outdoor 1/64 -2.19 out64.pth

If you want to reproduce our result, you can directly download the corresponding checkpoints from Dropbox

To reproduce all these results, simple add --evaluate to run.sh and pick the corresponding pre-trained model with --pretrained. An example is shown as follows.

python /home/CLNet/main.py \
  --data-dir '/home/COST2100' \
  --scenario 'in' \
  --pretrained './checkpoints/in4.pth' \
  --evaluate \
  --batch-size 200 \
  --workers 0 \
  --cr 4 \
  --cpu \
  2>&1 | tee test_log.out

Acknowledgment

This repository is modified from the CRNet open source code. Thanks Zhilin for his amazing work. Thanks Chao-Kai Wen and Shi Jin group for providing the pre-processed COST2100 dataset, you can find their related work named CsiNet in Github-Python_CsiNet

clnet's People

Contributors

sijieji avatar waxpple avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

clnet's Issues

Is sth WRONG?

when I run the code meet the requirement python=3.7.11, pytorch=1.9.0,It can run successfully first 10 epochs,but in val,python error,it print as follows:
I 09.13/13:00 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] Epoch: [10/1000][400/500] lr: 6.53e-04 | MSE loss: 4.142e-04 | time: 0.453
I 09.13/13:00 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] Epoch: [10/1000][420/500] lr: 6.56e-04 | MSE loss: 4.134e-04 | time: 0.452
I 09.13/13:00 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] Epoch: [10/1000][440/500] lr: 6.59e-04 | MSE loss: 4.121e-04 | time: 0.453
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] Epoch: [10/1000][460/500] lr: 6.61e-04 | MSE loss: 4.114e-04 | time: 0.454
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] Epoch: [10/1000][480/500] lr: 6.64e-04 | MSE loss: 4.107e-04 | time: 0.453
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] Epoch: [10/1000][500/500] lr: 6.67e-04 | MSE loss: 4.096e-04 | time: 0.454
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:87 ] => Train Loss: 4.096e-04

I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][20/150] lr: 6.67e-04 | MSE loss: 4.028e-04 | time: 0.210
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][40/150] lr: 6.67e-04 | MSE loss: 4.009e-04 | time: 0.174
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][60/150] lr: 6.67e-04 | MSE loss: 3.993e-04 | time: 0.163
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][80/150] lr: 6.67e-04 | MSE loss: 4.010e-04 | time: 0.160
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][100/150] lr: 6.67e-04 | MSE loss: 4.007e-04 | time: 0.157
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][120/150] lr: 6.67e-04 | MSE loss: 4.008e-04 | time: 0.154
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] Epoch: [10/1000][140/150] lr: 6.67e-04 | MSE loss: 4.006e-04 | time: 0.155
I 09.13/13:01 C:\Users\86158\Desktop\CLNet-main\utils\solver.py:98 ] => Val Loss: 4.004e-04

Traceback (most recent call last):
File "main.py", line 87, in
main()
File "main.py", line 56, in main
trainer.loop(args.epochs, train_loader, val_loader, test_loader)
File "C:\Users\86158\Desktop\CLNet-main\utils\solver.py", line 71, in loop
self.test_loss, rho, nmse = self.test(test_loader)
File "C:\Users\86158\Desktop\CLNet-main\utils\solver.py", line 109, in test
return self.tester(test_loader, verbose=False)
File "C:\Users\86158\Desktop\CLNet-main\utils\solver.py", line 230, in call
loss, rho, nmse = self._iteration(test_data)
File "C:\Users\86158\Desktop\CLNet-main\utils\solver.py", line 250, in _iteration
rho, nmse = evaluator(sparse_pred, sparse_gt, raw_gt)
File "C:\Users\86158\Desktop\CLNet-main\utils\statics.py", line 61, in evaluator
raw_pred = torch.fft(sparse_pred, signal_ndim=1)[:, :, :125, :]
TypeError: 'module' object is not callable

I run it on different computers,the questions are same,and I can't understand what's mean about raw_pred,could you solve the question?thanks

Model Flops

Hello, great work!
I tried to run your code, where the Model Flops is 5.702M. However, it's 4.05M in your paper. Why??? Thank you very much!

what is Rho?

I run this code on my pycharm on Win10, there are some bug occurs.
to fit my gpu cuda version, I install pytorch1.8.0, there are some different function from your original verion.
For example, the fft function
image
I've modified it like the picture shows.
but got a result bellow:
image
rho is nan, I cant understand what the meaning of rho? does it matter when i cant get it?

NMSE计算请教

你好,想请教一下代码中NMSE和rho计算的问题:
(1) 请问在训练中使用的信道应该都是经过二维FFT之后并进行归一化的信道吧?
如果是这样的话,原始的信道应该是需要降信道进行二维的IFFT2,但是您的代码中为什么是进行1维的FFT呢?这里我不太明白。
(2) 如果说FFT和IFFT2是等效的话,那么为什么频域变换成257维,然后又只取其中的125维呢?这里也不太清楚。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.