Base¶
The mouse.models.base module contains the abstract Model base class and the load_model factory function. Use load_model in virtually all cases — it reads config.json and automatically instantiates the right concrete class.
load_model¶
mouse.models.base.load_model
¶
load_model(repo_id_or_path: str, force_download: bool = False, local_dir: str | Path | None = None, **kwargs: Any) -> 'Model'
Load a MOUSE model from a local directory or HuggingFace Hub repo.
Automatically selects the correct model class (ModelLlama, ModelQwen3, or
ModelNone) by inspecting backbone_kwargs in config.json — no need
to know the class up front.
Detection logic
backbone_kwargsis empty → ModelNonebackbone_kwargshashead_dim→ ModelQwen3backbone_kwargsis non-empty withouthead_dim→ ModelLlama
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repo_id_or_path
|
str
|
A local path to a checkpoint directory or a HF Hub
repo id (e.g. |
required |
force_download
|
bool
|
If |
False
|
local_dir
|
str | Path | None
|
Directory where Hub files are saved after download. When
set, |
None
|
**kwargs
|
Any
|
Forwarded verbatim to |
{}
|
Returns:
| Type | Description |
|---|---|
'Model'
|
The loaded model instance. |
Source code in mouse/models/base.py
Model¶
mouse.models.base.Model
¶
Model(hidden_dim: int, backbone_kwargs: dict, embedding_kwargs: dict, sp_head_kwargs: dict, dqn_head_kwargs: dict, sv_head_kwargs: dict, vec_dqn_head_kwargs: dict, action_head: str | None = None)
Bases: Module
Base for context-conditioned model: StepEmbedder, backbone, and output heads.
The forward pass takes a TensorDict [B, S] of step records.
StepEmbedder maps each step to tokens_per_step embedding vectors,
producing a flat token sequence [B, S * tokens_per_step, D] for the
backbone. The last token of each step is pooled as the step representation
and passed to each enabled output head.
Enabled heads are determined by num_layers > 0 in their kwargs:
sp— SwiGLU MLP → logits[B, S, A]dqn— DQN twin-head → logits[B, S, A](+dqn_target)vec_dqn— VecDQN head → vectors[B, S, A, D](+vec_dqn_target); useget_actionorvec_dqn_scoresto get scalar scoressv— SwiGLU MLP → logits[B, S, A]
action_head selects which head get_action reads. It is stored in
config.json and loaded automatically. If omitted, the most expressive
enabled head is chosen: vec_dqn > dqn > sp > sv.
Use get_action(out, temperature, num_actions) to sample or greedily
select actions without manual score conversion.
Source code in mouse/models/base.py
backbone_forward
¶
backbone_forward(embeds: Tensor, token_type: Tensor, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None, **kwargs: Any) -> tuple[torch.Tensor, dict[str, Any] | None]
Run backbone; return (hidden states [B, T, D], cache dict or None).
Source code in mouse/models/base.py
head
¶
Run all enabled heads on pooled step representations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
h
|
Tensor
|
Step representations |
required |
batch_size
|
tuple[int, int]
|
|
required |
Returns:
| Type | Description |
|---|---|
TensorDict
|
TensorDict |
TensorDict
|
Logit heads ( |
TensorDict
|
|
TensorDict
|
shape |
Source code in mouse/models/base.py
forward
¶
forward(step_stream: TensorDict, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None) -> tuple[TensorDict, dict[str, Any] | None]
Run a full forward pass over a batch of step sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_stream
|
TensorDict
|
TensorDict |
required |
cache
|
dict[str, Any] | None
|
KV-cache dict from a previous call, or |
None
|
use_cache
|
bool
|
If |
False
|
cache_position
|
Tensor | None
|
Token position indices |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
TensorDict
|
TensorDict |
cache |
dict[str, Any] | None
|
Updated KV-cache dict, or |
Source code in mouse/models/base.py
polyak_update
¶
Soft-update all target heads toward their online counterparts.
Source code in mouse/models/base.py
get_action
¶
get_action(out: TensorDict, temperature: float = 1.0, num_actions: int | None = None) -> torch.Tensor
Select an action from model output at the last step position.
Uses self.action_head, which is set at construction time (or
auto-detected from enabled heads).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
out
|
TensorDict
|
Model output TensorDict |
required |
temperature
|
float
|
Sampling temperature. |
1.0
|
num_actions
|
int | None
|
If given, trim scores to the first |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
|