protomotions.agents.ase.model module¶
- class protomotions.agents.ase.model.ASEDiscriminatorEncoder(*args, **kwargs)[source]¶
Bases:
DiscriminatorDiscriminator with MI encoder head for ASE.
Inherits from Discriminator and adds an encoder head for mutual information learning.
- config: ASEDiscriminatorEncoderConfig¶
- forward(tensordict)[source]¶
Forward pass computing discriminator and MI encoder outputs.
- Parameters:
tensordict (MockTensorDict) – TensorDict containing observations and latents.
- Returns:
TensorDict with disc_logits and mi_enc_output added.
- Return type:
MockTensorDict
- compute_mi_reward(tensordict, mi_hypersphere_reward_shift)[source]¶
Computes the Mutual Information based reward.
- Parameters:
tensordict (MockTensorDict) – TensorDict with mi_enc_output and latents.
mi_hypersphere_reward_shift (bool) – Whether to shift reward to [0, 1].
- Returns:
Mutual Information reward tensor.
- Return type:
- calc_von_mises_fisher_enc_error(enc_pred, latent)[source]¶
Calculates the Von Mises-Fisher error between predicted and true latent vectors.
- Parameters:
enc_pred (torch.Tensor) – Predicted encoded latent vector. Shape (batch_size, latent_dim).
latent (torch.Tensor) – True latent vector. Shape (batch_size, latent_dim).
- Returns:
Von Mises-Fisher error. Shape (batch_size, 1).
- Return type:
- all_weights()[source]¶
Returns all weights from all sequential modules (trunk + discriminator + encoder).
Uses explicit walking to avoid duplicates in nested structures.
- Returns:
List of all weight parameters.
- Return type:
List[nn.Parameter]
- all_discriminator_weights()[source]¶
Returns weights of discriminator part only (excludes encoder head).
Explicitly walks through sequential_models to avoid including encoder head.
- Returns:
List of discriminator weight parameters.
- Return type:
List[nn.Parameter]
- logit_weights()[source]¶
Returns the weights of the final discriminator layer.
- Returns:
List containing the weight parameter of the discriminator’s output layer.
- Return type:
List[nn.Parameter]