protomotions.agents.ase.model module

class protomotions.agents.ase.model.ASEDiscriminatorEncoder(*args, **kwargs)[source]

Bases: Discriminator

Discriminator with MI encoder head for ASE.

Inherits from Discriminator and adds an encoder head for mutual information learning.

config: ASEDiscriminatorEncoderConfig
__init__(config)[source]
forward(tensordict)[source]

Forward pass computing discriminator and MI encoder outputs.

Parameters:

tensordict (MockTensorDict) – TensorDict containing observations and latents.

Returns:

TensorDict with disc_logits and mi_enc_output added.

Return type:

MockTensorDict

compute_mi_reward(tensordict, mi_hypersphere_reward_shift)[source]

Computes the Mutual Information based reward.

Parameters:
  • tensordict (MockTensorDict) – TensorDict with mi_enc_output and latents.

  • mi_hypersphere_reward_shift (bool) – Whether to shift reward to [0, 1].

Returns:

Mutual Information reward tensor.

Return type:

torch.Tensor

calc_von_mises_fisher_enc_error(enc_pred, latent)[source]

Calculates the Von Mises-Fisher error between predicted and true latent vectors.

Parameters:
  • enc_pred (torch.Tensor) – Predicted encoded latent vector. Shape (batch_size, latent_dim).

  • latent (torch.Tensor) – True latent vector. Shape (batch_size, latent_dim).

Returns:

Von Mises-Fisher error. Shape (batch_size, 1).

Return type:

torch.Tensor

all_weights()[source]

Returns all weights from all sequential modules (trunk + discriminator + encoder).

Uses explicit walking to avoid duplicates in nested structures.

Returns:

List of all weight parameters.

Return type:

List[nn.Parameter]

all_discriminator_weights()[source]

Returns weights of discriminator part only (excludes encoder head).

Explicitly walks through sequential_models to avoid including encoder head.

Returns:

List of discriminator weight parameters.

Return type:

List[nn.Parameter]

logit_weights()[source]

Returns the weights of the final discriminator layer.

Returns:

List containing the weight parameter of the discriminator’s output layer.

Return type:

List[nn.Parameter]

all_enc_weights()[source]

Returns all weights of the encoder part only (includes trunk + encoder head).

Returns:

List of encoder weight parameters.

Return type:

List[nn.Parameter]

enc_weights()[source]

Returns the weights of the final encoder layer only.

Returns:

List containing the weight parameter of the encoder’s output layer.

Return type:

List[nn.Parameter]