Source code for protomotions.utils.motion_interpolation_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Motion interpolation utilities.

Provides functions for smoothly interpolating between motion frames,
including linear position interpolation and spherical quaternion interpolation (SLERP).
"""

import torch
from protomotions.utils import rotations


[docs] def interpolate_pos(pos0, pos1, blend): """Linear interpolation between two position tensors. Args: pos0: Starting positions [batch, ...] or [batch, bodies, 3] pos1: Ending positions [batch, ...] or [batch, bodies, 3] blend: Blend factor [batch] where 0=pos0, 1=pos1 Returns: Interpolated positions with same shape as pos0/pos1 """ if pos1.dim() == 2: blend = blend.unsqueeze(-1) elif pos1.dim() == 3: blend = blend.unsqueeze(-1).unsqueeze(-1) else: raise ValueError(f"pos1 has {pos1.dim()} dimensions, expected 2 or 3") return (1.0 - blend) * pos0 + blend * pos1
[docs] def interpolate_quat(rot0, rot1, blend): """Spherical linear interpolation (SLERP) between quaternions. Args: rot0: Starting quaternions [batch, 4] or [batch, bodies, 4] rot1: Ending quaternions [batch, 4] or [batch, bodies, 4] blend: Blend factor [batch] where 0=rot0, 1=rot1 Returns: Interpolated quaternions with same shape as rot0/rot1 """ if rot1.dim() == 2: blend = blend.unsqueeze(-1) elif rot1.dim() == 3: blend = blend.unsqueeze(-1).unsqueeze(-1) else: raise ValueError(f"rot1 has {rot1.dim()} dimensions, expected 2 or 3") return rotations.slerp(rot0, rot1, blend)
[docs] def calc_frame_blend(time, length, num_frames, dt): """ Calculate frame indices and blend factor for interpolation. Args: time (torch.Tensor): Current time. length (torch.Tensor): Length of the motion sequence in seconds. num_frames (torch.Tensor): Number of frames in the motion sequence. dt (torch.Tensor): Time step between frames. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Frame index 0, frame index 1, and blend factor. """ phase = time / length phase = torch.clip(phase, 0.0, 1.0) frame_idx0 = (phase * (num_frames - 1)).long() frame_idx1 = torch.min(frame_idx0 + 1, num_frames - 1) blend = (time - frame_idx0 * dt) / dt return frame_idx0, frame_idx1, blend