Coder Social home page Coder Social logo

unet-segmentation's Introduction

UNet with ResNet Backbone

This repository contains an implementation of a UNet model with a ResNet backbone using PyTorch. The model leverages pretrained weights from ResNet to enhance feature extraction in the encoder part of the UNet architecture. Additionally, an ensemble model is provided to combine multiple UNet models for improved performance.

Introduction

UNet is a popular convolutional neural network architecture for biomedical image segmentation. It consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. By incorporating a ResNet backbone, this implementation enhances the feature extraction capabilities of the encoder, potentially improving segmentation performance. Additionally, an ensemble model is provided to combine multiple UNet models for improved performance.

Installation

To get started, clone the repository and install the required dependencies:

git clone https://github.com/gyb357/UNet-Segmentation.git
pip install -r requirements.txt

File Structure

The repository is structured as follows:

.
├── config/
│   └── config.yaml       # Configuration files
├── csv/
│   └── train_logg.csv    # Train logg csv files
├── dataset/
│   ├── image/            # Directory for train image dataset
│   ├── mask/             # Directory for train mask dataset
│   └── test/             # Directory for test image dataset
├── model/
│   ├── checkpoint/       # Directory for checkpoint models
│   └── pretrained/       # Directory for storing pretrained models
├── runs/                 # Directory for tensorboard
│
├── dataset.py            # Script for dataset preparation
├── main.py               # Main script to run the training and evaluation
├── miou.py               # Script to calculate mean Intersection over Union (mIoU)
├── requirement.txt       # Required dependencies
├── resnet.py             # ResNet model definition
├── test.py               # Script for testing/evaluation
├── train.py              # Script for training the model
├── unet.py               # UNet model definition
└── utils.py              # Utility functions

How to use it

  1. Check the config.yaml and adjust the parameters and file addresses of the model.
  2. Place a dataset or prelearning weight in the specified file.
  3. Run main.py to learn and test the model.
  4. In cmd, enter the "tensorboard --logdir= your tensorboard log dir" and access the web browser at "localhost:6006" to check the tensorboard.

Pretrained Weights

Download the Imagenet1K pretrained model and put it in the model/pretrained/. Download link https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html

resnet18 v1: https://download.pytorch.org/models/resnet18-f37072fd.pth
resnet34 v1: https://download.pytorch.org/models/resnet34-b627a593.pth
resnet50 v2: https://download.pytorch.org/models/resnet50-11ad3fa6.pth
resnet101 v2: https://download.pytorch.org/models/resnet101-cd907fc2.pth
resnet152 v2: https://download.pytorch.org/models/resnet152-f82ba261.pth

Pretraining weight models above resnet50 are currently under development and will be developed later.

The train_logg.csv file lets you see how the learning progresses. Also, you can use the tensorboard.

unet-segmentation's People

Contributors

gyb357 avatar

Stargazers

Kevin Ko avatar

Watchers

 avatar

Forkers

gguip1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.