protomotions.utils.export_utils module¶
Utilities for exporting trained models to ONNX format.
This module provides functions to export TensorDict-based models to ONNX format using torch.onnx.dynamo_export. The exported models can be used for deployment and inference in production environments.
- Key Functions:
export_onnx: Export a TensorDictModule to ONNX format
export_ppo_model: Export a trained PPO model to ONNX
- class protomotions.utils.export_utils.ONNXExportWrapper(module, in_keys)[source]¶
Bases:
<Mock object at 0x701e6b0c7490>[]Wrapper for TensorDictModule that accepts
**kwargsfor ONNX export.TensorDictModules expect a TensorDict argument, but torch.onnx.dynamo_export unpacks inputs as kwargs. This wrapper bridges the gap.
- protomotions.utils.export_utils.export_ppo_actor(actor, sample_obs, path, validate=True)[source]¶
Export a PPO actor’s mu network to ONNX.
Exports the mean network (mu) of a PPO actor, which is the core policy network without the distribution layer. Uses real observations from the environment to ensure proper tracing.
- Parameters:
Example
>>> # Get real observations from environment >>> env.reset() >>> sample_obs = agent.get_obs() >>> export_ppo_actor(agent.model._actor, sample_obs, "ppo_actor.onnx")
- protomotions.utils.export_utils.export_ppo_critic(critic, sample_obs, path, validate=True)[source]¶
Export a PPO critic network to ONNX.
Uses real observations from the environment to ensure proper tracing.
- Parameters:
Example
>>> # Get real observations from environment >>> env.reset() >>> sample_obs = agent.get_obs() >>> export_ppo_critic(agent.model._critic, sample_obs, "ppo_critic.onnx")
- protomotions.utils.export_utils.export_ppo_model(model, sample_obs, output_dir, validate=True)[source]¶
Export a complete PPO model (actor and critic) to ONNX.
Exports both the actor and critic networks to separate ONNX files in the specified directory.
- Parameters:
- Returns:
Dict with paths to exported files.
Example
>>> model = trained_agent.model >>> sample_obs = {"obs": torch.randn(1, 128)} >>> paths = export_ppo_model(model, sample_obs, "exported_models/") >>> print(paths) {'actor': 'exported_models/actor.onnx', 'critic': 'exported_models/critic.onnx'}