Coder Social home page Coder Social logo

zhengyushan / kat Goto Github PK

View Code? Open in Web Editor NEW
29.0 2.0 6.0 285 KB

The code for Kernel attention transformer (KAT)

License: MIT License

Python 100.00%
whole-slide-image histopathology histopathology-images histopathology-wsi wsi wsi-images wsi-classificaiton wsi-representation pathology-image wsi-encode

kat's Introduction

KAT: kernel attention Transformer for histopathology whole slide image classification

This is a PyTorch implementation of the paper KAT:

Data preparation

The structure of the whole slide image dataset to run the code.

# Take a lung cancer dataset collected from TCGA as the example.
./data                                                              # The directory of the data.
├─ TCGA-55-8510-01Z-00-DX1.BB1EAC72-6215-400B-BCBF-E3D51A60182D     # The directory for a slide.
│  ├─ Large                                                         # The directory of image tiles in Level 0 (40X lens).
│  │  ├─ 0000_0000.jpg                                              # The image tile in Row 0 and Column 0.
│  │  ├─ 0000_0001.jpg                                              # The image tile in Row 0 and Column 1.
│  │  └─ ...
│  ├─ Medium                                                        # The directory of image tiles in Level 1 (20X lens).
│  │  ├─ 0000_0000.jpg
│  │  ├─ 0000_0001.jpg
│  │  └─ ...
│  ├─ Small                                                         # The directory of image tiles in Level 2 (10X lens).
│  │  ├─ 0000_0000.jpg
│  │  ├─ 0000_0001.jpg
│  │  └─ ...
│  ├─ Overview                                                      # The directory of image tiles in Level 3 (5X lens).
│  │  ├─ 0000_0000.jpg
│  │  ├─ 0000_0001.jpg
│  │  └─ ...
│  └─ Overview.jpg                                                  # The thumbnail of the WSI in Level 3.     
│     
├─ TCGA-44-3919-01A-01-BS1.9251d6ad-dab8-42fd-836d-1b18e5d2afed
└─ ...

Generate configuration file for the dataset.

python dataset/configure_dataset.py

Train

Run the codes on a single GPU:

CONFIG_FILE='configs/tcga_lung.yaml'
WORKERS=8
GPU=0

python cnn_sample.py --cfg $CONFIG_FILE --num-workers $WORKERS
for((FOLD=0;FOLD<5;FOLD++)); 
do
    python cnn_train_cl.py --cfg $CONFIG_FILE --fold $FOLD\
        --epochs 21 --batch-size 100 --workers $WORKERS\
        --fix-pred-lr --eval-freq 2 --gpu $GPU

    python cnn_wsi_encode.py --cfg $CONFIG_FILE --fold $FOLD\
        --batch-size 512 --num-workers $WORKERS --gpu $GPU

    python kat_train.py --cfg $CONFIG_FILE --fold $FOLD --node-aug\
        --num-epochs 200 --batch-size 32 --num-workers $WORKERS  --weighted-sample\
        --eval-freq 5 --gpu $GPU
done 

Run the codes on multiple GPUs:

CONFIG_FILE='configs/tcga_lung.yaml'
WORKERS=8
WORLD_SIZE=1

python cnn_sample.py --cfg $CONFIG_FILE --num-workers $WORKERS

for((FOLD=0;FOLD<5;FOLD++)); 
do
    python cnn_train_cl.py --cfg $CONFIG_FILE --fold $FOLD\
        --epochs 21 --batch-size 400 workers $WORKERS\
        --fix-pred-lr --eval-freq 2\
        --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size $WORLD_SIZE --rank 0

    python cnn_wsi_encode.py --cfg $CONFIG_FILE --fold $FOLD\
        --batch-size 512 --num-workers $WORKERS\
        --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size $WORLD_SIZE --rank 0

    python kat_train.py --cfg $CONFIG_FILE --fold $FOLD --node-aug\
        --num-epochs 200 --batch-size 128 --num-workers $WORKERS  --weighted-sample --eval-freq 5\
        --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size $WORLD_SIZE --rank 0
done

Train KAT with kernel contrastive learning (KCL)

In our extended work, we built a contrastive presentation learning module to the kernels for better accuracy and generalization.
Run katcl_train.py instead of kat_train.py if you want to use the contrastive learning module.

Run on a single GPU:

python katcl_train.py --cfg $CONFIG_FILE --fold $FOLD \
        --num-epochs 200 --batch-size 32 --num-workers $WORKERS  --weighted-sample\
        --eval-freq 5 --gpu $GPU

Run on on multiple GPUs:

python katcl_train.py --cfg $CONFIG_FILE --fold $FOLD \
        --num-epochs 200 --batch-size 128 --num-workers $WORKERS  --weighted-sample --eval-freq 5\
        --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size $WORLD_SIZE --rank 0

If the code is helpful to your research, please cite:

@inproceedings{zheng2022kernel,
    author    = {Yushan Zheng, Jun Li, Jun Shi, Fengying Xie, Zhiguo Jiang},
    title     = {Kernel Attention Transformer (KAT) for Histopathology Whole Slide Image Classification},
    booktitle = {Medical Image Computing and Computer Assisted Intervention 
                -- MICCAI 2022},
    pages     = {283--292},
    year      = {2022}
}

@article{zheng2023kernel,
    author    = {Yushan Zheng, Jun Li, Jun Shi, Fengying Xie, Jianguo Huai, Ming Cao, Zhiguo Jiang},
    title     = {Kernel Attention Transformer for Histopathology Whole Slide Image Analysis and Assistant Cancer Diagnosis},
    journal   = {IEEE Transactions on Medical Imaging},
    year      = {2023}
}

kat's People

Contributors

zhengyushan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

kat's Issues

the structure of the whole slide image dataset

How can I get the structure of the whole slide image dataset as you prove. Now i have the part of datasets of TCGA, and they are in the format of SVS. How i can get this photo of each patches. Thanks

can't understand the function 'extract_tile' in loader.py

def extract_tile(image_dir, tile_size, x, y, width, height):
x_start_tile = x // tile_size
y_start_tile = y // tile_size
x_end_tile = (x + width) // tile_size
y_end_tile = (y + height) // tile_size

tmp_image = np.ones(
    ((y_end_tile - y_start_tile + 1) * tile_size, (x_end_tile - x_start_tile + 1) * tile_size, 3),
    np.uint8) * 240

for y_id, col in enumerate(range(x_start_tile, x_end_tile + 1)):
    for x_id, row in enumerate(range(y_start_tile, y_end_tile + 1)):
        img_path = os.path.join(image_dir, '{:04d}_{:04d}.jpg'.format(row, col))
        if not os.path.exists(img_path):
            continue
        img = cv2.imread(img_path)
        h, w, _ = img.shape
        tmp_image[(x_id * tile_size):(x_id * tile_size + h), (y_id * tile_size):(y_id * tile_size + w), :] = img

x_off = x % tile_size
y_off = y % tile_size
output = tmp_image[y_off:y_off + height, x_off:x_off + width]

return output

what's the meaning of 'pos[1] * step' in the below function? why you regard it as x and pass it to function extract_file?
def extract_and_save_tiles(image_dir, slide_save_dir, position_list, tile_size,
imsize, step, invert_rgb=False):
for pos in position_list:
img = extract_tile(image_dir, tile_size, pos[1] * step, pos[0] * step,
imsize, imsize)

    if len(img) > 0:
        if invert_rgb:
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(
            os.path.join(slide_save_dir, '{:04d}_{:04d}.jpg'.format(pos[1], pos[0])), img)

Questions about the model FLOP, GPU memory cost, and speed

Hello,
very impressed with your work.

But I have a few small questions about Table1 and Table2, specifically about the model FLOP, GPU memory cost, and speed.

As far as I know, the size of each slide is different, which causes the efficiency of each image calculation to be different.

So my question is, what size slice are you computing on, and how many patches are there?

Looking forward to your reply.

Yours.
pzSuen.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.