Coder Social home page Coder Social logo

lxztju / pytorch_classification Goto Github PK

View Code? Open in Web Editor NEW
1.3K 14.0 338.0 3.13 MB

利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型部署,cnn提取特征,svm或者随机森林等进行分类,模型蒸馏,一个完整的代码

License: MIT License

Python 2.21% Jupyter Notebook 96.71% CMake 0.15% C++ 0.85% Shell 0.07%
pytorch image-classification deployment svm knn cnn label-smoothing densenet resnext resnet

pytorch_classification's Introduction

简介

基于torchision实现的pytorch图像分类功能。

近期更新

  • 2022.11.05更新

    • 新添加tensorrt c++的推理方案
  • 2022.10.29更新,进行代码重构,基本的功能基本一致。

    • 支持pytorch ddp的训练
    • 支持c++ libtorch的模型推理
    • 支持script脚本一键运行
    • 添加日志模块

习惯之前版本的请看v1版本的代码:V1版本

主要功能:

利用pytorch实现图像分类,基于torchision可以扩展使用densenet,resnext,mobilenet,efficientnet,swin transformer等图像分类网络

如果有用欢迎star

实现功能

  • 基础功能利用pytorch实现图像分类
  • 包含带有warmup的cosine学习率调整
  • warmup的step学习率优调整
  • 多模型融合预测,加权与投票融合
  • 利用flask + redis实现模型云端api部署(tag v1)
  • c++ libtorch的模型部署
  • 使用tta测试时增强进行预测(tag v1)
  • 添加label smooth的pytorch实现(标签平滑)(tag v1)
  • 添加使用cnn提取特征,并使用SVM,RF,MLP,KNN等分类器进行分类(tag v1)。
  • 可视化特征层

运行环境

  • python3.7
  • pytorch 1.8.1
  • torchvision 0.9.1
  • opencv(libtorch cpp推理使用, 版本3.4.6)(可选)
  • libtorch cpp推理使用(可选)

快速开始

数据集形式

数据集的组织形式,参考sample_files/imgs/listfile.txt

训练 测试

修改run.sh中的参数,直接运行run.sh即可运行

主要修改的参数:

OUTPUT_PATH 模型保存和log文件的路径

TRAIN_LIST 训练数据集的list文件
VAL_LIST  测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size  batch的大小
j dataloader的num_workers的大小
num_classes 类别数

libtorch inference

代码存储在cpp_inference文件夹中。

  1. 利用cpp_inference/traced_model/trace_model.py将训练好的模型导出。

  2. 编译所需的opencv和libtorch代码到cpp_inference/third_party_library

  3. 编译

sh compile.sh
  1. 可执行文件测试
./bin/imgCls imgpath

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.