Skip to content

Embedding

The embedding layer converts a TensorDict[B, S] of step records into the flat token sequence [B, S*T, D] consumed by the backbone.

StepEmbedder

mouse.models.embedding.embedding.StepEmbedder

StepEmbedder(hidden_dim: int, max_num_actions: int, max_num_obs_continuous: int, max_num_obs_discrete: int, max_num_obs_image: int, max_num_time_steps: int, include_action_token: bool, include_done_token: bool, include_reward_token: bool, include_obs_continuous: bool, include_obs_discrete: bool, include_obs_image: bool, include_time_token: bool, include_type_token: bool, token_data_len: int, num_compute_tokens: int = 0, concat_modalities: bool = False, fourier_in_min: float = -1.0, fourier_in_max: float = 1.0, std: float = 0.02)

Bases: Module

Converts a batch of step records [B, S] into embedding sequences [B, S*T, D].

tokens_per_step is fixed at construction time so that the backbone always receives a consistently-shaped input.

Two embedding modes are available:

Sum mode (concat_modalities=False, default): Every modality contributes token_data_len tokens; contributions are summed at each position so the output is always exactly token_data_len data tokens per step regardless of which modalities are active.

Concat mode (concat_modalities=True): Each active modality occupies its own dedicated block of token_data_len tokens, laid out sequentially. tokens_per_step = num_active_modalities * token_data_len + num_compute_tokens.

Compute tokens (num_compute_tokens > 0): K learned scratch tokens are appended after the data tokens in every step block. The backbone can attend to and write into them as working memory. The step representation is always pooled from the last token (the last compute token when K > 0).

Parameters:

Name Type Description Default
hidden_dim int

Model hidden dimension D.

required
max_num_actions int

Size of the action embedding table.

required
max_num_obs_continuous int

Continuous obs vector length; must be > 0 when include_obs_continuous is True.

required
max_num_obs_discrete int

Discrete obs vector length; must be > 0 when include_obs_discrete is True.

required
max_num_obs_image int

Total pixel count per image; must be > 0 when include_obs_image is True.

required
max_num_time_steps int

TIME embedding table size; must be > 0 when include_time_token is True.

required
include_action_token bool

Emit an ACTION token per step.

required
include_done_token bool

Emit a DONE token per step.

required
include_reward_token bool

Emit a REWARD token per step.

required
include_obs_continuous bool

Emit an OBS_CONTINUOUS token per step.

required
include_obs_discrete bool

Emit an OBS_DISCRETE token per step.

required
include_obs_image bool

Emit an OBS_IMAGE token per step.

required
include_time_token bool

Emit a TIME token per step.

required
include_type_token bool

Add the learned type embedding to every token.

required
token_data_len int

Number of tokens T produced per modality.

required
num_compute_tokens int

Number of learned scratch tokens K appended after the data tokens within each step block. 0 disables compute tokens.

0
concat_modalities bool

When True, modality embeddings are concatenated sequentially rather than summed. Each active modality occupies its own token_data_len slots.

False
fourier_in_min float

Smallest input value the RFF resolves.

-1.0
fourier_in_max float

Largest input value the RFF covers.

1.0
std float

Initialisation std for embedding tables.

0.02
Source code in mouse/models/embedding/embedding.py
def __init__(
    self,
    hidden_dim: int,
    max_num_actions: int,
    max_num_obs_continuous: int,
    max_num_obs_discrete: int,
    max_num_obs_image: int,
    max_num_time_steps: int,
    include_action_token: bool,
    include_done_token: bool,
    include_reward_token: bool,
    include_obs_continuous: bool,
    include_obs_discrete: bool,
    include_obs_image: bool,
    include_time_token: bool,
    include_type_token: bool,
    token_data_len: int,
    num_compute_tokens: int = 0,
    concat_modalities: bool = False,
    fourier_in_min: float = -1.0,
    fourier_in_max: float = 1.0,
    std: float = 0.02,
) -> None:
    super().__init__()

    _size_checks = [
        ("include_action_token", include_action_token, "max_num_actions", max_num_actions),
        ("include_obs_continuous", include_obs_continuous, "max_num_obs_continuous", max_num_obs_continuous),
        ("include_obs_discrete", include_obs_discrete, "max_num_obs_discrete", max_num_obs_discrete),
        ("include_obs_image", include_obs_image, "max_num_obs_image", max_num_obs_image),
        ("include_time_token", include_time_token, "max_num_time_steps", max_num_time_steps),
    ]
    for inc_name, inc_val, size_name, size_val in _size_checks:
        if inc_val and int(size_val) <= 0:
            raise ValueError(f"{inc_name} is True but {size_name} is {size_val} (must be > 0).")

    if int(num_compute_tokens) < 0:
        raise ValueError(f"num_compute_tokens must be >= 0, got {num_compute_tokens}.")

    self.hidden_dim = int(hidden_dim)
    self.include_action_token = bool(include_action_token)
    self.include_time_token = bool(include_time_token)
    self.include_done_token = bool(include_done_token)
    self.include_reward_token = bool(include_reward_token)
    self.include_obs_continuous = bool(include_obs_continuous)
    self.include_obs_discrete = bool(include_obs_discrete)
    self.include_obs_image = bool(include_obs_image)
    self.include_type_token = bool(include_type_token)
    self.num_compute_tokens = int(num_compute_tokens)
    self.concat_modalities = bool(concat_modalities)
    self.token_data_len = int(token_data_len)

    # Count active data modalities (determines tokens_per_step in concat mode).
    self._num_data_modalities: int = sum([
        include_time_token,
        include_action_token,
        include_obs_continuous,
        include_obs_discrete,
        include_obs_image,
        include_reward_token,
        include_done_token,
    ])

    # Compute tokens_per_step.
    T = self.token_data_len
    K = self.num_compute_tokens
    if concat_modalities:
        data_slots = self._num_data_modalities * T
    else:
        data_slots = T
    self.tokens_per_step: int = data_slots + K

    # Shared type embedding (only used internally, not in returned token_types).
    self.type_embedder = TypeEmbedder(hidden_dim=hidden_dim, token_data_len=T, embedding_std=std)

    # Action (optional)
    self.action_embedder = (
        ActionEmbedder(hidden_dim=hidden_dim, token_data_len=T, max_num_actions=int(max_num_actions), embedding_std=std)
        if include_action_token else None
    )

    # Time (optional)
    self.time_embedder = (
        TimeEmbedder(hidden_dim=hidden_dim, token_data_len=T, max_num_time_steps=int(max_num_time_steps), embedding_std=std)
        if include_time_token else None
    )

    # Done (optional)
    self.done_embedder = (
        DoneEmbedder(hidden_dim=hidden_dim, token_data_len=T, embedding_std=std)
        if include_done_token else None
    )

    # Reward (optional)
    self.reward_embedder = (
        RewardEmbedder(hidden_dim=hidden_dim, token_data_len=T, in_min=fourier_in_min, in_max=fourier_in_max, embedding_std=std)
        if include_reward_token else None
    )

    # Continuous obs (optional)
    if include_obs_continuous:
        self.obs_continuous_embedder = ObsContinuousEmbedder(
            hidden_dim=hidden_dim, max_num_obs=int(max_num_obs_continuous), token_data_len=T,
            in_min=fourier_in_min, in_max=fourier_in_max, embedding_std=std,
        )
    else:
        self.obs_continuous_embedder = None

    # Discrete obs (optional)
    self.obs_discrete_embedder = (
        ObsDiscreteEmbedder(
            hidden_dim=hidden_dim, max_num_obs=int(max_num_obs_discrete), token_data_len=T, embedding_std=std
        )
        if include_obs_discrete else None
    )

    # Image obs (optional)
    self.obs_image_embedder = (
        ObsImageEmbedder(
            hidden_dim=hidden_dim, max_num_obs=int(max_num_obs_image), token_data_len=T, embedding_std=std
        )
        if include_obs_image else None
    )

    # Compute tokens — one shared learned embedding per compute slot, broadcast over (B, S).
    if K > 0:
        self.compute_embed = nn.Parameter(torch.empty(K, int(hidden_dim)))
        nn.init.normal_(self.compute_embed, std=std)
    else:
        self.compute_embed = None  # type: ignore[assignment]

forward

forward(step_stream: TensorDict) -> tuple[torch.Tensor, torch.Tensor]

Embed a batch of steps.

Parameters:

Name Type Description Default
step_stream TensorDict

TensorDict of shape [B, S].

required

Returns:

Name Type Description
embeds Tensor

[B, S*tokens_per_step, D] — per-position embedding vectors.

token_types Tensor

[B, S*tokens_per_step] int64 — TokenType id at each position (used by the backbone to build the attention mask). Data positions carry their modality's TokenType; compute positions carry TokenType.COMPUTE; unused padding positions carry TokenType.PAD (0).

Source code in mouse/models/embedding/embedding.py
def forward(
    self,
    step_stream: TensorDict,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Embed a batch of steps.

    Args:
        step_stream: TensorDict of shape ``[B, S]``.

    Returns:
        embeds:      ``[B, S*tokens_per_step, D]`` — per-position embedding vectors.
        token_types: ``[B, S*tokens_per_step]`` int64 — ``TokenType`` id at each
                     position (used by the backbone to build the attention mask).
                     Data positions carry their modality's ``TokenType``; compute
                     positions carry ``TokenType.COMPUTE``; unused padding positions
                     carry ``TokenType.PAD`` (0).
    """
    device = next(self.parameters()).device
    step_stream = step_stream.to(device)

    B, S = int(step_stream.batch_size[0]), int(step_stream.batch_size[1])
    T, D = self.token_data_len, self.hidden_dim

    # Fetch each active modality from the step stream.
    action   = step_stream["action"]          if self.include_action_token   else None
    reward   = step_stream["reward"]          if self.include_reward_token   else None
    done     = step_stream["done"]            if self.include_done_token     else None
    time_idx = step_stream["time"]            if self.include_time_token     else None
    obs_cont = step_stream["obs_continuous"]  if self.include_obs_continuous else None
    obs_disc = step_stream["obs_discrete"]    if self.include_obs_discrete   else None
    obs_img  = step_stream["obs_image"]       if self.include_obs_image      else None

    dtype = torch.get_default_dtype()

    if self.concat_modalities:
        data_embeds, data_types = self._forward_concat(
            B, S, T, D, device, dtype,
            action, reward, done, time_idx, obs_cont, obs_disc, obs_img,
        )
    else:
        data_embeds, data_types = self._forward_sum(
            B, S, T, D, device, dtype,
            action, reward, done, time_idx, obs_cont, obs_disc, obs_img,
        )

    # Append compute tokens.
    K = self.num_compute_tokens
    if K > 0:
        c = self.compute_embed.to(dtype=dtype)          # [K, D]
        c = c.view(1, 1, K, D).expand(B, S, K, D)
        embeds = torch.cat([data_embeds, c], dim=2)     # [B, S, data_slots+K, D]

        c_types = torch.full((B, S, K), int(TokenType.COMPUTE), device=device, dtype=torch.long)
        token_types = torch.cat([data_types, c_types], dim=2)  # [B, S, total]
    else:
        embeds = data_embeds
        token_types = data_types

    total_T = embeds.shape[2]
    return embeds.reshape(B, S * total_T, D), token_types.reshape(B, S * total_T)

TokenType

mouse.models.embedding.embedding.TokenType

Bases: IntEnum

Token type identifiers used by StepEmbedder when building the embedding sequence.


Per-modality embedders

These are constructed and owned by StepEmbedder. They are documented here for reference.

mouse.models.embedding.embedding.ActionEmbedder

ActionEmbedder(hidden_dim: int, token_data_len: int, max_num_actions: int, embedding_std: float = 0.02)

Bases: Module

Embeds a discrete action id → flat content vector [N, T*D].

Source code in mouse/models/embedding/embedding.py
def __init__(self, hidden_dim: int, token_data_len: int, max_num_actions: int, embedding_std: float = 0.02) -> None:
    super().__init__()
    self.embed = ScaledEmbedding(
        num_embeddings=max_num_actions, embedding_dim=hidden_dim * token_data_len, scale=embedding_std
    )

forward

forward(action: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
action Tensor

[N] int64 action indices.

required

Returns: [N, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, action: torch.Tensor) -> torch.Tensor:
    """Args:
        action: ``[N]`` int64 action indices.
    Returns:
        ``[N, T*D]`` content embedding.
    """
    return self.embed(action)

mouse.models.embedding.embedding.RewardEmbedder

RewardEmbedder(hidden_dim: int, token_data_len: int, in_min: float, in_max: float, embedding_std: float = 0.02)

Bases: Module

Embeds a scalar reward via Random Fourier Features → flat content vector [N, T*D].

Source code in mouse/models/embedding/embedding.py
def __init__(
    self,
    hidden_dim: int,
    token_data_len: int,
    in_min: float,
    in_max: float,
    embedding_std: float = 0.02,
) -> None:
    super().__init__()
    rff_scale = embedding_std / 0.5 ** 0.5
    self.rff = RandomFourierFeatures(
        num_features=hidden_dim * token_data_len, in_min=in_min, in_max=in_max, output_scale=rff_scale
    )

forward

forward(reward: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
reward Tensor

[N] float32 scalar rewards.

required

Returns: [N, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, reward: torch.Tensor) -> torch.Tensor:
    """Args:
        reward: ``[N]`` float32 scalar rewards.
    Returns:
        ``[N, T*D]`` content embedding.
    """
    return self.rff(reward, 0)

mouse.models.embedding.embedding.DoneEmbedder

DoneEmbedder(hidden_dim: int, token_data_len: int, embedding_std: float = 0.02)

Bases: Module

Embeds a ternary done flag → flat content vector [N, T*D].

Source code in mouse/models/embedding/embedding.py
def __init__(self, hidden_dim: int, token_data_len: int, embedding_std: float = 0.02) -> None:
    super().__init__()
    self.embed = ScaledEmbedding(num_embeddings=3, embedding_dim=hidden_dim * token_data_len, scale=embedding_std)

forward

forward(done: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
done Tensor

[N] int64 in {0, 1, 2}.

required

Returns: [N, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, done: torch.Tensor) -> torch.Tensor:
    """Args:
        done: ``[N]`` int64 in {0, 1, 2}.
    Returns:
        ``[N, T*D]`` content embedding.
    """
    return self.embed(done)

mouse.models.embedding.embedding.TimeEmbedder

TimeEmbedder(hidden_dim: int, token_data_len: int, max_num_time_steps: int, embedding_std: float = 0.02)

Bases: Module

Embeds episode step index → flat content vector [N, T*D].

Positions with time_idx < 0 are treated as absent and produce a zero vector.

Source code in mouse/models/embedding/embedding.py
def __init__(
    self, hidden_dim: int, token_data_len: int, max_num_time_steps: int, embedding_std: float = 0.02
) -> None:
    super().__init__()
    self.embed = ScaledEmbedding(
        num_embeddings=max_num_time_steps, embedding_dim=hidden_dim * token_data_len, scale=embedding_std
    )

forward

forward(time_idx: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
time_idx Tensor

[N] int64; negative values mean the field is absent.

required

Returns: [N, T*D] content embedding (zero where time_idx < 0).

Source code in mouse/models/embedding/embedding.py
def forward(self, time_idx: torch.Tensor) -> torch.Tensor:
    """Args:
        time_idx: ``[N]`` int64; negative values mean the field is absent.
    Returns:
        ``[N, T*D]`` content embedding (zero where time_idx < 0).
    """
    return self.embed(time_idx)

mouse.models.embedding.embedding.ObsContinuousEmbedder

ObsContinuousEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, in_min: float, in_max: float, embedding_std: float = 0.02)

Bases: Module

Embeds continuous observations → flat content vector [N, T*D].

Each obs dimension is projected via a position-indexed RFF; all contributions are summed to give the final [N, T*D] output.

Source code in mouse/models/embedding/embedding.py
def __init__(
    self,
    hidden_dim: int,
    max_num_obs: int,
    token_data_len: int,
    in_min: float,
    in_max: float,
    embedding_std: float = 0.02,
) -> None:
    super().__init__()
    self.max_num_obs = max_num_obs
    # cos has std ≈ 1/√2; sum over max_num_obs dims grows by √max_num_obs → divide both out
    rff_scale = embedding_std / (0.5 ** 0.5 * max_num_obs ** 0.5)
    self.rff = RandomFourierFeatures(
        num_features=hidden_dim * token_data_len, in_min=in_min, in_max=in_max,
        num_freq_sets=max_num_obs, output_scale=rff_scale,
    )

forward

forward(obs: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
obs Tensor

[*batch, max_num_obs] float32 observations.

required

Returns: [*batch, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, obs: torch.Tensor) -> torch.Tensor:
    """Args:
        obs: ``[*batch, max_num_obs]`` float32 observations.
    Returns:
        ``[*batch, T*D]`` content embedding.
    """
    positions = torch.arange(self.max_num_obs, device=obs.device).expand_as(obs)
    return self.rff(obs.float(), positions).sum(dim=-2)

mouse.models.embedding.embedding.ObsContinuousLinearEmbedder

ObsContinuousLinearEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, input_std: float = 1.0, embedding_std: float = 0.02)

Bases: Module

Embeds continuous observations → flat content vector [N, T*D].

Each obs dimension is projected via a position-specific learned linear map applied directly to the scalar value; all contributions are summed to give the final [N, T*D] output. Unlike :class:ObsContinuousEmbedder this uses no random features — the obs value scales a learned direction.

Parameters:

Name Type Description Default
hidden_dim int

Model hidden dimension D.

required
max_num_obs int

Length of the continuous obs vector.

required
token_data_len int

Number of tokens T per step.

required
input_std float

Expected std of the incoming obs values, used to normalise the linear initialisation. Defaults to 1.0.

1.0
embedding_std float

Desired output std of the embedding. Defaults to 0.02.

0.02
Source code in mouse/models/embedding/embedding.py
def __init__(
    self,
    hidden_dim: int,
    max_num_obs: int,
    token_data_len: int,
    input_std: float = 1.0,
    embedding_std: float = 0.02,
) -> None:
    super().__init__()
    self.max_num_obs = max_num_obs
    # Kaiming uniform for in_features=1 has std = 1/√3 (Uniform[-1,1]).
    # ScaledPosLinear multiplies those weights by scale, so per-dim output std =
    # scale × (1/√3) × input_std.  Divide scale by (1/√3) × √max_num_obs to
    # hit embedding_std after summing max_num_obs independent dims.
    _kaiming_std = 3.0 ** -0.5
    self.projs = ScaledPosLinear(
        num_positions=max_num_obs,
        in_features=1,
        out_features=hidden_dim * token_data_len,
        scale=embedding_std / (_kaiming_std * input_std),
    )

forward

forward(obs: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
obs Tensor

[*batch, max_num_obs] float32 observations.

required

Returns: [*batch, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, obs: torch.Tensor) -> torch.Tensor:
    """Args:
        obs: ``[*batch, max_num_obs]`` float32 observations.
    Returns:
        ``[*batch, T*D]`` content embedding.
    """
    positions = torch.arange(self.max_num_obs, device=obs.device).expand_as(obs)
    return self.projs(obs.float().unsqueeze(-1), positions).sum(dim=-2)

mouse.models.embedding.embedding.ObsDiscreteEmbedder

ObsDiscreteEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, embedding_std: float = 0.02)

Bases: Module

Embeds a scalar discrete state index → flat content vector [N, T*D].

The state index is looked up in a learned embedding table of size max_num_obs (the state-space cardinality).

Source code in mouse/models/embedding/embedding.py
def __init__(
    self,
    hidden_dim: int,
    max_num_obs: int,
    token_data_len: int,
    embedding_std: float = 0.02,
) -> None:
    super().__init__()
    self.max_num_obs = max_num_obs
    # Summing max_num_obs independent N(0, scale) rows inflates std by √max_num_obs.
    # Note: if obs values are uniform integers in [0, max_num_obs-1], collisions further
    # inflate by ≈√(2·max_num_obs-1)/√max_num_obs; √max_num_obs is a reasonable approximation.
    self.embed = ScaledEmbedding(
        num_embeddings=max_num_obs, embedding_dim=hidden_dim * token_data_len,
        scale=embedding_std / max_num_obs ** 0.5,
    )

forward

forward(obs: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
obs Tensor

[*batch] int64 discrete state index.

required

Returns: [*batch, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, obs: torch.Tensor) -> torch.Tensor:
    """Args:
        obs: ``[*batch]`` int64 discrete state index.
    Returns:
        ``[*batch, T*D]`` content embedding.
    """
    return self.embed(obs)

mouse.models.embedding.embedding.ObsImageEmbedder

ObsImageEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, embedding_std: float = 0.02)

Bases: Module

Embeds image pixels → flat content vector [N, T*D].

Each pixel is projected via a position-specific linear map on the normalised pixel value; all contributions are summed to give the final [N, T*D] output.

Source code in mouse/models/embedding/embedding.py
def __init__(
    self,
    hidden_dim: int,
    max_num_obs: int,
    token_data_len: int,
    embedding_std: float = 0.02,
) -> None:
    super().__init__()
    self.max_num_obs = max_num_obs
    # pixel_norm_std: std of NormalizedPixel output  ≈ std of Uniform[-1,1] = 1/√3
    # _kaiming_std:   std of Kaiming-uniform weights for in_features=1 = 1/√3
    # per-dim output std = scale × _kaiming_std × pixel_norm_std
    # Divide scale by that product × √max_num_obs to hit embedding_std after summing.
    pixel_norm_std = 3.0 ** -0.5
    _kaiming_std = 3.0 ** -0.5
    self.norm = NormalizedPixel()
    self.projs = ScaledPosLinear(
        num_positions=max_num_obs, in_features=1, out_features=hidden_dim * token_data_len,
        scale=embedding_std / (_kaiming_std * pixel_norm_std * max_num_obs ** 0.5),
    )

forward

forward(obs: Tensor) -> torch.Tensor

Parameters:

Name Type Description Default
obs Tensor

[*batch, max_num_obs] int64/float pixel values.

required

Returns: [*batch, T*D] content embedding.

Source code in mouse/models/embedding/embedding.py
def forward(self, obs: torch.Tensor) -> torch.Tensor:
    """Args:
        obs: ``[*batch, max_num_obs]`` int64/float pixel values.
    Returns:
        ``[*batch, T*D]`` content embedding.
    """
    positions = torch.arange(self.max_num_obs, device=obs.device).expand_as(obs)
    normalized = self.norm(obs.float()).unsqueeze(-1)            # [*batch, max_num_obs, 1]
    return self.projs(normalized, positions).sum(dim=-2)

mouse.models.embedding.embedding.TypeEmbedder

TypeEmbedder(hidden_dim: int, token_data_len: int, embedding_std: float = 0.02)

Bases: Module

Shared token-type embedding table. Maps a TokenType[N, T*D].

Source code in mouse/models/embedding/embedding.py
def __init__(self, hidden_dim: int, token_data_len: int, embedding_std: float = 0.02) -> None:
    super().__init__()
    self.embed = ScaledEmbedding(num_embeddings=8, embedding_dim=hidden_dim * token_data_len, scale=embedding_std)

Encoding

mouse.models.embedding.encoding.RandomFourierFeatures

RandomFourierFeatures(num_features: int, in_min: float = 0.01, in_max: float = 100.0, num_freq_sets: int = 1, output_scale: float = 1.0, dtype: dtype = torch.float32)

Bases: Module

Random Fourier Features (Rahimi & Recht, 2007) with log-uniform frequencies.

Each feature is cos(ωx + b) where ω is sampled log-uniformly over [1/in_max, 1/in_min] and b ~ Uniform(0, 2π). The random phase breaks the even-function symmetry of plain cos, making the encoding sign-sensitive (x ≠ −x) while halving the output size vs a sin+cos encoding.

in_min and in_max are expressed in input-space units: - the lowest frequency (ω = 1/in_max) completes one cycle across in_max units; - the highest frequency (ω = 1/in_min) completes one cycle across in_min units.

num_freq_sets independent (ω, b) pairs are sampled, forming banks of shape (num_freq_sets, num_features). forward requires a freq_idx integer tensor of the same shape as x to select which set to use per element. Both buffers are persistent so they are saved with the checkpoint.

Output has per-dim std ≈ output_scale / √2 before the affine.

After the cosine, a learnable per-(freq set, feature) weight is applied (shape (num_freq_sets, num_features); initialised to output_scale).

Source code in mouse/models/embedding/encoding.py
def __init__(
    self,
    num_features: int,
    in_min: float = 1e-2,
    in_max: float = 1e2,
    num_freq_sets: int = 1,
    output_scale: float = 1.0,
    dtype: torch.dtype = torch.float32,
):
    super().__init__()
    if num_features <= 0:
        raise ValueError("num_features must be > 0")
    if in_min <= 0 or in_max <= 0:
        raise ValueError("in_min and in_max must be > 0")
    if in_min >= in_max:
        raise ValueError("in_min must be < in_max")
    if num_freq_sets < 1:
        raise ValueError("num_freq_sets must be >= 1")
    # ω = 1/x so that one cycle spans x units of input: ω_min = 1/in_max, ω_max = 1/in_min
    log_w_min = math.log(1.0 / in_max)
    log_w_max = math.log(1.0 / in_min)
    freqs = torch.empty(num_freq_sets, num_features, dtype=dtype).uniform_(log_w_min, log_w_max).exp()
    phases = torch.empty(num_freq_sets, num_features, dtype=dtype).uniform_(0.0, 2.0 * math.pi)
    self.register_buffer("freqs", freqs, persistent=True)
    self.register_buffer("phases", phases, persistent=True)
    self.weight = nn.Parameter(torch.full((num_freq_sets, num_features), float(output_scale), dtype=dtype))

forward

forward(x: Tensor, freq_idx: Tensor | int) -> torch.Tensor

Map scalar inputs to RFF embeddings.

Parameters:

Name Type Description Default
x Tensor

Scalar inputs, shape (*batch,).

required
freq_idx Tensor | int

Which frequency set(s) to use. Either: - an int constant — same set broadcast over all elements, or - an integer tensor of the same shape as x — one set per element.

required

Returns: Tensor of shape (*batch, num_features).

Source code in mouse/models/embedding/encoding.py
def forward(self, x: torch.Tensor, freq_idx: torch.Tensor | int) -> torch.Tensor:
    """Map scalar inputs to RFF embeddings.

    Args:
        x:        Scalar inputs, shape ``(*batch,)``.
        freq_idx: Which frequency set(s) to use. Either:
                  - an ``int`` constant — same set broadcast over all elements, or
                  - an integer tensor of the same shape as ``x`` — one set per element.
    Returns:
        Tensor of shape ``(*batch, num_features)``.
    """
    if isinstance(freq_idx, torch.Tensor):
        assert x.shape == freq_idx.shape, (
            f"x and freq_idx must have the same shape, got {x.shape} and {freq_idx.shape}"
        )
    freqs = self.get_buffer("freqs")    # (num_freq_sets, num_features)
    phases = self.get_buffer("phases")  # (num_freq_sets, num_features)
    x = x.to(dtype=freqs.dtype)
    w = freqs[freq_idx]   # (num_features,) or (*batch, num_features)
    b = phases[freq_idx]  # (num_features,) or (*batch, num_features)
    aw = self.weight[freq_idx]
    return aw * (x.unsqueeze(-1) * w + b).cos()

mouse.models.embedding.encoding.NormalizedPixel

Bases: Module

Maps integer pixel values (0-255) to [-1, 1].


Linear layers

mouse.models.embedding.linear.ScaledEmbedding

ScaledEmbedding(num_embeddings: int, embedding_dim: int, scale: float = 1.0, **kwargs)

Bases: Embedding

nn.Embedding with default Normal(0, 1) init multiplied by scale.

Source code in mouse/models/embedding/linear.py
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 1.0, **kwargs) -> None:
    super().__init__(num_embeddings, embedding_dim, **kwargs)
    self.weight.data.mul_(scale)

mouse.models.embedding.linear.ScaledLinear

ScaledLinear(in_features: int, out_features: int, scale: float, bias: bool = True, device: device | str | None = None, dtype: dtype | None = None)

Bases: Linear

Linear layer with Kaiming-uniform init multiplied by scale.

Kaiming-uniform is applied first (a=sqrt(5), same as nn.Linear default), then both weights and bias are multiplied by scale in-place.

Source code in mouse/models/embedding/linear.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    scale: float,
    bias: bool = True,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> None:
    super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
    sc = float(scale)
    if sc < 0.0:
        raise ValueError(f"scale must be >= 0, got {scale!r}.")
    with torch.no_grad():
        self.weight.mul_(sc)
        if self.bias is not None:
            self.bias.mul_(sc)

mouse.models.embedding.linear.PosLinear

PosLinear(num_positions: int, in_features: int, out_features: int, device: device | str | None = None, dtype: dtype | None = None)

Bases: Module

Position-conditioned linear projection.

Stores one independent (weight, bias) pair per position index. At forward time, each element selects its projection via an integer position index, enabling per-dimension embeddings without separate nn.Linear instances.

Parameters:

Name Type Description Default
num_positions int

Number of distinct positions (embedding table size).

required
in_features int

Input feature dimension.

required
out_features int

Output feature dimension.

required
device device | str | None

Tensor device for parameters.

None
dtype dtype | None

Tensor dtype for parameters.

None
Source code in mouse/models/embedding/linear.py
def __init__(
    self,
    num_positions: int,
    in_features: int,
    out_features: int,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> None:
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features
    factory = {"device": device, "dtype": dtype}
    self.weight = nn.Parameter(torch.empty(num_positions, out_features, in_features, **factory))
    self.bias = nn.Parameter(torch.empty(num_positions, out_features, **factory))
    nn.init.kaiming_uniform_(self.weight.view(-1, in_features), a=5**0.5)
    nn.init.zeros_(self.bias)

forward

forward(x: Tensor, pos: Tensor) -> torch.Tensor

Apply the position-specific projection.

Parameters:

Name Type Description Default
x Tensor

Input tensor [*batch, in_features].

required
pos Tensor

Integer position indices [*batch]; same leading shape as x.

required

Returns:

Type Description
Tensor

Output tensor [*batch, out_features].

Source code in mouse/models/embedding/linear.py
def forward(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
    """Apply the position-specific projection.

    Args:
        x: Input tensor ``[*batch, in_features]``.
        pos: Integer position indices ``[*batch]``; same leading shape as ``x``.

    Returns:
        Output tensor ``[*batch, out_features]``.
    """
    x_flat = x.reshape(-1, x.shape[-1])
    pos_flat = pos.reshape(-1)
    w = self.weight[pos_flat]
    b = self.bias[pos_flat]
    out = torch.bmm(w, x_flat.unsqueeze(-1)).squeeze(-1) + b
    out_shape: tuple[int, ...] = tuple(int(s) for s in pos.shape) + (self.out_features,)
    return out.view(out_shape)

mouse.models.embedding.linear.ScaledPosLinear

ScaledPosLinear(num_positions: int, in_features: int, out_features: int, scale: float = 1.0, device: device | str | None = None, dtype: dtype | None = None)

Bases: PosLinear

PosLinear with Kaiming-uniform weights multiplied by scale.

Use scale = output_std / input_std to match a desired output standard deviation.

Source code in mouse/models/embedding/linear.py
def __init__(
    self,
    num_positions: int,
    in_features: int,
    out_features: int,
    scale: float = 1.0,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
) -> None:
    super().__init__(
        num_positions=num_positions,
        in_features=in_features,
        out_features=out_features,
        device=device,
        dtype=dtype,
    )
    if scale <= 0.0:
        raise ValueError(f"scale must be > 0, got {scale!r}.")
    with torch.no_grad():
        self.weight.mul_(scale)
        if self.bias is not None:
            self.bias.mul_(scale)