Coder Social home page Coder Social logo

怎样转化为onnx模型 about iim HOT 3 OPEN

taohan10200 avatar taohan10200 commented on July 28, 2024
怎样转化为onnx模型

from iim.

Comments (3)

taohan10200 avatar taohan10200 commented on July 28, 2024

我们提供的是pytorch保存的模型,onnx格式的模型可自通过我们开源的模型参数自行转换

from iim.

csz-006 avatar csz-006 commented on July 28, 2024

请问你转换成功了吗,可以看下转onnx的代码嘛

from iim.

csz-006 avatar csz-006 commented on July 28, 2024

这个是转换为onnx的代码,需要注意输入为13h*w
import os
from tkinter.messagebox import NO
import torch
import torch.onnx
import torch.nn as nn
import onnxruntime as ort
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import onnx
from onnxsim import simplify
from model.locator import Crowd_locator
from collections import OrderedDict

# os.environ['CUDA_VISIBLE_DEVICES']= '1'

GPU_ID = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID
torch.backends.cudnn.benchmark = True

def onnx_export(model_path):
net = Crowd_locator(netName,GPU_ID,pretrained=False)
net.cuda()
state_dict = torch.load(model_path)
if len(GPU_ID.split(',')) > 1:
net.load_state_dict(state_dict)
else:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
net.eval()
# 打印模型的每一层名称
for name, module in net.named_modules():
print(name,'nnnnn')
dummpy_input = torch.zeros(1, 3, 512, 1024).cuda() # 640 640
# dummpy_input = torch.zeros(1, 3, 512, 1024).cuda()
onnx_name = 'HRnet_Crowd_count_512_1024_opset12.onnx'
# net = net(dummpy_input)
torch.onnx.export(
net, dummpy_input, onnx_name,
verbose=True,
input_names=['image'],
output_names=['predict'],
opset_version=12,
dynamic_axes=None
)

def onnx_sim(onnx_path):
model_onnx = onnx.load_model(onnx_path)
model_smi, check = simplify(model_onnx)
save_path = 'HRnet_Crowd_count_512_1024_opset12-sim.onnx'
onnx.save(model_smi, save_path)
print('模型静态图简化完成')

if name == 'main':
netName = 'HR_Net' # VGG16_FPN HR_Net
model_path = '/IIM/Preweights/NWPU-HR-ep_241_F1_0.802_Pre_0.841_Rec_0.766_mae_55.6_mse_330.9.pth'

onnx_path = '/IIM/Preweights/1024_HRnet_Crowd_count_512_1024_opset12.onnx'
# save_model(pth_file)
onnx_export(model_path)
# onnx_sim(onnx_path)
print('Done')

from iim.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.