Verified parity
The claim
Section titled “The claim”Every supported model exports with cos = +1.000000 numerical parity to the PyTorch reference. First-action max_abs differences sit between 2 × 10⁻⁷ and 8 × 10⁻⁷ — the floor of float32 arithmetic. These numbers are not “close enough.” They are the limit of what’s representable.
Why this matters
Section titled “Why this matters”Most “deploy your model” tools settle for “close enough” and quietly diverge in production. A 0.008 absolute error sounds small until your robot is operating a 7-DOF arm at 30 Hz, accumulating that error every step. Within thirty seconds of operation the policy and the served model are doing visibly different things, and the failure mode is silent — no error, no log, just slightly wrong actions.
Reflex treats parity as a load-bearing claim. If a measurement fails, the tool exits non-zero rather than shipping with the divergence.
How it’s verified
Section titled “How it’s verified”Two layers of checks:
-
Per-export parity check. Every
reflex exportruns the model on shared seeded inputs in both PyTorch and ONNX, computesmax_abs_diffper fixture, and refuses to write the export if the threshold fails (< 1e-4by default,< 1e-5strict). -
Production default unrolled-loop parity. Flow-matching VLAs (SmolVLA, pi0, pi0.5) integrate a velocity field over 10 Euler steps. We unroll the entire loop into ONNX and verify parity against the canonical 10-step PyTorch sampling — not just one step.
For GR00T’s DDPM-style diffusion (4 canonical steps), we export a single-step DiT and wrap the loop in reflex serve, then verify end-to-end that the loop also matches PyTorch (max_abs 4.77 × 10⁻⁷).
Verify it yourself
Section titled “Verify it yourself”reflex validate export ./p0 --model lerobot/pi0_base --threshold 1e-4Sample passing output:
Per-fixture resultsfixture_idx max_abs_diff mean_abs_diff passed0 3.21e-06 8.40e-07 PASS1 2.98e-06 7.92e-07 PASS...Summarymax_abs_diff_across_all 3.21e-06passed PASSExit codes: 0 pass, 1 fail (any fixture above threshold), 2 error.
Pipe --output-json for CI consumption.
What was hard about getting here
Section titled “What was hard about getting here”Getting pi0 / pi0.5 to export at machine precision required three interacting patches under torch.export:
F.padfor causal masks (the default trace produced incorrect mask shapes for pi0’s PaliGemma backbone)- Freezing
DynamicLayer.update(transformers 5.x’s KV cache dynamism doesn’t trace cleanly without it) - Manually computing
past_kv.get_seq_length()for mask assembly (the symbolic shape trace lost the connection)
GR00T’s simpler DiT graph (no DynamicCache, no PaliGemma masking) traces cleanly via plain torch.onnx.export(opset=19) — no patches needed.