Coder Social home page Coder Social logo

zheng-yuwei / yolov2-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
3.0 3.0 0.0 27.98 MB

tf-keras-implemented YOLOv2

Home Page: https://zheng-yuwei.github.io/2018/10/03/4_YOLOv2/

License: MIT License

Python 100.00%
yolov2 tensorflow tensorflow-keras python3 resnet-18 resnet-v2 mixnet resnext radam

yolov2-tensorflow's Introduction

YOLOv2-tensorflow

基于tf.keras,实现YOLOv2模型。

本项目相比其他YOLO v2项目的特色

与所有YOLO v2项目相比:

  1. 使用tf.data.Dataset读取数据,tf.keras构造模型,简单易懂,同时易于多GPU训练、模型转换等操作;
  2. 全中文详细代码注释,算法理解等说明;
  3. 自由开启/关闭train-from-scratch的预测框校正功能。

如何使用

取coco数据集中的20张图片做训练,测试效果如下,更多结果可查看dataset/test_result*

测试结果

快速上手

  1. 制作数据集label.txt,一行为image_path x0 y0 w0 h0 cls0 x1 y1 x1 h1 cls1 ..., 其中xywh为待检测目标的bounding box中心点坐标和宽高相对于原图的比例(归一化了),cls为类别;
  2. 实际用自己的数据训练时,可能需要执行以下utils/check_label_file.py,确保标签文件中的图片真实可用;
  3. 修改并运行utils/anchors/kmeans_anchors.py,聚类预定义anchors;
  4. run.py同目录下新建 logs文件夹,存放日志文件;训练完毕会出现models文件夹,存放模型;
  5. 查看configs.py并进行修改,此为参数配置文件;
  6. 执行python run.py,会根据配置文件configs.py进行训练/测试/模型转换等(需要注意我设置了随机种子)。

anchor聚类图

不同聚类中心下,待检测目标与归属anchor的IOU-样本比例的ROC曲线

以上的IOU-Ratio曲线需要从右往左看,表示随着与聚类中心IOU越小,类内label框的占比比例。

学习掌握

  1. 先看README.md;
  2. 再看1_learning_note下的note;
  3. multi_label下的trainer.py里的__init__函数,把整体模型串起来;
  4. run.py文件,结合着看configs.py

目录结构

  • A_learning_notes: README后,先查看本部分了解本项目大致结构;
  • backbone: 模型的骨干网络脚本,basic_backbone.py包含了基类BasicBackbone, 实现了5个类型的骨干网络:resnet-18, resnet-18-v2, mobilenet-v2, mixnet-18, resnext-18; 其中,前三个网络基本遵照原始网络结构,后两个是借鉴了对应网络的**,在resnet-18基础上改写;
  • dataset: 数据集构造脚本;
    • dataset_util.py: 使用tf.image API进行图像数据增强,然后用tf.data进行数据集构建;
    • file_util.py: 以txt标签文件的形式,构造tf.data数据集用于训练;
  • images: 项目图片;
  • logs: 存放训练过程中的日志文件和tensorboard文件(当前可能不存在);
  • models: 存放训练好的模型文件(当前可能不存在);
  • utils: 一些工具脚本;
    • anchors: 通过k-means聚类计算得到预定义anchors;
    • check_label_file.py: 在训练前检查训练集,确保标签文件中的图片真实可用;
    • logger.py:构造文件和控制台日志句柄;
    • logger_callback.py: 日志打印的keras回调函数;
    • radam.py: RAdam算法的tf.keras优化器实现;
  • yolov2: yolov2模型构建脚本;
    • train.py: 模型训练接口,集成模型构建/编译/训练/debug/预测、数据集构建等功能;
    • yolov2_decoder.py: 对YOLO v2模型的预测输出进行解码;
    • yolov2_trainer.py: 构造YOLO v2检测器模型;
    • yolov2_loss.py: YOLO v2的损失函数;
    • yolov3_post_process.py:YOLO v2后处理,预测和测试的时候用。
  • configs.py: 配置文件;
  • run.py: 启动脚本;

代码库特别说明

标签文件格式说明

标签文件格式内容为:

image_path x0 y0 w0 h0 cls0 ...

其中,image_path是图片相对路径,会拼接上configs.py中的FLAGS.train_set_dir(测试的话则是FLAGS.test_set_dir); x0 y0 w0 h0是归一化后的待检测物品中心点坐标、宽高,归一化也就是 实际尺寸/图片尺寸; cls0是图片类别,即使是单类别且不计算类别损失,该位也必须存在(可以任意值)。 前者是多类别的目标检测,后者主要是单类别的目标检测。 后续省略号表示多个待检测对象的标签x0 y0 w0 h0 cls0

算法说明

MixNet的理解

MixNet是Google在轻量级网络结构上探索的又一成果。

2019-05 Google将NAS用到了轻量级网络结构的搜索上,得到MnasNet,也就是MobileNet v3 (Searching for MobileNetV3)。 论文中的启示可能有:

  1. 沿用了MobileNet v2的基本结构块:x6通道数的1x1卷积,步长为2的3x3 depthwise卷积,/2通道数的1x1线性卷积;
  2. 基于squeeze and excitation结构的轻量级注意力模型;
  3. 使用 h-swish激活函数(hard swish):(x * ReLU6(x+3)) / 6;
  4. 手工微调网络开端和结尾这两个开销比较大的部分。

2019-07 针对kernel size的影响进行了系统性研究,观察 multiple kernel sizes 的融合可以带来精度和效率上的提升, 也就是 mixed depthwise convolution (MixConv, 论文 MixConv: Mixed Depthwise Convolutional Kernels)。 然后将其集成到AutoML的搜索空间中,最终得到MixNets轻量网络结构体系(论文结果优于MobileNet v3)。 启示可能包括(沿用MobileNet v2/v3的基础模块结构):

  1. 加上了分组卷积的思维:在前后两个1x1卷积中进行了分组操作;
  2. 将中间的3x3卷积模块替换为MixConv:有[16, 8, 4, 4]比例的4组[3, 5, 7, 9]卷积核尺寸;
  3. 若卷积核太大,计算量不太可承受,可考虑使用dilated convolution。

我在backbone中实现的MixNet并不是论文中的网络结构,而是使用了MixConv不同卷积核尺寸的**构造的网络。

TODO

  • RAdam;
  • 多尺度输入;
  • mixup;
  • focal loss;
  • GHM损失函数;
  • GIOU;
  • TIOU-Recall;
  • Guassian YOLO;
  • 模型测试,计算mAP;Cartucho/mAP

yolov2-tensorflow's People

Contributors

zheng-yuwei avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

yolov2-tensorflow's Issues

关于tf.data的padded_batch使用问题

label_set = label_set.padded_batch(batch_size, (tf.TensorShape([None])), padding_values=-1.)

hi,请教一个问题,在读取数据的时候,针对objs不定长问题用了padding,但是在处理数据的时候如何操作呢?逐个batch拆分开处理吗?您的代码中相关操作的位置在哪里呢?谢谢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.