Embedding¶
The embedding layer converts a TensorDict[B, S] of step records into the flat token sequence [B, S*T, D] consumed by the backbone.
StepEmbedder¶
mouse.models.embedding.embedding.StepEmbedder
¶
StepEmbedder(hidden_dim: int, max_num_actions: int, max_num_obs_continuous: int, max_num_obs_discrete: int, max_num_obs_image: int, max_num_time_steps: int, include_action_token: bool, include_done_token: bool, include_reward_token: bool, include_obs_continuous: bool, include_obs_discrete: bool, include_obs_image: bool, include_time_token: bool, include_type_token: bool, token_data_len: int, num_compute_tokens: int = 0, concat_modalities: bool = False, fourier_in_min: float = -1.0, fourier_in_max: float = 1.0, std: float = 0.02)
Bases: Module
Converts a batch of step records [B, S] into embedding sequences [B, S*T, D].
tokens_per_step is fixed at construction time so that the backbone always
receives a consistently-shaped input.
Two embedding modes are available:
Sum mode (concat_modalities=False, default):
Every modality contributes token_data_len tokens; contributions are
summed at each position so the output is always exactly
token_data_len data tokens per step regardless of which modalities
are active.
Concat mode (concat_modalities=True):
Each active modality occupies its own dedicated block of token_data_len
tokens, laid out sequentially.
tokens_per_step = num_active_modalities * token_data_len + num_compute_tokens.
Compute tokens (num_compute_tokens > 0):
K learned scratch tokens are appended after the data tokens in every
step block. The backbone can attend to and write into them as working
memory. The step representation is always pooled from the last token
(the last compute token when K > 0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Model hidden dimension |
required |
max_num_actions
|
int
|
Size of the action embedding table. |
required |
max_num_obs_continuous
|
int
|
Continuous obs vector length; must be > 0 when include_obs_continuous is True. |
required |
max_num_obs_discrete
|
int
|
Discrete obs vector length; must be > 0 when include_obs_discrete is True. |
required |
max_num_obs_image
|
int
|
Total pixel count per image; must be > 0 when include_obs_image is True. |
required |
max_num_time_steps
|
int
|
TIME embedding table size; must be > 0 when include_time_token is True. |
required |
include_action_token
|
bool
|
Emit an ACTION token per step. |
required |
include_done_token
|
bool
|
Emit a DONE token per step. |
required |
include_reward_token
|
bool
|
Emit a REWARD token per step. |
required |
include_obs_continuous
|
bool
|
Emit an OBS_CONTINUOUS token per step. |
required |
include_obs_discrete
|
bool
|
Emit an OBS_DISCRETE token per step. |
required |
include_obs_image
|
bool
|
Emit an OBS_IMAGE token per step. |
required |
include_time_token
|
bool
|
Emit a TIME token per step. |
required |
include_type_token
|
bool
|
Add the learned type embedding to every token. |
required |
token_data_len
|
int
|
Number of tokens |
required |
num_compute_tokens
|
int
|
Number of learned scratch tokens |
0
|
concat_modalities
|
bool
|
When |
False
|
fourier_in_min
|
float
|
Smallest input value the RFF resolves. |
-1.0
|
fourier_in_max
|
float
|
Largest input value the RFF covers. |
1.0
|
std
|
float
|
Initialisation std for embedding tables. |
0.02
|
Source code in mouse/models/embedding/embedding.py
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 | |
forward
¶
Embed a batch of steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_stream
|
TensorDict
|
TensorDict of shape |
required |
Returns:
| Name | Type | Description |
|---|---|---|
embeds |
Tensor
|
|
token_types |
Tensor
|
|
Source code in mouse/models/embedding/embedding.py
TokenType¶
mouse.models.embedding.embedding.TokenType
¶
Bases: IntEnum
Token type identifiers used by StepEmbedder when building the embedding sequence.
Per-modality embedders¶
These are constructed and owned by StepEmbedder. They are documented here for reference.
mouse.models.embedding.embedding.ActionEmbedder
¶
ActionEmbedder(hidden_dim: int, token_data_len: int, max_num_actions: int, embedding_std: float = 0.02)
Bases: Module
Embeds a discrete action id → flat content vector [N, T*D].
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action
|
Tensor
|
|
required |
Returns:
[N, T*D] content embedding.
mouse.models.embedding.embedding.RewardEmbedder
¶
RewardEmbedder(hidden_dim: int, token_data_len: int, in_min: float, in_max: float, embedding_std: float = 0.02)
Bases: Module
Embeds a scalar reward via Random Fourier Features → flat content vector [N, T*D].
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reward
|
Tensor
|
|
required |
Returns:
[N, T*D] content embedding.
mouse.models.embedding.embedding.DoneEmbedder
¶
Bases: Module
Embeds a ternary done flag → flat content vector [N, T*D].
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
done
|
Tensor
|
|
required |
Returns:
[N, T*D] content embedding.
mouse.models.embedding.embedding.TimeEmbedder
¶
TimeEmbedder(hidden_dim: int, token_data_len: int, max_num_time_steps: int, embedding_std: float = 0.02)
Bases: Module
Embeds episode step index → flat content vector [N, T*D].
Positions with time_idx < 0 are treated as absent and produce a zero vector.
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
time_idx
|
Tensor
|
|
required |
Returns:
[N, T*D] content embedding (zero where time_idx < 0).
Source code in mouse/models/embedding/embedding.py
mouse.models.embedding.embedding.ObsContinuousEmbedder
¶
ObsContinuousEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, in_min: float, in_max: float, embedding_std: float = 0.02)
Bases: Module
Embeds continuous observations → flat content vector [N, T*D].
Each obs dimension is projected via a position-indexed RFF; all contributions
are summed to give the final [N, T*D] output.
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
|
required |
Returns:
[*batch, T*D] content embedding.
Source code in mouse/models/embedding/embedding.py
mouse.models.embedding.embedding.ObsContinuousLinearEmbedder
¶
ObsContinuousLinearEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, input_std: float = 1.0, embedding_std: float = 0.02)
Bases: Module
Embeds continuous observations → flat content vector [N, T*D].
Each obs dimension is projected via a position-specific learned linear map
applied directly to the scalar value; all contributions are summed to give
the final [N, T*D] output. Unlike :class:ObsContinuousEmbedder this
uses no random features — the obs value scales a learned direction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dim
|
int
|
Model hidden dimension |
required |
max_num_obs
|
int
|
Length of the continuous obs vector. |
required |
token_data_len
|
int
|
Number of tokens |
required |
input_std
|
float
|
Expected std of the incoming obs values, used to normalise
the linear initialisation. Defaults to |
1.0
|
embedding_std
|
float
|
Desired output std of the embedding. Defaults to |
0.02
|
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
|
required |
Returns:
[*batch, T*D] content embedding.
Source code in mouse/models/embedding/embedding.py
mouse.models.embedding.embedding.ObsDiscreteEmbedder
¶
ObsDiscreteEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, embedding_std: float = 0.02)
Bases: Module
Embeds a scalar discrete state index → flat content vector [N, T*D].
The state index is looked up in a learned embedding table of size
max_num_obs (the state-space cardinality).
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
|
required |
Returns:
[*batch, T*D] content embedding.
mouse.models.embedding.embedding.ObsImageEmbedder
¶
ObsImageEmbedder(hidden_dim: int, max_num_obs: int, token_data_len: int, embedding_std: float = 0.02)
Bases: Module
Embeds image pixels → flat content vector [N, T*D].
Each pixel is projected via a position-specific linear map on the normalised
pixel value; all contributions are summed to give the final [N, T*D] output.
Source code in mouse/models/embedding/embedding.py
forward
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
|
required |
Returns:
[*batch, T*D] content embedding.
Source code in mouse/models/embedding/embedding.py
mouse.models.embedding.embedding.TypeEmbedder
¶
Bases: Module
Shared token-type embedding table. Maps a TokenType → [N, T*D].
Source code in mouse/models/embedding/embedding.py
Encoding¶
mouse.models.embedding.encoding.RandomFourierFeatures
¶
RandomFourierFeatures(num_features: int, in_min: float = 0.01, in_max: float = 100.0, num_freq_sets: int = 1, output_scale: float = 1.0, dtype: dtype = torch.float32)
Bases: Module
Random Fourier Features (Rahimi & Recht, 2007) with log-uniform frequencies.
Each feature is cos(ωx + b) where ω is sampled log-uniformly over
[1/in_max, 1/in_min] and b ~ Uniform(0, 2π). The random phase breaks the
even-function symmetry of plain cos, making the encoding sign-sensitive
(x ≠ −x) while halving the output size vs a sin+cos encoding.
in_min and in_max are expressed in input-space units:
- the lowest frequency (ω = 1/in_max) completes one cycle across in_max units;
- the highest frequency (ω = 1/in_min) completes one cycle across in_min units.
num_freq_sets independent (ω, b) pairs are sampled, forming banks of
shape (num_freq_sets, num_features). forward requires a freq_idx
integer tensor of the same shape as x to select which set to use per
element. Both buffers are persistent so they are saved with the checkpoint.
Output has per-dim std ≈ output_scale / √2 before the affine.
After the cosine, a learnable per-(freq set, feature) weight is applied
(shape (num_freq_sets, num_features); initialised to output_scale).
Source code in mouse/models/embedding/encoding.py
forward
¶
Map scalar inputs to RFF embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Scalar inputs, shape |
required |
freq_idx
|
Tensor | int
|
Which frequency set(s) to use. Either:
- an |
required |
Returns:
Tensor of shape (*batch, num_features).
Source code in mouse/models/embedding/encoding.py
mouse.models.embedding.encoding.NormalizedPixel
¶
Bases: Module
Maps integer pixel values (0-255) to [-1, 1].
Linear layers¶
mouse.models.embedding.linear.ScaledEmbedding
¶
Bases: Embedding
nn.Embedding with default Normal(0, 1) init multiplied by scale.
Source code in mouse/models/embedding/linear.py
mouse.models.embedding.linear.ScaledLinear
¶
ScaledLinear(in_features: int, out_features: int, scale: float, bias: bool = True, device: device | str | None = None, dtype: dtype | None = None)
Bases: Linear
Linear layer with Kaiming-uniform init multiplied by scale.
Kaiming-uniform is applied first (a=sqrt(5), same as nn.Linear default),
then both weights and bias are multiplied by scale in-place.
Source code in mouse/models/embedding/linear.py
mouse.models.embedding.linear.PosLinear
¶
PosLinear(num_positions: int, in_features: int, out_features: int, device: device | str | None = None, dtype: dtype | None = None)
Bases: Module
Position-conditioned linear projection.
Stores one independent (weight, bias) pair per position index. At
forward time, each element selects its projection via an integer position
index, enabling per-dimension embeddings without separate nn.Linear
instances.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_positions
|
int
|
Number of distinct positions (embedding table size). |
required |
in_features
|
int
|
Input feature dimension. |
required |
out_features
|
int
|
Output feature dimension. |
required |
device
|
device | str | None
|
Tensor device for parameters. |
None
|
dtype
|
dtype | None
|
Tensor dtype for parameters. |
None
|
Source code in mouse/models/embedding/linear.py
forward
¶
Apply the position-specific projection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor |
required |
pos
|
Tensor
|
Integer position indices |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor |
Source code in mouse/models/embedding/linear.py
mouse.models.embedding.linear.ScaledPosLinear
¶
ScaledPosLinear(num_positions: int, in_features: int, out_features: int, scale: float = 1.0, device: device | str | None = None, dtype: dtype | None = None)
Bases: PosLinear
PosLinear with Kaiming-uniform weights multiplied by scale.
Use scale = output_std / input_std to match a desired output standard deviation.