Coder Social home page Coder Social logo

unet3plus's Introduction

UNet3plus

UNet3+

PyTorch

Reference

[2021] PVT V2: Improved Baselines with Pyramid Vision Transformer

[2020] ConvNext: A ConvNet for the 2020s

[2020] UNET 3+: A Full-Scale Connected UNet for Medical Image Segmentation (ICASSP 2020)

Dependencies

  • Python == 3.9
  • PyTorch >= 1.10.2
  • Torchvision >= 0.11.3
  • CUDA >= cu113
  • numpy >= 1.20.3
  • Pillow >= 8.3.1
  • tensorboard >= 2.9.1
  • tqdm == 4.62.3
  • scikit-learn == 1.2.2
  • timm == 0.6.13
  • opencv-python == 4.7.0.72
  • albumentations == 1.3.0
  • segmentation-models-pytorch ==0.3.2
  • pandas == 2.0.1

augmentations

augmentations are tuned for the histopathological images - Equalize, HueSaturationValue, ColorJitter, Blur, RandomBrightnessContrast, ChannelShuffle - Transpose, RandomRotate90, Flip

Training methodlogy

- Mixed Precision 
- Gradient Accumalation
- Label Smoothing
- Gradient Clipping

Output

  • checkpoint.pth.tar
  • bestmodel.pth.tar if best model is found
  • commandlines.txt
  • log.txt
  • tensorboard file to track the training curves
  • train_dice_score .mat and val_dice_score.mat files which consists of path to each image and image mean dice

Usage

from utils.models import Unet3plus, Unet3plusGlcm, Unet3plus_deepsupervision

model = Unet3plus(num_classes = num_classes, encoder = config.encoder)

Run locally

Note : Use Python 3

Training

python train.py 


Train the UNet on images and target masks
Update the config.py file

optional arguments:
	"Segmentation Training"
	num_workers = 8
	epochs = 50
	train_batch = 8
	lr = 0.0005
	weight_decay = 0.0005
	checkpoint = 'exp1'
	resume = '' #path to checkpoint
	gpu = 0
	seed = 42
	clip = None # else 0.99999
	size = 512
	ignore_label = 5 # should keep same as num_classes
	accum_iter=1 #Gradient Accumalation is True if accum_iter>1
	label_sm = 0.08
	freeze_backbone = False
	encoder = 'convnext_tiny'
	num_classes  = 5
	train_image_path = '../dataset/train/images/'
	train_mask_path = '../dataset/train/masks/'
	validation_image_path = '../dataset/val/images/'
	validation_mask_path = '../dataset/val/masks/'              

The input images and target masks should have the same name. The ignore label in mask should have label value = 255

unet3plus's People

Contributors

owais-ansari avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.