Shipping a model to the edge: PyTorch → ONNX → TensorRT
A hands-on, current (2026) path for taking a PyTorch CV model to ONNX and a TensorRT engine: export, parity, FP16/INT8 build, and latency gating.
Shipping a model to the edge: PyTorch to ONNX to TensorRT
When I run a detector or a segmentation net inside a vehicle, the model that scored well on my workstation is not the artifact that ships. What ships is a TensorRT engine: a hardware-specific, quantized, fused binary that hits a fixed latency budget on the exact GPU in the box (a Jetson Orin, a DRIVE platform, or a discrete NVIDIA card behind the head unit). A 30 FPS camera gives me ~33 ms per frame, and that has to cover capture, preprocess, inference, and postprocess. Inference needs to be small.
This is the path I actually use: PyTorch to ONNX to TensorRT, with parity checks and INT8 calibration in between. Everything below is runnable on a CUDA machine.
Versions matter more than usual here
TensorRT is tightly coupled to your CUDA, driver, and (on Jetson) JetPack version. An engine is not portable: built for one GPU arch and TensorRT version, it will refuse to load on another. Build on the target, or on an identical environment.
python -c "import torch; print(torch.__version__)" # 2.6+ for the dynamo exporter
python -c "import tensorrt as trt; print(trt.__version__)"
trtexec --version
nvidia-smi # confirm driver + CUDA match what TensorRT was built against
One caveat up front: TensorRT 11.x removed --int8/--calib from trtexec and moved post-training quantization offline into NVIDIA ModelOpt. The INT8 calibration flow I show below is the TensorRT 10.x path (currently 10.15/10.16), which is still what most edge and JetPack stacks run. The implicit-quantization calibrator APIs are deprecated in 10.x but still functional. If you are on 11.x, you quantize the ONNX with ModelOpt first, then build a plain engine. Check trt.__version__ before you copy commands.
Export to ONNX (use the dynamo exporter)
I’ll use a torchvision model so you can reproduce it, but a custom CV head behaves the same.
# export.py
import torch
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).eval().cuda()
dummy = torch.randn(1, 3, 224, 224, device="cuda")
torch.onnx.export(
model, (dummy,), "resnet50.onnx",
dynamo=True, # recommended path since PyTorch 2.5; opset >= 18
opset_version=18,
input_names=["images"], output_names=["logits"],
dynamic_axes={"images": {0: "batch"}, "logits": {0: "batch"}},
)
Three things bite people here:
- The dynamo exporter does not reliably honor your
input_names/output_names. It often emits autogenerated tensor names regardless of what you pass. Do not assume the names areimages/logitsdownstream. Read them back from the model and use those everywhere (parity, profiles, runtime). Every script below derives names instead of hardcoding them. - Dynamic axes. I make only the batch dim dynamic. The temptation is to mark H and W dynamic too “for flexibility.” Don’t, unless you truly need it. Dynamic spatial dims force TensorRT to build for a shape range, which is slower and sometimes triggers worse kernels. My camera resolution is fixed, so I fix H/W and only vary batch if I fuse multiple cameras.
- Unsupported ops. The dynamo path (torch.export + FX) handles far more than the old TorchScript path, but custom ops, exotic interpolation modes, or some NMS variants still fail. When export errors, the message names the op. Fixes, in order of preference: rewrite the layer with supported ops, move the op into pre/postprocess outside the graph, or write a TensorRT plugin (last resort, real work).
Simplify and re-check the graph, then print the real I/O names:
pip install onnx onnxslim "onnxruntime-gpu>=1.20"
python -c "import onnx, onnxslim; onnx.save(onnxslim.slim(onnx.load('resnet50.onnx')), 'resnet50.onnx')"
python - <<'PY'
import onnx
m = onnx.load("resnet50.onnx"); onnx.checker.check_model(m)
print("inputs :", [i.name for i in m.graph.input])
print("outputs:", [o.name for o in m.graph.output])
PY
Note the printed names. I’ll call them IN and OUT below; substitute the real strings.
Validate parity before you go near TensorRT
If ONNX already disagrees with PyTorch, no amount of TensorRT tuning will save you. Catch it here, and read the input name from the session rather than guessing it.
# parity.py
import numpy as np, torch, onnxruntime as ort
from torchvision.models import resnet50, ResNet50_Weights
x = torch.randn(1, 3, 224, 224)
torch_out = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).eval()(x).detach().numpy()
sess = ort.InferenceSession("resnet50.onnx", providers=["CUDAExecutionProvider"])
in_name = sess.get_inputs()[0].name # do not assume "images"
ort_out = sess.run(None, {in_name: x.numpy()})[0]
print("max abs diff:", np.abs(torch_out - ort_out).max())
np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-3)
print("FP32 parity OK")
atol/rtol around 1e-3 is reasonable for FP32. If this fails, suspect a layer the exporter lowered differently (often normalization or interpolation).
Build the engine, way 1: trtexec
trtexec is the fastest way to get a working engine and a latency number. FP16 first: it’s nearly free accuracy-wise on most CV models and a large speedup. Replace images with your real input name.
trtexec \
--onnx=resnet50.onnx \
--saveEngine=resnet50_fp16.engine \
--fp16 \
--minShapes=images:1x3x224x224 \
--optShapes=images:8x3x224x224 \
--maxShapes=images:16x3x224x224 \
--verbose
Those three *Shapes flags define the optimization profile for the dynamic batch axis. This is the number-one dynamic-shape gotcha: if you export with a dynamic axis but forget the profile, the build fails or silently locks to batch 1. optShapes is what TensorRT tunes kernels for, so set it to your real serving batch.
Build the engine, way 2: the Python API
I use the Python builder in CI and whenever I need an INT8 calibrator. This is the TensorRT 10.x API; it reads the input name from the parsed network so it never goes stale.
# build.py
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(0) # explicit batch is the default in TRT 10
parser = trt.OnnxParser(network, logger)
with open("resnet50.onnx", "rb") as f:
assert parser.parse(f.read()), [parser.get_error(i) for i in range(parser.num_errors)]
in_name = network.get_input(0).name # real name from the graph
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4 GiB
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
profile.set_shape(in_name, (1,3,224,224), (8,3,224,224), (16,3,224,224))
config.add_optimization_profile(profile)
serialized = builder.build_serialized_network(network, config)
assert serialized is not None, "engine build failed"
with open("resnet50_fp16.engine", "wb") as f:
f.write(serialized)
print("engine built")
INT8 with a calibration dataset
INT8 roughly halves latency again over FP16, but it needs a few hundred representative images to learn per-tensor scales. Representative is the word: calibrate on real frames from the deployment camera, not ImageNet, or accuracy drifts in ways your test set won’t catch.
# calibrator.py
import os, numpy as np, tensorrt as trt
from PIL import Image
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401 initializes a CUDA context
class Calibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, files, cache="calib.cache", bs=8):
super().__init__()
self.files, self.bs, self.cache, self.i = files, bs, cache, 0
self.dev = cuda.mem_alloc(bs * 3 * 224 * 224 * 4) # float32 bytes
def get_batch_size(self): return self.bs
def get_batch(self, names):
if self.i + self.bs > len(self.files): return None
batch = np.stack([self._load(f) for f in self.files[self.i:self.i+self.bs]])
self.i += self.bs
cuda.memcpy_htod(self.dev, np.ascontiguousarray(batch, np.float32))
return [int(self.dev)]
def _load(self, f):
a = np.asarray(Image.open(f).convert("RGB").resize((224,224)), np.float32) / 255.
a = (a - [0.485,0.456,0.406]) / [0.229,0.224,0.225]
return a.transpose(2,0,1)
def read_calibration_cache(self):
return open(self.cache,"rb").read() if os.path.exists(self.cache) else None
def write_calibration_cache(self, c): open(self.cache,"wb").write(c)
Install pycuda (pip install pycuda) and wire the calibrator into the build config from build.py:
config.set_flag(trt.BuilderFlag.INT8)
files = [os.path.join("calib_images", f) for f in os.listdir("calib_images")]
config.int8_calibrator = Calibrator(files)
On 10.x you can also let trtexec calibrate directly from a populated cache, but to generate that cache you still need the calibrator above; the Python route is the one I rely on.
The honest gotcha: INT8 accuracy loss is not uniform. Detection mAP and small-object recall degrade more than top-1 classification. If a class matters for safety (pedestrians, cyclists), measure that class specifically. When INT8 drops too much, keep sensitive layers in FP16 (set per-layer precision) or use ModelOpt quantization-aware training. Sometimes the right call is “FP16 only,” and that’s fine.
Benchmark latency and verify accuracy didn’t drift
# steady-state latency, not a single cold run
trtexec --loadEngine=resnet50_int8.engine --shapes=images:8x3x224x224 \
--iterations=1000 --avgRuns=100 --useCudaGraph
Read the median GPU compute time, not host wall time, and only on the target device: desktop numbers don’t transfer to Orin. Then run the engine on a real validation set and compare metrics against FP32. Current cuda-python (13.x) removed the old from cuda import cudart shim, so import from cuda.bindings:
# infer.py (pip install "cuda-python>=12.6")
import numpy as np, tensorrt as trt
from cuda.bindings import runtime as cudart # 13.x layout; 12.6+ also supports this
def chk(err):
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(cudart.cudaGetErrorString(err))
logger = trt.Logger(trt.Logger.WARNING)
engine = trt.Runtime(logger).deserialize_cuda_engine(open("resnet50_int8.engine","rb").read())
assert engine is not None, "deserialize failed (arch/TRT/CUDA mismatch?)"
ctx = engine.create_execution_context()
in_name = engine.get_tensor_name(0) # do not hardcode "images"/"logits"
out_name = engine.get_tensor_name(1)
ctx.set_input_shape(in_name, (8,3,224,224))
inp = np.random.randn(8,3,224,224).astype(np.float32)
out = np.empty((8,1000), np.float32)
err, d_in = cudart.cudaMalloc(inp.nbytes); chk(err)
err, d_out = cudart.cudaMalloc(out.nbytes); chk(err)
err, stream = cudart.cudaStreamCreate(); chk(err)
H2D = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
D2H = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
chk(cudart.cudaMemcpyAsync(d_in, inp.ctypes.data, inp.nbytes, H2D, stream)[0])
ctx.set_tensor_address(in_name, int(d_in))
ctx.set_tensor_address(out_name, int(d_out))
ctx.execute_async_v3(stream)
chk(cudart.cudaMemcpyAsync(out.ctypes.data, d_out, out.nbytes, D2H, stream)[0])
chk(cudart.cudaStreamSynchronize(stream)) # required before reading `out`
print("top-1:", out.argmax(1))
My ship/no-ship gate is simple: the INT8 engine must hit the per-frame latency budget on the target and stay within an agreed accuracy delta on the metric that matters. If it only passes one, it doesn’t ship.
Where this goes wrong
- Version/driver mismatch. “Engine deserialization failed” almost always means the engine was built with a different TensorRT/CUDA/arch than the runtime. Rebuild on the target. Pin your toolchain in a container.
- Tensor names. The dynamo exporter renames I/O. Always read names from the model, session, or engine; never hardcode them.
- Dynamic shapes. Don’t mark dims dynamic you don’t need; always supply a matching optimization profile; set
optShapesto your real batch. - Unsupported ops. Caught at export or parse time. Rewrite, move to host code, or (rarely) write a plugin.
- INT8 drift. Calibrate on real deployment frames, measure the safety-critical classes, fall back to mixed or FP16 precision when needed.
Get this scripted once and shipping a new model version becomes a 10-minute job: export, parity, build, calibrate, benchmark, gate.