Source code for protomotions.agents.evaluators.metrics

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from typing import Optional, Callable


[docs] class MotionMetrics: """ Store and compute metrics for motion data. Stores raw data in the shape [num_motions, max_motion_len, num_sub_features] and supports basic reduction operations for computing final metrics. """
[docs] def __init__( self, num_motions: int, motion_lens: torch.Tensor, max_motion_len: int, num_sub_features: int = 1, device: torch.device = None, dtype: torch.dtype = torch.float32, ): """ Initialize the metrics tracker. Args: num_motions: Number of motions to track motion_lens: Number of frames of each motion sequence max_motion_len: conservative max number of frames allocated for data storage for shape consistency across different GPUs when aggregating num_sub_features: Number of sub-features per data point (default: 1) device: Device to store the tensors on dtype: Data type for the tensors """ self.num_motions = num_motions self.num_sub_features = num_sub_features self.device = device self.dtype = dtype self.motion_lens = motion_lens self.max_motion_len = max_motion_len # Raw data storage self.data = torch.zeros( (num_motions, self.max_motion_len, num_sub_features), device=device, dtype=dtype, ) # Counters to track number of frames per motion self.frame_counts = torch.zeros(num_motions, device=device, dtype=torch.long)
[docs] def update( self, motion_ids: torch.Tensor, values: torch.Tensor, frame_indices: Optional[torch.Tensor] = None, ) -> None: """ Update the metrics data for specified motions. Args: motion_ids: Tensor of motion IDs to update [batch_size] values: Tensor of values to update [batch_size, num_sub_features] frame_indices: Optional tensor of frame indices [batch_size] If None, will use the current count for each motion """ if values.ndim == 1: values = values.unsqueeze(1) assert motion_ids.shape[0] == values.shape[0] # assert motion_ids being non-duplicated assert torch.unique(motion_ids).shape[0] == motion_ids.shape[0] if frame_indices is None: # Use current counts as frame indices frame_indices = self.frame_counts[motion_ids] assert frame_indices.shape[0] == values.shape[0] # Update the data using batched operations with per-motion length checks valid_mask = frame_indices < self.motion_lens[motion_ids] if valid_mask.any(): valid_motion_ids = motion_ids[valid_mask] valid_frame_indices = frame_indices[valid_mask] valid_values = values[valid_mask] # Use advanced indexing to update only valid entries self.data[valid_motion_ids, valid_frame_indices] = valid_values # Increment frame counts self.frame_counts[valid_motion_ids] += 1
[docs] def get_unfilled_mask( self, ) -> torch.Tensor: """ Get a mask of the unfilled values in the data. """ # Create indices matrix and mask for values beyond frame count indices = torch.arange(self.max_motion_len, device=self.device).unsqueeze( 0 ) # [1, max_motion_len] frame_counts = self.frame_counts.unsqueeze(1) # [num_motions, 1] mask = indices >= frame_counts # [num_motions, max_motion_len] # Expand mask to cover sub-features dimension mask = mask.unsqueeze(-1).expand( -1, -1, self.num_sub_features ) # [num_motions, max_motion_len, num_sub_features] return mask
[docs] def max_reduce_each_motion( self, with_frame: bool = False, ) -> torch.Tensor: """ Reduce the data by taking the max of each motion. """ mask = self.get_unfilled_mask() # Apply the mask to set values beyond frame count to -inf data = self.data.masked_fill(mask, float("-inf")) # take the max of each motion max_values, max_frames = data.max(dim=1) # [num_motions, num_sub_features] if self.num_sub_features == 1: max_values = max_values[:, 0] if with_frame: return max_values, max_frames else: return max_values
[docs] def min_reduce_each_motion( self, ) -> torch.Tensor: """ Reduce the data by taking the min of each motion. """ mask = self.get_unfilled_mask() # Apply the mask to set values beyond frame count to inf data = self.data.masked_fill(mask, float("inf")) # take the min of each motion min_values = data.min(dim=1).values # [num_motions, num_sub_features] if self.num_sub_features == 1: return min_values[:, 0] else: return min_values
[docs] def mean_reduce_each_motion( self, ) -> torch.Tensor: """ Reduce the data by taking the mean of each motion. """ mask = self.get_unfilled_mask() # Apply the mask to set values beyond frame count to 0 for summation data = self.data.masked_fill(mask, 0) # Sum the valid values for each motion sum_values = data.sum(dim=1) # [num_motions, num_sub_features] # Get the actual number of frames for each motion frame_counts = self.frame_counts.unsqueeze(-1).clamp( min=1 ) # [num_motions, 1], clamp to avoid division by zero # Calculate the mean by dividing the sum by the actual frame count mean_values = sum_values / frame_counts # [num_motions, num_sub_features] # Handle motions with zero frames explicitly, setting their mean to 0 (or NaN if preferred) zero_frame_mask = (self.frame_counts == 0).unsqueeze(-1) mean_values = mean_values.masked_fill(zero_frame_mask, 0.0) if self.num_sub_features == 1: return mean_values[:, 0] else: return mean_values
[docs] def ops_mean_reduce( self, op: Callable, ) -> torch.Tensor: """ first reduce the data by taking the op of each motion, then mean reduce across motions. """ # Check if op is a bound method of this instance if hasattr(op, "__self__") and op.__self__ is self: op_values = op() # Call bound method directly else: op_values = op(self) # Call external function with self as argument op_values_valid = op_values[self.frame_counts > 0] # take the mean across num of motions mean_values = op_values_valid.mean(dim=0) # [num_sub_features] or scalar return mean_values
[docs] def max_mean_reduce( self, ) -> torch.Tensor: return self.ops_mean_reduce(self.max_reduce_each_motion)
[docs] def min_mean_reduce( self, ) -> torch.Tensor: return self.ops_mean_reduce(self.min_reduce_each_motion)
[docs] def mean_mean_reduce( self, ) -> torch.Tensor: return self.ops_mean_reduce(self.mean_reduce_each_motion)
[docs] def mean_max_reduce( self, ) -> torch.Tensor: """ First reduce each motion by taking the mean over valid frames, then take the max across all motions. Returns: torch.Tensor: Maximum of the per-motion means (worst performing motion) """ mean_values = ( self.mean_reduce_each_motion() ) # [num_motions] or [num_motions, num_sub_features] # Only consider motions with valid frames valid_mask = self.frame_counts > 0 if not valid_mask.any(): # No valid motions, return zeros with appropriate shape if self.num_sub_features == 1: return torch.tensor(0.0, device=self.device, dtype=self.dtype) else: return torch.zeros( self.num_sub_features, device=self.device, dtype=self.dtype ) mean_values_valid = mean_values[valid_mask] max_value = ( mean_values_valid.max(dim=0).values if mean_values_valid.ndim > 1 else mean_values_valid.max() ) return max_value
[docs] def mean_min_reduce( self, ) -> torch.Tensor: """ First reduce each motion by taking the mean over valid frames, then take the min across all motions. Returns: torch.Tensor: Minimum of the per-motion means (best performing motion) """ mean_values = ( self.mean_reduce_each_motion() ) # [num_motions] or [num_motions, num_sub_features] # Only consider motions with valid frames valid_mask = self.frame_counts > 0 if not valid_mask.any(): # No valid motions, return zeros with appropriate shape if self.num_sub_features == 1: return torch.tensor(0.0, device=self.device, dtype=self.dtype) else: return torch.zeros( self.num_sub_features, device=self.device, dtype=self.dtype ) mean_values_valid = mean_values[valid_mask] min_value = ( mean_values_valid.min(dim=0).values if mean_values_valid.ndim > 1 else mean_values_valid.min() ) return min_value
[docs] def compute_finite_difference_jitter_reduce_each_motion( self, num_bodies: int, aggregate_method: str = "mean", order: int = 2, field_description: str = "data", ) -> torch.Tensor: """ Generic method to compute jitter using finite differences of specified order. Output is padded to match input length (padded with zeros at the beginning). Args: num_bodies: Number of rigid bodies (to reshape the flattened data) aggregate_method: How to aggregate across bodies ("mean", "max", "sum") order: Order of finite differences (1 for velocity-like, 2 for acceleration-like) field_description: Description of the field for error messages Returns: torch.Tensor: Jitter values with shape [num_motions, max_motion_len] (same as input) """ assert ( self.num_sub_features == num_bodies * 3 ), f"Expected num_sub_features={num_bodies * 3}, got {self.num_sub_features}" assert order in [ 1, 2, ], f"Only 1st and 2nd order finite differences supported, got {order}" # Get the mask for valid data mask = ( self.get_unfilled_mask() ) # [num_motions, max_motion_len, num_sub_features] # Apply mask to data (set invalid entries to 0) data = self.data.masked_fill( mask, 0.0 ) # [num_motions, max_motion_len, num_bodies*3] # Reshape to separate bodies: [num_motions, max_motion_len, num_bodies, 3] data_reshaped = data.view(self.num_motions, self.max_motion_len, num_bodies, 3) # Check if we have enough frames if self.max_motion_len < order + 1: # Not enough frames for specified order differences, return all zeros jitter = torch.zeros( self.num_motions, self.max_motion_len, device=self.device ) return jitter if order == 1: # 1st order finite difference: data[t+1] - data[t] finite_diffs = ( data_reshaped[:, 1:, :, :] - data_reshaped[:, :-1, :, :] ) # [num_motions, max_motion_len-1, num_bodies, 3] # Pad with zeros at the beginning: [num_motions, max_motion_len, num_bodies, 3] finite_diffs = torch.cat( [ torch.zeros(self.num_motions, 1, num_bodies, 3, device=self.device), finite_diffs, ], dim=1, ) elif order == 2: # 2nd order finite difference: data[t+1] - 2*data[t] + data[t-1] data_t_minus_1 = data_reshaped[ :, :-2, :, : ] # [num_motions, max_motion_len-2, num_bodies, 3] data_t = data_reshaped[ :, 1:-1, :, : ] # [num_motions, max_motion_len-2, num_bodies, 3] data_t_plus_1 = data_reshaped[ :, 2:, :, : ] # [num_motions, max_motion_len-2, num_bodies, 3] finite_diffs = ( data_t_plus_1 - 2 * data_t + data_t_minus_1 ) # [num_motions, max_motion_len-2, num_bodies, 3] # Pad with zeros at the beginning: [num_motions, max_motion_len, num_bodies, 3] finite_diffs = torch.cat( [ torch.zeros(self.num_motions, 2, num_bodies, 3, device=self.device), finite_diffs, ], dim=1, ) # Compute L2 norm for each body: [num_motions, max_motion_len, num_bodies] jitter_per_body = torch.norm(finite_diffs, dim=-1) # Aggregate across bodies if aggregate_method == "mean": jitter = jitter_per_body.mean(dim=-1) # [num_motions, max_motion_len] elif aggregate_method == "max": jitter = jitter_per_body.max(dim=-1).values # [num_motions, max_motion_len] elif aggregate_method == "sum": jitter = jitter_per_body.sum(dim=-1) # [num_motions, max_motion_len] else: raise ValueError(f"Unknown aggregate_method: {aggregate_method}") # Apply the original mask to ensure jitter is 0 for invalid frames # Create a simple mask for valid frames frame_counts = self.frame_counts.unsqueeze(1) # [num_motions, 1] indices = torch.arange(self.max_motion_len, device=self.device).unsqueeze( 0 ) # [1, max_motion_len] valid_frame_mask = indices < frame_counts # [num_motions, max_motion_len] # Set jitter to 0 for invalid frames jitter = jitter.masked_fill(~valid_frame_mask, 0.0) return jitter
[docs] def compute_jitter_reduce_each_motion( self, num_bodies: int, aggregate_method: str = "mean" ) -> torch.Tensor: """ Compute jitter (2nd order finite differences of positions) and reduce across body dimensions. This method is specifically designed for rigid_body_pos data with shape [num_motions, max_motion_len, num_bodies*3]. It computes the L2 norm of 2nd order finite differences (pos[t+1] - 2*pos[t] + pos[t-1]) for each body, then aggregates across all bodies using the specified method. Output is zero-padded at the beginning to match input length. Args: num_bodies: Number of rigid bodies (to reshape the flattened data) aggregate_method: How to aggregate across bodies ("mean", "max", "sum") Returns: torch.Tensor: Jitter values with shape [num_motions, max_motion_len] (same as input) """ return self.compute_finite_difference_jitter_reduce_each_motion( num_bodies=num_bodies, aggregate_method=aggregate_method, order=2, field_description="rigid_body_pos", )
[docs] def compute_rotation_jitter_reduce_each_motion( self, num_bodies: int, aggregate_method: str = "mean" ) -> torch.Tensor: """ Compute rotation jitter (1st order finite differences of angular velocities) and reduce across body dimensions. This method is specifically designed for rigid_body_ang_vel data with shape [num_motions, max_motion_len, num_bodies*3]. It computes the L2 norm of 1st order finite differences (ang_vel[t+1] - ang_vel[t]) for each body, then aggregates across all bodies using the specified method. Output is zero-padded at the beginning to match input length. Args: num_bodies: Number of rigid bodies (to reshape the flattened data) aggregate_method: How to aggregate across bodies ("mean", "max", "sum") Returns: torch.Tensor: Rotation jitter values with shape [num_motions, max_motion_len] (same as input) """ return self.compute_finite_difference_jitter_reduce_each_motion( num_bodies=num_bodies, aggregate_method=aggregate_method, order=1, field_description="rigid_body_ang_vel", )
[docs] def jitter_mean_reduce_each_motion( self, num_bodies: int, aggregate_method: str = "mean" ) -> torch.Tensor: """ Compute jitter and then take the mean over time for each motion. Args: num_bodies: Number of rigid bodies aggregate_method: How to aggregate across bodies ("mean", "max", "sum") Returns: torch.Tensor: Mean jitter value for each motion [num_motions] """ return self._generic_jitter_mean_reduce_each_motion( num_bodies=num_bodies, aggregate_method=aggregate_method, jitter_method=self.compute_jitter_reduce_each_motion, )
[docs] def rotation_jitter_mean_reduce_each_motion( self, num_bodies: int, aggregate_method: str = "mean" ) -> torch.Tensor: """ Compute rotation jitter and then take the mean over time for each motion. Args: num_bodies: Number of rigid bodies aggregate_method: How to aggregate across bodies ("mean", "max", "sum") Returns: torch.Tensor: Mean rotation jitter value for each motion [num_motions] """ return self._generic_jitter_mean_reduce_each_motion( num_bodies=num_bodies, aggregate_method=aggregate_method, jitter_method=self.compute_rotation_jitter_reduce_each_motion, )
def _generic_jitter_mean_reduce_each_motion( self, num_bodies: int, aggregate_method: str, jitter_method: Callable ) -> torch.Tensor: """ Generic helper method to compute jitter and take the mean over time for each motion. Args: num_bodies: Number of rigid bodies aggregate_method: How to aggregate across bodies ("mean", "max", "sum") jitter_method: The method to call for computing jitter Returns: torch.Tensor: Mean jitter value for each motion [num_motions] """ jitter = jitter_method( num_bodies, aggregate_method ) # [num_motions, max_motion_len] # Apply the unfilled mask to only consider valid frames mask = self.get_unfilled_mask()[ :, :, 0 ] # [num_motions, max_motion_len] (use first sub-feature mask) jitter_masked = jitter.masked_fill(mask, 0.0) # Sum jitter values and divide by valid count jitter_sum = jitter_masked.sum(dim=1) # [num_motions] valid_frame_counts = self.frame_counts.clamp( min=1 ) # [num_motions], clamp to avoid division by zero mean_jitter = jitter_sum / valid_frame_counts # [num_motions] # Set mean to 0 for motions with no valid frames zero_frame_mask = self.frame_counts == 0 mean_jitter = mean_jitter.masked_fill(zero_frame_mask, 0.0) return mean_jitter
[docs] def copy_from( self, other: "MotionMetrics", ) -> None: """Copy data from another MotionMetrics object.""" self.data = other.data.clone() self.frame_counts = other.frame_counts.clone()
[docs] def copy_from_motion_ids( self, other: "MotionMetrics", motion_ids: torch.Tensor, ) -> None: """Copy data from another MotionMetrics object for specific motions.""" self.data[motion_ids] = other.data[motion_ids] self.frame_counts[motion_ids] = other.frame_counts[motion_ids]
[docs] def merge_from( self, other: "MotionMetrics", ) -> None: """Merge data from another MotionMetrics object.""" assert self.max_motion_len == other.max_motion_len assert self.num_sub_features == other.num_sub_features self.data = torch.cat([self.data, other.data], dim=0) self.frame_counts = torch.cat([self.frame_counts, other.frame_counts], dim=0) self.motion_lens = torch.cat([self.motion_lens, other.motion_lens], dim=0) self.num_motions = self.data.shape[0]
[docs] def reset(self) -> None: """Reset all stored data and frame counts.""" self.data.zero_() self.frame_counts.zero_()
[docs] def to(self, device: torch.device) -> "MotionMetrics": """Move metrics to specified device.""" self.device = device self.data = self.data.to(device) self.frame_counts = self.frame_counts.to(device) return self