Comments (5)
emm,这是因为你的模型没法过jit.trace
,底层原因是因为你给定的路径下的LayerNorm2d
用了torch.autograd.Function
来实现,这个需要自己来实现模型的转换映射逻辑。当然还有个简单的方案就是把他给换回普通的torch.nn.LayerNorm
或者torch.nn.functional.layer_norm
.
import torch
import torch.nn as nn
import torch.nn.functional as F
def new_layer_norm(self, x):
normalized_shape = x.shape[1:]
return F.layer_norm(x, normalized_shape, self.weight.view(-1,1,1).expand(normalized_shape), self.bias.view(-1,1,1).expand(normalized_shape), self.eps)
def patch_layer_norm(model):
for name, module in model.named_modules():
if type(module).__name__ == 'LayerNorm2d':
module.forward = new_layer_norm.get(module, type(module))
patch_layer_norm(model)
from tinyneuralnetwork.
确实我发现是trace的问题, 我也打算用pytorch.ln试试, 但是还是不死心想问问能不能支持下这个 :)
from tinyneuralnetwork.
@YilanWang 如果你要走前面那条路也行的,需要实现LayerNormFunction的symbolic方法,参见https://pytorch.org/docs/master/onnx.html#static-symbolic-method 以及 https://pytorch.org/docs/master/onnx.html#c-operators ,就是需要把g.Op里面的方法名换成aten::layer_norm,然后参数前面按照这个函数的方式来传。
from tinyneuralnetwork.
@YilanWang 深入研究了一下,发现这个确实是可以支持的,当然还有个问题就是这个graph是不能save的,所以需要先trace再convert,这样可以跳过save这一步。
from tinyneuralnetwork.
@YilanWang 结合 #293 应该已经可以了,试下下面的代码
model.eval()
graph = torch.jit.trace(model, dummy_input)
converter = TFLiteConverter(graph, dummy_input, ...)
converter.convert()
from tinyneuralnetwork.
Related Issues (20)
- Whether to support pytorch to keras HOT 1
- TransposeConv wrong shape? HOT 15
- change input to INT8 after converting to tflite HOT 2
- [converter] implement torch's `aten::scaled_dot_product_attention` operator HOT 2
- Request: clamp would be more efficient to go to Bounded Relu than Maximum + Minimum HOT 3
- Do not support PReLU module? HOT 5
- torch.max not working HOT 2
- OneShotChannelPruner results in the miss of some operators HOT 4
- KeyError when executing quantization HOT 5
- PyTorch 转 TFLite 使用 int8 量化 HOT 4
- Does tinynn support following int16 quantization? HOT 1
- jit.trace succeed but tinynn tracer failed HOT 1
- It became larger after converting to tflite model HOT 4
- how to do Post-training integer quantization with int16 activation HOT 4
- unnecessary float() variables cause quantization to fail. HOT 7
- aten::index nodes take multiple indices in PyTorch model but cause an error when trying to convert to TFLite HOT 1
- aten::repeat_interleave is considered an unsupported Tensor and causing an error when trying to convert to TFLite HOT 2
- 请问下 转tflite 模型能支持签名吗? HOT 9
- [PTQ Converter] 'Linear+relu' module conversion failed
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tinyneuralnetwork.