Coder Social home page Coder Social logo

unet-pytorch's Introduction

Unet:U-Net: Convolutional Networks for Biomedical Image Segmentation目标检测模型在Pytorch当中的实现


目录

  1. 仓库更新 Top News
  2. 相关仓库 Related code
  3. 性能情况 Performance
  4. 所需环境 Environment
  5. 文件下载 Download
  6. 训练步骤 How2train
  7. 预测步骤 How2predict
  8. 评估步骤 miou
  9. 参考资料 Reference

Top News

2022-03:进行大幅度更新、支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/unet-pytorch/tree/bilibili

2020-08:创建仓库、支持多backbone、支持数据miou评估、标注数据处理、大量注释等。

相关仓库

模型 路径
Unet https://github.com/bubbliiiing/unet-pytorch
PSPnet https://github.com/bubbliiiing/pspnet-pytorch
deeplabv3+ https://github.com/bubbliiiing/deeplabv3-plus-pytorch

性能情况

unet并不适合VOC此类数据集,其更适合特征少,需要浅层特征的医药数据集之类的。

训练数据集 权值文件名称 测试数据集 输入图片大小 mIOU
VOC12+SBD unet_vgg_voc.pth VOC-Val12 512x512 58.78
VOC12+SBD unet_resnet_voc.pth VOC-Val12 512x512 67.53

所需环境

torch==1.2.0
torchvision==0.4.0

文件下载

训练所需的权值可在百度网盘中下载。
链接: https://pan.baidu.com/s/1A22fC5cPRb74gqrpq7O9-A
提取码: 6n2c

VOC拓展数据集的百度网盘如下:
链接: https://pan.baidu.com/s/1vkk3lMheUm6IjTXznlg7Ng
提取码: 44mk

训练步骤

一、训练voc数据集

1、将我提供的voc数据集放入VOCdevkit中(无需运行voc_annotation.py)。
2、运行train.py进行训练,默认参数已经对应voc数据集所需要的参数了。

二、训练自己的数据集

1、本文使用VOC格式进行训练。
2、训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。
3、训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
4、在训练前利用voc_annotation.py文件生成对应的txt。
5、注意修改train.py的num_classes为分类个数+1。
6、运行train.py即可开始训练。

三、训练医药数据集

1、下载VGG的预训练权重到model_data下面。
2、按照默认参数运行train_medical.py即可开始训练。

预测步骤

一、使用预训练权重

a、VOC预训练权重
  1. 下载完库后解压,如果想要利用voc训练好的权重进行预测,在百度网盘或者release下载权值,放入model_data,运行即可预测。
img/street.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。
b、医药预训练权重
  1. 下载完库后解压,如果想要利用医药数据集训练好的权重进行预测,在百度网盘或者release下载权值,放入model_data,修改unet.py中的model_path和num_classes;
_defaults = {
    #-------------------------------------------------------------------#
    #   model_path指向logs文件夹下的权值文件
    #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
    #   验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。
    #-------------------------------------------------------------------#
    "model_path"    : 'model_data/unet_vgg_medical.pth',
    #--------------------------------#
    #   所需要区分的类的个数+1
    #--------------------------------#
    "num_classes"   : 2,
    #--------------------------------#
    #   所使用的的主干网络:vgg、resnet50   
    #--------------------------------#
    "backbone"      : "vgg",
    #--------------------------------#
    #   输入图片的大小
    #--------------------------------#
    "input_shape"   : [512, 512],
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"         : True,
    #--------------------------------#
    #   是否使用Cuda
    #   没有GPU可以设置成False
    #--------------------------------#
    "cuda"          : True,
}
  1. 运行即可预测。
img/cell.png

二、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在unet.py文件里面,在如下部分修改model_path、backbone和num_classes使其对应训练好的文件;model_path对应logs文件夹下面的权值文件
_defaults = {
    #-------------------------------------------------------------------#
    #   model_path指向logs文件夹下的权值文件
    #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
    #   验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。
    #-------------------------------------------------------------------#
    "model_path"    : 'model_data/unet_vgg_voc.pth',
    #--------------------------------#
    #   所需要区分的类的个数+1
    #--------------------------------#
    "num_classes"   : 21,
    #--------------------------------#
    #   所使用的的主干网络:vgg、resnet50   
    #--------------------------------#
    "backbone"      : "vgg",
    #--------------------------------#
    #   输入图片的大小
    #--------------------------------#
    "input_shape"   : [512, 512],
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"         : True,
    #--------------------------------#
    #   是否使用Cuda
    #   没有GPU可以设置成False
    #--------------------------------#
    "cuda"          : True,
}
  1. 运行predict.py,输入
img/street.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

评估步骤

1、设置get_miou.py里面的num_classes为预测的类的数量加1。
2、设置get_miou.py里面的name_classes为需要去区分的类别。
3、运行get_miou.py即可获得miou大小。

Reference

https://github.com/ggyyzm/pytorch_segmentation
https://github.com/bonlime/keras-deeplab-v3-plus

unet-pytorch's People

Contributors

bubbliiiing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

unet-pytorch's Issues

训练一段时间后,CE loss变为NAN

您好,看了您的教程我试着自己搭建了一个U-Net模型,并采用Dice + CE loss作为损失函数,但在迭代几十个epoch后,我的CE loss返回了NAN值,反馈的结果是 ‘Function 'LogSoftmaxBackward' returned nan values in its 0th output.’ 同样的数据在您源码上运行没有出现这个问题,请问您是否知道些解决方法?

模型结构

可以直接将本项目中经过修改的unet结构换成原生的unet结构吗

预测结果问题

up你好,我现在用的是resnet50为主干网络,输入resize为512512,输出为256256.但是输出预测标签于真实标签少一块,请问是什么原因呢?我将输入resize为256*256,输出尺寸不变还是有这个问题?
预测:
J114`TKENH 10LU{8WFLL
原图:
38IRV$5Z}S7G9LTGVZ)D5N2

各种指标震荡

请问这种情况应该怎么办? loss也是类似的情况,一直在震荡,但总体趋势在下降。
epoch_miou

泡泡哥请问一下制作数据集的问题

可以看到视频用labelme标出来的是彩色图,猫狗分别是红色和绿色,但是训练的时候图又变成了二值图,中间是不是缺了什么流程!我用3060跑上面说的猫狗红色和绿色的数据集根本用不了,哭死

博主您好,我做二分类使用您的train_Medical训练,效果很差

我想问一下,你的数据medical数据预处理modify_png[png <= 127.5] = 1,用自己的数据集这里需要修改一下吗:
`pg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)

    jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
    png         = np.array(png)
    #-------------------------------------------------------#
    #   这里的标签处理方式和普通voc的处理方式不同
    #   将小于127.5的像素点设置为目标像素点。
    #-------------------------------------------------------#
    modify_png  = np.zeros_like(png)
    modify_png[png <= 127.5] = 1
    seg_labels  = modify_png
    seg_labels  = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])]
    seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))

`

loss很低但没有效果

421f9be8e6413ae380328bd26f7937b
博主您好,我问下用您的模型训练那个医学分割的图像,loss训练到很低,但是出来的图没有效果是为什么呢。

预测出来的图片是灰的

按照b站的教学来的,训练的自己的图片,为什么预测出来的照片,就像是加了一层灰色蒙版似的,没有识别出来?

Batch Normalization

博主,我想问一下我在网络中增加了batchnorm,它的均值和方差是按照一个batch来计算的。但是我看在计算miou时,图片是一张一张做预测的,这个会有影响吗?

DDP评估

up主你好!
感谢您分享的工作!
有一个问题想跟您讨论下,我看您的代码里面可以采用DDP模式进行训练,但是我看你在评估的时候没有用到dist.all_gather(tensor_list, tensor)等相关代码通过DDP的方式进行评估,想问下只用local_rank==0的最终权重来进行评估和通过DDP的方式进行评估,两者是一样的是吗?
谢谢

运行很慢1fps

我直接运行预训练权重,GTX960,只有1-2fps,这是unet本省就慢吗

Intel MKL ERROR

出现Intel MKL ERROR:Parameter 6 was incorrect on entry to DGELSD.这个错误,请问是什么情况啊,一直解决不了。是不是我的模型问题呢?

报错

博主您好!从b站来的,我在train.py 的from utils.dataloader import yolo_dataset_collate, YoloDataset这里就报错No module named 'utils.dataloader'在百度上也搜不到所以来麻烦你,不知道怎么解决,跪求解答

utils.py中的cvtColor函数

if len(np.shape(image)) == 3 and np.shape(image)[-2] == 3:
这个if语句的第二个条件是不是写错了?
我觉得应该是np.shape(image)[2],不是-2.
图像的shape应该是(h,w, c)
-2的话是指w(宽),2才是指c(通道)。
是不是?难道我理解错了?

过拟合

博主,我用你的代码和VOC数据集,主干网络为vgg,从0开始训练,为什么会过拟合啊,trainloss下降,但是valloss已经基本不变了,而miou也没有你训练出来的那么高,我的只有40左右。

请问一下UP,为什么我predict.py出来的结果都是黑色的?

我看了之前其他小伙伴的提问,我更改了lr epoch但是仍然不行。
我输入的图片和预测的图片都是512.512.3的,训练的标签是单通道的(8位图),这个问题困扰我一个星期了,我看在B站也有很多小伙伴有同样的遭遇.....

期待并感谢您的回答。

训练自己的数据集问题

博主你好,已经可以正常训练并进行预测,用的Vgg,在转成ONNX后,在cpu上推理预测,速度很慢,输入图像是30722048,shape是512512,推理时间在2.2S左右,请问这是代码问题还是模型问题呢?

fit_one_epoch

有个疑惑为什么fit_one_epoch传两个模型参数,model_train用作训练,model用来保存

预训练权重

打扰一下,请问有unet(res50)在imagenet上的预训练权重吗。“unet_resnet50_voc.pth”是直接用unet在voc上训练好的权重吗?还有“resnet50_19c8e573.pth”是什么权重呢?

关于训练自己的数据集的问题

我仔细看了作者的代码和VOC数据集的格式,发现VOC数据集训练集和验证集mask图像的像素值就是它的分类值,比如这个物体属于第五类那它的像素值就是(5,5,5),我用的数据集mask图像的像素值是它对应的颜色,直接跑程序的话,miou和mPA除了背景其他都是0,我把训练集和验证集mask图像的像素值改成了对应的分类值,图片就变成全黑的了,训练之后miou和mPA也是除了背景其他都是0,应该怎么解决呢?

关于训练数据集的问题

直接训练16bit的数据集会报错,请问在不改变数据集的情况下应该修改代码的那一部分呢?

dice loss 的一个小疑问

博主您好,非常感谢您的开源代码,我在数据读取的地方有一个小疑问,就是dice loss对应的标签seg_labels的问题。seg_labels对应的是one-shot形式,其中num_classes已经是加一的了,为什么该部分还要再加一呢?
如能回答,万分感谢呀,这地方没弄懂。

预测图像全黑问题

您好,我在用您的网络训练遥感图像数据集VOCdevkit时,进行预测得到的结果是全黑的(如下图)
2023-03-17 21-32-44屏幕截图
,我将待预测的图像改为位深为8的灰度图后依然为全黑,将训练数据改为位深为8的灰度图重新训练再预测依然为全黑请问是为什么呢?

我要分两类,数据标签已透过np.array检查过了「背景为0,目标为1」,但训练完后预测没有结果

博主您好,您的影片我都看过了,也照着您的方式进行操作,且我第一次训练后预测有成功(我是训练水槽水面变化),但当我重新再做一次一模一样的数据时,参数也一样,就一直没结果,即使重新安装环境及下载Github代码都不行。

后来我把之前训练的权重套入,发现应该是训练没训练好,请问博主知道是什么问题?
7 mp4_20220815172942_0825

裁剪预测图

你好,请问有没有crop检测结果的代码在里面或者大概怎么实现?
比如,我想用模型把舌头裁剪出:

询问一下预训练的问题

你好,打扰了。我是想问下主干模型是指的是在下采样过程中使用的vgg吗?如果我不改变上采样是不是就不用使用imagenet训练。然后注销掉model_path=‘’ 以及 if model_path !=‘’这段。然后使用自己的数据集去进行训练。 谢谢大佬!!!!!!。实际上大佬你的voc的权重文件是不是为二次预训练的数据。
不好意思,语言表达能力不行。俺不晓得这样说大佬明不明白。

网络结构

泡泡老师你好,您的这篇代码中的unet是原生的unet吗?我读了您的代码发现encoder是用了resnet50,比unet结构图要复杂一些。

from tqdm import tqdm 报错

import os
import time

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

如何修改输出热力图

你好,泡泡导师,请问下这个框架能不能输出网络预测的热力特征图呢?我想自己改进网络结构,然后通过热力图来可视化网络学到的东西,但是不知道怎么改,不知道导师有空能不能加上这个功能呢?感激不尽!!!

关于训练自己的数据集

请问一下大佬,如果训练自己的数据集,是不是dataloader返回的jpg和png就够用了,labels是针对voc数据集的白边的,再用diceloss的时候把参数里的label换成png?

运行中断

D:\xuexi\anaconda\envs\unet\python.exe D:/unet-pytorch-main/train.py
initialize network with normal type
Load weights model_data/unet_vgg_voc.pth.
Start Train
Epoch 1/50: 0%| | 0/27 [00:00<?, ?it/s<class 'dict'>]
Process finished with exit code -1073741819 (0xC0000005)
刚运行就中断了,请问怎么回事

Loss

不使用预训练模型,用ECSSD重新训练主干网络,数据集中有1000张图片,但每次epoch有450张,并且损失一直为0,期待您的回复

predict.py问题

predict.py文件和博主视频上的不一样啊?请博主回答一下,谢谢!

loss计算

导师,我的数据集缺陷数据集,就两个类别,前景(缺陷),背景,label的像素点是0和1,在数据增强的时候对label用0填充
image
数据增强以后,我的label的像素点还是0和1啊,那下面代码的意义何在?
image
如此,在计算损失的时候,我的num_class=2,ignore_index= 2就没有任何作用啊,计算的损失还是包含了我数据增强填充的0
image

很想请教这个损失问题,谢谢UP

为啥在dataloader第40行转换的array的shape和cv2不一样呢

我使用json_to_dataset.py转化mask后尝试使用代码查看shape
import cv2
import numpy as np
from PIL import Image

file = '/home/fut/Downloads/unet-pytorch-main/mydata/masks/ID_1110_json.png'
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
print(img.shape)

pil = Image.open(file)
img2 = np.array(pil)
print(img2.shape)
结果会是:
(800, 800, 3)
(800, 800)
为什么PIL读取后通道就没了,正是因为这个原因你的项目会很好跑起来。

医药数据集预测时的默认设置是不是要改成如下参数

_defaults = {
    "model_path"        : 'model_data/unet_medical.pth',
    "model_image_size"  : (256, 256, 3),   #cell.png
    "num_classes"       : 2,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.