Comments (6)
torch.onnx.export
is based on TorchScript technilogy (aka torch.jit.script
) which is not compatible with Dynamo by default
Instead, try using torch.onnx.dynamo_export
to export models to ONNX based on either torch.nn.Module or Dynamo's graphs
Refer to the tutorial to get started at https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html
from pytorch.
@thiagocrepaldi thanks. i did try with dynamo_export too and almost positive i ran accross the same exact issue. i'll double check. nor doees this tutorial help at all with getting it to work with autograd. i can get torch.onnx.export
working fine with dynamo, but without autograd
from pytorch.
I would be extremely grateful if you could point to an example of using aot_module
with torch.onnx.dynamo_export
. I've searched a lot and couldn't find one example anywhere. My assumption is that it's because it doesn't work. :P
from pytorch.
@xadupre and @wschin can help you in more depth, as they are the maintainers for the onnx backend on torch.compile
an untested repro would be something like
import torch
class Linear(torch.nn.Module):
def __init__(self):
super(Linear, self).__init__()
self.linear = torch.nn.Linear(128, 10)
self.activation = torch.nn.ReLU()
def forward(self, *inputs):
input = self.linear(inputs[0])
input = self.activation(input)
return input
model = Linear()
model.train()
loss_fn = torch.nn.MSELoss()
input = torch.randn((64, 128), requires_grad=True)
labels = torch.randn((64, 10), requires_grad=True)
compiled_model = torch.compile(model, backend="onnxrt")
output = compiled_model(*input)
loss = loss_fn(output, labels)
loss.backward()
from pytorch.
Refer to this https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
from pytorch.
@thiagocrepaldi thanks again. i do know how to use torch.compile
for inference, but i believe your example does not work for backward graph compilation. as i understand, that's what aot autograd is for.
it would be preferable to have something that works directly on module (either nn.Module
or fx.GraphModule
is fine) like aot_module
is supposed to, but at this point i'd take any example that works with onnx (i have yet to see any non trivial examples). for example i've tried defining backward functions and compiling that with both torch.compile
and aot_function
but with no success. i look forward to hear what others suggest.
from pytorch.
Related Issues (20)
- svd_lowrank support for complex-valued matrices HOT 1
- PyTorch gradient checkpointing consumes a lot more memory than JAX
- DISABLED test_correctness_Adadelta_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA) HOT 2
- DISABLED test_some_outputs_dont_require_grad_view (__main__.TestAOTAutograd) HOT 2
- Docker Build Job Failed for CUDA 12.4 in Nightly Build HOT 2
- Data Loader is not working it runs infinitely HOT 1
- DISABLED test_svd_cuda_complex64 (__main__.TestLinalgCUDA) HOT 2
- Compiling GPT2 with tensorrt backend and dynamic shapes results in guard failures. HOT 2
- Fix "failed running function instead of failed running module" when nn module inlining enabled. HOT 1
- DISABLED test_comprehensive___getitem___cuda_int64 (__main__.TestInductorOpInfoCUDA) HOT 1
- RK torch bug fix HOT 1
- [BE] DTensor optimizer to utilize OptimizerInfo Constructs
- Wrapper subclass serialization should preserve all _make_wrapper_subclass() args HOT 3
- [BE] Autoclose issues that haven't been updated in X months HOT 5
- aten.sort inductor support HOT 1
- IValue::deepcopy(GenericDict) returns incorrect result when dictionary values are Tensor views
- Investigate "Expected a value of type 'List[int]' for argument 'stride' but instead found type 'tuple' " when nn module inlining is enabled. HOT 7
- Compiling with Inductor, DDP, and Dynamic Shapes Results in Errors HOT 10
- torch._inductor.config.trace.enabled = True crashes
- DISABLED test_train_parity_2d_transformer_checkpoint_resume (__main__.TestFullyShard2DTraining) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.