# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Base evaluator for agent evaluation and metrics computation.
This module provides the base evaluation infrastructure for computing performance
metrics during training and evaluation. Evaluators run periodic assessments of
agent performance and compute task-specific metrics.
Key Classes:
- BaseEvaluator: Base class for all evaluators
- SmoothnessMetricPlugin: Plugin for computing motion smoothness metrics
Key Features:
- Periodic evaluation during training
- Motion quality metrics computation
- Episode statistics aggregation
- Smoothness and jerk analysis
- Distributed evaluation support
"""
import torch
from typing import Dict, Optional, Tuple, Any
from lightning.fabric import Fabric
from protomotions.agents.evaluators.metrics import MotionMetrics
from protomotions.agents.evaluators.smoothness_evaluator import SmoothnessEvaluator
from protomotions.envs.base_env.env import BaseEnv
from protomotions.agents.evaluators.config import EvaluatorConfig
[docs]
class SmoothnessMetricPlugin:
"""Plugin for computing smoothness metrics from motion data."""
[docs]
def __init__(
self, evaluator, window_sec: float = 0.4, high_jerk_threshold: float = 6500.0
):
"""
Initialize the smoothness metric plugin.
Args:
evaluator: The parent evaluator instance
window_sec: Window size in seconds for smoothness computation
high_jerk_threshold: Threshold for classifying high jerk frames
"""
self.smoothness_evaluator = SmoothnessEvaluator(
device=evaluator.device,
dt=evaluator.env.dt,
window_sec=window_sec,
high_jerk_threshold=high_jerk_threshold,
)
self.num_bodies = evaluator.env.robot_config.kinematic_info.num_bodies
[docs]
def compute(self, metrics: Dict[str, MotionMetrics]) -> Dict[str, float]:
"""
Compute smoothness metrics from collected motion data.
Args:
metrics: Dictionary of MotionMetrics
Returns:
Dictionary of smoothness metrics with "eval/" prefix
"""
smoothness_metrics = self.smoothness_evaluator.compute_smoothness_metrics(
metrics, self.num_bodies
)
# Add logging for each smoothness metric
result = {}
for k, v in smoothness_metrics.items():
print(f"Smoothness metric: {k}, value: {v}")
result[f"eval/{k}"] = v
return result
[docs]
class BaseEvaluator:
"""Base class for agent evaluation and metrics computation.
Runs periodic evaluations during training to assess agent performance.
Collects episode statistics, computes task-specific metrics, and provides
feedback for checkpoint selection (best model saving).
Args:
agent: The agent being evaluated.
fabric: Lightning Fabric instance for distributed evaluation.
config: Evaluator configuration specifying eval frequency and length.
Example:
>>> evaluator = BaseEvaluator(agent, fabric, config)
>>> metrics, score = evaluator.evaluate()
"""
[docs]
def __init__(self, agent: Any, fabric: Fabric, config: EvaluatorConfig):
"""
Initialize the evaluator.
Args:
agent: The agent to evaluate
fabric: Lightning Fabric instance for distributed training
"""
self.agent = agent
self.fabric = fabric
self.config = config
# Plugin system for additional metrics
self.metric_plugins = []
self._register_plugins()
# Counter for tracking evaluation calls
self.eval_count = 0
@property
def device(self) -> torch.device:
"""Device for computations (from fabric)."""
return self.fabric.device
@property
def env(self) -> BaseEnv:
"""Environment instance (from agent)."""
return self.agent.env
@property
def root_dir(self):
"""Root directory for saving outputs (from agent)."""
return self.agent.root_dir
@torch.no_grad()
def evaluate(self) -> Tuple[Dict, Optional[float]]:
"""
Evaluate the agent and calculate metrics.
This is the main entry point that orchestrates the evaluation process.
Returns:
Tuple containing:
- Dict of evaluation metrics for logging
- Optional score value for determining best model
"""
self.agent.eval()
# Initialize metrics and prepare evaluation context
metrics = self.initialize_eval()
if not metrics:
return {}, 0
# Run evaluation
self.run_evaluation(metrics)
# Process evaluation results
evaluation_log, evaluated_score = self.process_eval_results(metrics)
# Cleanup after evaluation
self.cleanup_after_evaluation()
# Increment eval counter
self.eval_count += 1
return evaluation_log, evaluated_score
[docs]
def initialize_eval(self) -> Tuple[Dict, Dict]:
"""
Initialize metrics dictionary with required keys.
Prepare the evaluation context.
Returns:
Tuple containing metrics dict and evaluation context dict
"""
return {}
[docs]
def run_evaluation(self, metrics: Dict) -> None:
"""
Run the evaluation process and collect metrics.
Args:
metrics: Dictionary to collect evaluation metrics
"""
raise NotImplementedError("Run evaluation not implemented for base evaluator.")
[docs]
def process_eval_results(
self, metrics: Dict, eval_context: Dict
) -> Tuple[Dict, Optional[float]]:
"""
Process collected metrics and prepare for logging.
Args:
metrics: Dictionary of collected metrics
eval_context: Dictionary containing evaluation context
Returns:
Tuple containing:
- Dict of processed metrics for logging
- Optional score value for determining best model
"""
return {}, None
[docs]
def cleanup_after_evaluation(self) -> None:
"""Clean up after evaluation (reset env state, etc.)"""
pass
def _create_base_metrics(
self,
metric_keys: list,
num_motions: int,
motion_num_frames: torch.Tensor,
max_eval_steps: int,
) -> Dict[str, MotionMetrics]:
"""
Create MotionMetrics objects for a list of keys.
Args:
metric_keys: List of metric keys to create
num_motions: Number of motions to evaluate
motion_num_frames: Number of frames per motion
max_eval_steps: Maximum evaluation steps
Returns:
Dictionary of MotionMetrics objects
"""
metrics = {}
for k in metric_keys:
metrics[k] = MotionMetrics(
num_motions, motion_num_frames, max_eval_steps, device=self.device
)
return metrics
def _add_robot_state_metrics(
self,
metrics: Dict[str, MotionMetrics],
num_motions: int,
motion_num_frames: torch.Tensor,
max_eval_steps: int,
) -> None:
"""
Add metrics for raw robot state (dof_pos, rigid_body_pos, etc.).
This is needed for derived metrics like smoothness.
Args:
metrics: Existing metrics dict to add to
num_motions: Number of motions to evaluate
motion_num_frames: Number of frames per motion
max_eval_steps: Maximum evaluation steps
"""
# Default implementation for humanoid robot state
if not hasattr(self.env, "simulator"):
return
try:
from protomotions.simulator.base_simulator.simulator_state import RobotState
dummy_state: RobotState = self.env.simulator.get_robot_state()
shape_mapping = dummy_state.get_shape_mapping(flattened=True)
for k, shape in shape_mapping.items():
metrics[k] = MotionMetrics(
num_motions,
motion_num_frames,
max_eval_steps,
num_sub_features=shape[0],
device=self.device,
)
except Exception as e:
print(f"Warning: Could not add robot state metrics: {e}")
def _register_plugins(self) -> None:
"""Register metric computation plugins. Override in subclasses."""
pass
def _register_smoothness_plugin(
self, window_sec: float = 0.4, high_jerk_threshold: float = 6500.0
) -> bool:
"""
Convenience method to register smoothness metric plugin.
Args:
window_sec: Window size in seconds for smoothness computation
high_jerk_threshold: Threshold for classifying high jerk frames
Returns:
True if plugin was registered successfully, False otherwise
"""
try:
self.metric_plugins.append(
SmoothnessMetricPlugin(self, window_sec, high_jerk_threshold)
)
return True
except ValueError as e:
print(f"Skipping smoothness plugin: {e}")
return False
def _compute_additional_metrics(
self, metrics: Dict[str, MotionMetrics]
) -> Dict[str, float]:
"""
Run all registered metric plugins to compute additional metrics.
Args:
metrics: Dictionary of collected MotionMetrics
Returns:
Dictionary of additional computed metrics
"""
additional_metrics = {}
for plugin in self.metric_plugins:
try:
plugin_metrics = plugin.compute(metrics)
additional_metrics.update(plugin_metrics)
except Exception as e:
print(f"Warning: Plugin {plugin.__class__.__name__} failed: {e}")
return additional_metrics
def _gen_metrics(
self, metrics: Dict[str, MotionMetrics], keys_to_log: list, prefix: str = "eval"
) -> Dict[str, float]:
"""
Log metrics with mean/max/min aggregations across motions.
For each metric, computes:
- mean: average across all per-motion means (overall performance)
- max: maximum of per-motion means (worst performing motion)
- min: minimum of per-motion means (best performing motion)
This gives you 3 separate line plot groups that track over time:
- {prefix}_mean/{metric}: How well you perform on average
- {prefix}_max/{metric}: How well you perform on the hardest motion
- {prefix}_min/{metric}: How well you perform on the easiest motion
Args:
metrics: Dictionary of MotionMetrics
keys_to_log: List of metric keys to log
prefix: Base prefix for logged metric names (default: "eval")
Returns:
Dictionary of logged metrics
"""
to_log = {}
for k in keys_to_log:
if k in metrics:
to_log[f"{prefix}_mean/{k}"] = metrics[k].mean_mean_reduce().item()
to_log[f"{prefix}_max/{k}"] = metrics[k].mean_max_reduce().item()
to_log[f"{prefix}_min/{k}"] = metrics[k].mean_min_reduce().item()
return to_log
def _save_list_to_file(
self, items: list, filename: str, subdirectory: Optional[str] = None
) -> None:
"""
Save a list of items to a text file (one per line).
Args:
items: List of items to save
filename: Name of output file
subdirectory: Optional subdirectory within root_dir
"""
if subdirectory:
output_dir = self.root_dir / subdirectory
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / filename
else:
output_path = self.root_dir / filename
print(f"Saving to: {output_path}")
with open(output_path, "w") as f:
for item in items:
f.write(f"{item}\n")
def _plot_per_frame_metrics(
self,
metrics: Dict[str, MotionMetrics],
keys_to_plot: Optional[list] = None,
motion_id: int = 0,
custom_colors: Optional[Dict[str, str]] = None,
output_filename: str = "metrics_per_frame_plot.png",
) -> None:
"""
Plot per-frame metrics vs time for a single motion.
Only plots single-feature metrics (ignores multi-feature metrics).
Args:
metrics: Dictionary of MotionMetrics objects
keys_to_plot: List of keys to plot (None = plot all single-feature metrics)
motion_id: Which motion to plot (default: 0)
custom_colors: Optional dict mapping metric keys to colors
output_filename: Name of output file
"""
try:
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
print("matplotlib not available, skipping plotting")
return
dt = self.env.dt
custom_colors = custom_colors or {}
# Filter to only single-feature metrics
single_feature_metrics = {}
valid_frames = {}
# Determine which keys to plot
if keys_to_plot is None:
keys_to_plot = list(metrics.keys())
for k in keys_to_plot:
if k in metrics and metrics[k].num_sub_features == 1:
single_feature_metrics[k] = metrics[k]
valid_frames[k] = metrics[k].frame_counts[motion_id].item()
if not single_feature_metrics:
print("No single-feature metrics found for plotting")
return
# Create subplots for each single-feature metric
num_metrics = len(single_feature_metrics)
fig, axes = plt.subplots(num_metrics, 1, figsize=(12, 4 * num_metrics))
if num_metrics == 1:
axes = [axes]
for i, k in enumerate(single_feature_metrics.keys()):
metric = single_feature_metrics[k]
num_valid_frames = valid_frames[k]
if num_valid_frames == 0:
axes[i].text(
0.5,
0.5,
f"No data for {k}",
horizontalalignment="center",
verticalalignment="center",
transform=axes[i].transAxes,
)
axes[i].set_title(f"{k}")
continue
# Extract data for the single motion (single feature)
data = metric.data[motion_id, :num_valid_frames, 0].cpu().numpy()
time_steps = np.arange(num_valid_frames) * dt
# Use custom color if provided, otherwise matplotlib default
plot_kwargs = {"label": k, "linewidth": 2}
if k in custom_colors:
plot_kwargs["color"] = custom_colors[k]
axes[i].plot(time_steps, data, **plot_kwargs)
axes[i].set_xlabel("Time (s)")
axes[i].set_ylabel(f"{k}")
axes[i].set_title(f"{k} vs Time")
axes[i].grid(True, alpha=0.3)
axes[i].legend()
plt.tight_layout()
# Save the plot
if hasattr(self, "root_dir") and self.root_dir is not None:
plot_path = self.root_dir / output_filename
plt.savefig(plot_path, dpi=150, bbox_inches="tight")
print(f"Per-frame metrics plot saved to: {plot_path}")
plt.close(fig)
print("Per-frame metrics plotted successfully")
[docs]
def simple_test_policy(self, collect_metrics: bool = False) -> None:
"""
Simple evaluation loop for testing the policy.
Args:
collect_metrics: whether to collect metrics during evaluation
"""
self.agent.eval()
done_indices = None # Force reset on first entry
step = 0
print("Evaluating policy...")
try:
while True:
obs, _ = self.env.reset(done_indices)
obs = self.agent.add_agent_info_to_obs(obs)
obs_td = self.agent.obs_dict_to_tensordict(obs)
# Obtain actor predictions
model_outs = self.agent.model(obs_td)
if "mean_action" in model_outs:
actions = model_outs["mean_action"]
else:
actions = model_outs["action"]
# Step the environment
obs, rewards, dones, terminated, extras = self.env.step(actions)
obs = self.agent.add_agent_info_to_obs(obs)
obs_td = self.agent.obs_dict_to_tensordict(obs)
done_indices = dones.nonzero(as_tuple=False).squeeze(-1)
step += 1
except KeyboardInterrupt:
print("\nEvaluation interrupted by Ctrl+C, exiting...")
return None