Source code for protomotions.utils.export_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for exporting trained models to ONNX format.

This module provides functions to export TensorDict-based models to ONNX format
using torch.onnx.dynamo_export. The exported models can be used for deployment
and inference in production environments.

Key Functions:
    - export_onnx: Export a TensorDictModule to ONNX format
    - export_ppo_model: Export a trained PPO model to ONNX
"""

import torch
import json
from pathlib import Path
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from typing import Optional, Dict, Any


[docs] class ONNXExportWrapper(torch.nn.Module): """Wrapper for TensorDictModule that accepts ``**kwargs`` for ONNX export. TensorDictModules expect a TensorDict argument, but torch.onnx.dynamo_export unpacks inputs as kwargs. This wrapper bridges the gap. """
[docs] def __init__(self, module: TensorDictModuleBase, in_keys: list): super().__init__() self.module = module self.in_keys = in_keys
[docs] def forward(self, **kwargs): """Forward that reconstructs TensorDict from kwargs.""" # Reconstruct TensorDict from kwargs batch_size = kwargs[self.in_keys[0]].shape[0] td = TensorDict(kwargs, batch_size=batch_size) # Forward through original module output_td = self.module(td) # Return tuple of outputs (ONNX expects tuple, not dict) return tuple(output_td[key] for key in self.module.out_keys)
@torch.inference_mode() def export_onnx( module: TensorDictModuleBase, td: TensorDict, path: str, meta: Optional[Dict[str, Any]] = None, validate: bool = True, ): """Export a TensorDictModule to ONNX format. Uses torch.onnx.dynamo_export to export the module. Creates a wrapper that converts between TensorDict and **kwargs for ONNX compatibility. Args: module: TensorDictModule to export. td: Sample TensorDict input (used for tracing). path: Path to save the ONNX model (must end with .onnx). meta: Optional additional metadata to save. validate: If True, validates the exported model with onnxruntime. Raises: ValueError: If path doesn't end with .onnx. Example: >>> from protomotions.agents.ppo.model import PPOModel >>> from tensordict import TensorDict >>> model = PPOModel(config) >>> sample_input = TensorDict({"obs": torch.randn(1, 128)}, batch_size=1) >>> export_onnx(model, sample_input, "policy.onnx") """ if not path.endswith(".onnx"): raise ValueError(f"Export path must end with .onnx, got {path}.") # Move to CPU and select only required input keys td = td.cpu().select(*module.in_keys, strict=True) module = module.cpu() module.eval() print(f"Exporting model to ONNX (PyTorch {torch.__version__})...") print(f" Input keys: {module.in_keys}") print(f" Output keys: {module.out_keys}") # Create wrapper that accepts **kwargs instead of TensorDict wrapper = ONNXExportWrapper(module, module.in_keys) wrapper.eval() # Export using dynamo with unpacked dict onnx_program = torch.onnx.dynamo_export(wrapper, **td.to_dict()) onnx_program.save(path) print(f"✓ Exported ONNX model to {path}") # Get actual ONNX input/output names import onnxruntime as ort ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) onnx_input_names = [inp.name for inp in ort_session.get_inputs()] onnx_output_names = [out.name for out in ort_session.get_outputs()] # Save metadata meta_path = path.replace(".onnx", ".json") if meta is None: meta = {} meta["in_keys"] = module.in_keys meta["out_keys"] = module.out_keys meta["in_shapes"] = [list(td[k].shape) for k in module.in_keys] meta["onnx_input_names"] = onnx_input_names meta["onnx_output_names"] = onnx_output_names meta["input_mapping"] = { onnx_name: semantic_name for onnx_name, semantic_name in zip(onnx_input_names, module.in_keys) } meta["output_mapping"] = { onnx_name: semantic_name for onnx_name, semantic_name in zip(onnx_output_names, module.out_keys) } with open(meta_path, "w") as f: json.dump(meta, f, indent=4) print(f"✓ Exported metadata to {meta_path}") # Validate with onnxruntime if validate: try: import onnxruntime as ort ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) def to_numpy(tensor): return ( tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() ) onnx_input = tuple(td[k] for k in module.in_keys) onnxruntime_input = { inp.name: to_numpy(v) for inp, v in zip(ort_session.get_inputs(), onnx_input) } ort_output = ort_session.run(None, onnxruntime_input) assert len(ort_output) == len( module.out_keys ), f"Output length mismatch: {len(ort_output)} vs {len(module.out_keys)}" print("✓ ONNX model validation successful!") except ImportError: print("⚠ Warning: onnxruntime not installed, skipping validation.") except Exception as e: print(f"⚠ Warning: ONNX validation failed: {e}")
[docs] def export_ppo_actor( actor, sample_obs: Dict[str, torch.Tensor], path: str, validate: bool = True ): """Export a PPO actor's mu network to ONNX. Exports the mean network (mu) of a PPO actor, which is the core policy network without the distribution layer. Uses real observations from the environment to ensure proper tracing. Args: actor: PPOActor instance to export. sample_obs: Sample observation dict from environment (via agent.get_obs()). path: Path to save the ONNX model. validate: If True, validates the exported model. Example: >>> # Get real observations from environment >>> env.reset() >>> sample_obs = agent.get_obs() >>> export_ppo_actor(agent.model._actor, sample_obs, "ppo_actor.onnx") """ # Create TensorDict from sample observations batch_size = sample_obs[list(sample_obs.keys())[0]].shape[0] td = TensorDict(sample_obs, batch_size=batch_size) # Export the mu network (policy mean) meta = { "model_type": "PPOActor", "observation_keys": list(sample_obs.keys()), "observation_shapes": {k: list(v.shape) for k, v in sample_obs.items()}, } export_onnx(actor, td, path, meta=meta, validate=validate)
[docs] def export_ppo_critic( critic, sample_obs: Dict[str, torch.Tensor], path: str, validate: bool = True ): """Export a PPO critic network to ONNX. Uses real observations from the environment to ensure proper tracing. Args: critic: PPO critic (MultiHeadedMLP) instance to export. sample_obs: Sample observation dict from environment (via agent.get_obs()). path: Path to save the ONNX model. validate: If True, validates the exported model. Example: >>> # Get real observations from environment >>> env.reset() >>> sample_obs = agent.get_obs() >>> export_ppo_critic(agent.model._critic, sample_obs, "ppo_critic.onnx") """ # Create TensorDict from sample observations batch_size = sample_obs[list(sample_obs.keys())[0]].shape[0] td = TensorDict(sample_obs, batch_size=batch_size) meta = { "model_type": "PPOCritic", "num_out": critic.config.num_out, "observation_keys": list(sample_obs.keys()), "observation_shapes": {k: list(v.shape) for k, v in sample_obs.items()}, } export_onnx(critic, td, path, meta=meta, validate=validate)
[docs] def export_ppo_model( model, sample_obs: Dict[str, torch.Tensor], output_dir: str, validate: bool = True ): """Export a complete PPO model (actor and critic) to ONNX. Exports both the actor and critic networks to separate ONNX files in the specified directory. Args: model: PPOModel instance to export. sample_obs: Sample observation dict for tracing. output_dir: Directory to save the ONNX models. validate: If True, validates the exported models. Returns: Dict with paths to exported files. Example: >>> model = trained_agent.model >>> sample_obs = {"obs": torch.randn(1, 128)} >>> paths = export_ppo_model(model, sample_obs, "exported_models/") >>> print(paths) {'actor': 'exported_models/actor.onnx', 'critic': 'exported_models/critic.onnx'} """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) actor_path = str(output_dir / "actor.onnx") critic_path = str(output_dir / "critic.onnx") print("Exporting PPO Actor...") export_ppo_actor(model._actor, sample_obs, actor_path, validate=validate) print("\nExporting PPO Critic...") export_ppo_critic(model._critic, sample_obs, critic_path, validate=validate) print(f"\nExport complete! Models saved to {output_dir}") return { "actor": actor_path, "critic": critic_path, "metadata": { "actor_meta": str(output_dir / "actor.json"), "critic_meta": str(output_dir / "critic.json"), }, }