This is just a guess, but are you by any chance processing each input image (or alternatively post-processing detections) of the batch separately inside of a for-loop? If yes, your behaviour might be due to how torch exports to ONNX, and you will need to modify your forward pass. Or, alternatively you can use torch.jit.script.
Where forwad pass could go wrong
Check your model for anything that defines a dimension of a tensor that is interpreted as a python integer during export. Setting dynamic axes will try to use variable shapes for the corresponding tensors, but will be overridden by explicit constant ones.
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
bs, c, h, w = batch.shape
# bs is saved as a constant integer during export
for i in range(bs):
do_something()
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
# iterating over tensors is not supported for dynamic batch sizes
# ONNX model will iterate the same amount as in batch during export
for i in batch:
do_something()
Potential fixes
Use tensor.size instead of tensor.shape
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
def forward(self, batch):
# This calls a function instead of getting an attribute,
# the variable will be dynamic
bs = batch.size(0)
for i in range(bs):
do_something()
Script parts of the model to preserve control flows and different input sizes
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
# Script parts of the forward pass, e.g. single functions
@torch.jit._script_if_tracing
def do_something(batch):
for i in batch:
do_something_else()
def forward(self, batch):
# function will be scripted, dynamic shapes preserved
do_something(batch)
Export the whole module as a ScriptModule, preserving all control flows
and input sizes
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
script_module = torch.jit.script(model)
torch.onnx.export(
script_module,
...
)