Skip to content

Verified parity

Every supported model exports with cos = +1.000000 numerical parity to the PyTorch reference. First-action max_abs differences sit between 2 × 10⁻⁷ and 8 × 10⁻⁷ — the floor of float32 arithmetic. These numbers are not “close enough.” They are the limit of what’s representable.

Most “deploy your model” tools settle for “close enough” and quietly diverge in production. A 0.008 absolute error sounds small until your robot is operating a 7-DOF arm at 30 Hz, accumulating that error every step. Within thirty seconds of operation the policy and the served model are doing visibly different things, and the failure mode is silent — no error, no log, just slightly wrong actions.

Reflex treats parity as a load-bearing claim. If a measurement fails, the tool exits non-zero rather than shipping with the divergence.

Two layers of checks:

  1. Per-export parity check. Every reflex export runs the model on shared seeded inputs in both PyTorch and ONNX, computes max_abs_diff per fixture, and refuses to write the export if the threshold fails (< 1e-4 by default, < 1e-5 strict).

  2. Production default unrolled-loop parity. Flow-matching VLAs (SmolVLA, pi0, pi0.5) integrate a velocity field over 10 Euler steps. We unroll the entire loop into ONNX and verify parity against the canonical 10-step PyTorch sampling — not just one step.

For GR00T’s DDPM-style diffusion (4 canonical steps), we export a single-step DiT and wrap the loop in reflex serve, then verify end-to-end that the loop also matches PyTorch (max_abs 4.77 × 10⁻⁷).

Terminal window
reflex validate export ./p0 --model lerobot/pi0_base --threshold 1e-4

Sample passing output:

Per-fixture results
fixture_idx max_abs_diff mean_abs_diff passed
0 3.21e-06 8.40e-07 PASS
1 2.98e-06 7.92e-07 PASS
...
Summary
max_abs_diff_across_all 3.21e-06
passed PASS

Exit codes: 0 pass, 1 fail (any fixture above threshold), 2 error.

Pipe --output-json for CI consumption.

Getting pi0 / pi0.5 to export at machine precision required three interacting patches under torch.export:

  • F.pad for causal masks (the default trace produced incorrect mask shapes for pi0’s PaliGemma backbone)
  • Freezing DynamicLayer.update (transformers 5.x’s KV cache dynamism doesn’t trace cleanly without it)
  • Manually computing past_kv.get_seq_length() for mask assembly (the symbolic shape trace lost the connection)

GR00T’s simpler DiT graph (no DynamicCache, no PaliGemma masking) traces cleanly via plain torch.onnx.export(opset=19) — no patches needed.