Skip to content

Base

The mouse.models.base module contains the abstract Model base class and the load_model factory function. Use load_model in virtually all cases — it reads config.json and automatically instantiates the right concrete class.

load_model

mouse.models.base.load_model

load_model(repo_id_or_path: str, force_download: bool = False, local_dir: str | Path | None = None, **kwargs: Any) -> 'Model'

Load a MOUSE model from a local directory or HuggingFace Hub repo.

Automatically selects the correct model class (ModelLlama, ModelQwen3, or ModelNone) by inspecting backbone_kwargs in config.json — no need to know the class up front.

Detection logic
  • backbone_kwargs is empty → ModelNone
  • backbone_kwargs has head_dim → ModelQwen3
  • backbone_kwargs is non-empty without head_dim → ModelLlama

Parameters:

Name Type Description Default
repo_id_or_path str

A local path to a checkpoint directory or a HF Hub repo id (e.g. "your-org/your-model").

required
force_download bool

If True, bypass the HF Hub cache and re-download. Ignored for local paths.

False
local_dir str | Path | None

Directory where Hub files are saved after download. When set, hf_hub_download writes files there and from_pretrained loads from that directory instead of the Hub cache. Ignored for local paths.

None
**kwargs Any

Forwarded verbatim to cls.from_pretrained (e.g. map_location, revision, token).

{}

Returns:

Type Description
'Model'

The loaded model instance.

Source code in mouse/models/base.py
def load_model(
    repo_id_or_path: str,
    force_download: bool = False,
    local_dir: str | Path | None = None,
    **kwargs: Any,
) -> "Model":
    """Load a MOUSE model from a local directory or HuggingFace Hub repo.

    Automatically selects the correct model class (ModelLlama, ModelQwen3, or
    ModelNone) by inspecting ``backbone_kwargs`` in ``config.json`` — no need
    to know the class up front.

    Detection logic:
        - ``backbone_kwargs`` is empty  → ModelNone
        - ``backbone_kwargs`` has ``head_dim`` → ModelQwen3
        - ``backbone_kwargs`` is non-empty without ``head_dim`` → ModelLlama

    Args:
        repo_id_or_path: A local path to a checkpoint directory or a HF Hub
            repo id (e.g. ``"your-org/your-model"``).
        force_download: If ``True``, bypass the HF Hub cache and re-download.
            Ignored for local paths.
        local_dir: Directory where Hub files are saved after download.  When
            set, ``hf_hub_download`` writes files there and
            ``from_pretrained`` loads from that directory instead of the
            Hub cache.  Ignored for local paths.
        **kwargs: Forwarded verbatim to ``cls.from_pretrained`` (e.g.
            ``map_location``, ``revision``, ``token``).

    Returns:
        The loaded model instance.
    """
    local = Path(repo_id_or_path)
    if local.exists():
        with (local / "config.json").open() as fh:
            config = json.load(fh)
    else:
        from huggingface_hub import hf_hub_download
        hf_kwargs: dict[str, Any] = {"force_download": force_download}
        if local_dir is not None:
            hf_kwargs["local_dir"] = str(local_dir)
        config_file = hf_hub_download(repo_id=repo_id_or_path, filename="config.json", **hf_kwargs)
        with open(config_file) as fh:
            config = json.load(fh)
        kwargs = {**kwargs, "force_download": force_download}
        if local_dir is not None:
            repo_id_or_path = str(local_dir)

    backbone_kwargs = config.get("backbone_kwargs", {})

    if not backbone_kwargs:
        from mouse.models.backbone.none import ModelNone
        return ModelNone.from_pretrained(repo_id_or_path, **kwargs)
    if "head_dim" in backbone_kwargs:
        from mouse.models.backbone.qwen3 import ModelQwen3
        return ModelQwen3.from_pretrained(repo_id_or_path, **kwargs)
    from mouse.models.backbone.llama import ModelLlama
    return ModelLlama.from_pretrained(repo_id_or_path, **kwargs)

Model

mouse.models.base.Model

Model(hidden_dim: int, backbone_kwargs: dict, embedding_kwargs: dict, sp_head_kwargs: dict, dqn_head_kwargs: dict, sv_head_kwargs: dict, vec_dqn_head_kwargs: dict, action_head: str | None = None)

Bases: Module

Base for context-conditioned model: StepEmbedder, backbone, and output heads.

The forward pass takes a TensorDict [B, S] of step records. StepEmbedder maps each step to tokens_per_step embedding vectors, producing a flat token sequence [B, S * tokens_per_step, D] for the backbone. The last token of each step is pooled as the step representation and passed to each enabled output head.

Enabled heads are determined by num_layers > 0 in their kwargs:

  • sp — SwiGLU MLP → logits [B, S, A]
  • dqn — DQN twin-head → logits [B, S, A] (+ dqn_target)
  • vec_dqn — VecDQN head → vectors [B, S, A, D] (+ vec_dqn_target); use get_action or vec_dqn_scores to get scalar scores
  • sv — SwiGLU MLP → logits [B, S, A]

action_head selects which head get_action reads. It is stored in config.json and loaded automatically. If omitted, the most expressive enabled head is chosen: vec_dqn > dqn > sp > sv.

Use get_action(out, temperature, num_actions) to sample or greedily select actions without manual score conversion.

Source code in mouse/models/base.py
def __init__(
    self,
    hidden_dim: int,
    backbone_kwargs: dict,
    embedding_kwargs: dict,
    sp_head_kwargs: dict,
    dqn_head_kwargs: dict,
    sv_head_kwargs: dict,
    vec_dqn_head_kwargs: dict,
    action_head: str | None = None,
):
    super().__init__()

    self.hidden_dim = int(hidden_dim)
    embedding_kwargs = {k: v for k, v in embedding_kwargs.items() if k != "obs_continuous_encoder"}
    self.max_num_actions = int(embedding_kwargs["max_num_actions"])

    self.embedder = StepEmbedder(hidden_dim=hidden_dim, **embedding_kwargs)

    self.sp_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sp_head_kwargs)
        if sp_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.dqn_head = (
        DQNHead(in_features=hidden_dim, out_features=self.max_num_actions, **dqn_head_kwargs)
        if dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.vec_dqn_head = (
        VecDQNHead(
            in_features=hidden_dim,
            max_num_actions=self.max_num_actions,
            **vec_dqn_head_kwargs,
        )
        if vec_dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.sv_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sv_head_kwargs)
        if sv_head_kwargs.get("num_layers", 0) > 0 else None
    )

    if action_head is not None:
        if action_head not in self._VALID_HEADS:
            raise ValueError(
                f"action_head must be one of {self._VALID_HEADS}, got {action_head!r}."
            )
        self.action_head: str = action_head
    else:
        # Auto-detect: most expressive enabled head wins.
        if self.vec_dqn_head is not None:
            self.action_head = "vec_dqn"
        elif self.dqn_head is not None:
            self.action_head = "dqn"
        elif self.sp_head is not None:
            self.action_head = "sp"
        elif self.sv_head is not None:
            self.action_head = "sv"
        else:
            raise ValueError("No output head is enabled; cannot determine action_head.")

    self._init_backbone(backbone_kwargs)

backbone_forward

backbone_forward(embeds: Tensor, token_type: Tensor, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None, **kwargs: Any) -> tuple[torch.Tensor, dict[str, Any] | None]

Run backbone; return (hidden states [B, T, D], cache dict or None).

Source code in mouse/models/base.py
def backbone_forward(
    self,
    embeds: torch.Tensor,
    token_type: torch.Tensor,
    cache: dict[str, Any] | None = None,
    use_cache: bool = False,
    cache_position: torch.Tensor | None = None,
    **kwargs: Any,
) -> tuple[torch.Tensor, dict[str, Any] | None]:
    """Run backbone; return (hidden states [B, T, D], cache dict or None)."""
    raise NotImplementedError("Subclasses must implement backbone_forward.")

head

head(h: Tensor, batch_size: tuple[int, int]) -> TensorDict

Run all enabled heads on pooled step representations.

Parameters:

Name Type Description Default
h Tensor

Step representations [B, S, D].

required
batch_size tuple[int, int]

(B, S) used to set the TensorDict batch dimensions.

required

Returns:

Type Description
TensorDict

TensorDict [B, S] with a key for each enabled head.

TensorDict

Logit heads (sp, dqn, dqn_target, sv) have shape

TensorDict

[B, S, A]; vector heads (vec_dqn, vec_dqn_target) have

TensorDict

shape [B, S, A, D]. Disabled heads are absent.

Source code in mouse/models/base.py
def head(self, h: torch.Tensor, batch_size: tuple[int, int]) -> TensorDict:
    """Run all enabled heads on pooled step representations.

    Args:
        h: Step representations ``[B, S, D]``.
        batch_size: ``(B, S)`` used to set the TensorDict batch dimensions.

    Returns:
        TensorDict ``[B, S]`` with a key for each enabled head.
        Logit heads (``sp``, ``dqn``, ``dqn_target``, ``sv``) have shape
        ``[B, S, A]``; vector heads (``vec_dqn``, ``vec_dqn_target``) have
        shape ``[B, S, A, D]``.  Disabled heads are absent.
    """
    tensors: dict[str, torch.Tensor] = {}
    if self.sp_head is not None:
        tensors["sp"] = self.sp_head(h)
    if self.dqn_head is not None:
        tensors["dqn"] = self.dqn_head(h)
        tensors["dqn_target"] = self.dqn_head.target_forward(h)
    if self.vec_dqn_head is not None:
        tensors["vec_dqn"] = self.vec_dqn_head(h)
        tensors["vec_dqn_target"] = self.vec_dqn_head.target_forward(h)
    if self.sv_head is not None:
        tensors["sv"] = self.sv_head(h)
    return TensorDict(tensors, batch_size=batch_size)

forward

forward(step_stream: TensorDict, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None) -> tuple[TensorDict, dict[str, Any] | None]

Run a full forward pass over a batch of step sequences.

Parameters:

Name Type Description Default
step_stream TensorDict

TensorDict [B, S] of step records (observations, actions, rewards, etc. as configured by the embedder).

required
cache dict[str, Any] | None

KV-cache dict from a previous call, or None for a full prefill. Only meaningful when use_cache=True.

None
use_cache bool

If True the backbone returns an updated cache that can be passed back on the next call for incremental decoding.

False
cache_position Tensor | None

Token position indices [T] for incremental decoding; leave None for full prefill.

None

Returns:

Name Type Description
out TensorDict

TensorDict [B, S] with one key per enabled head. Logit heads — sp, dqn, dqn_target, sv — have shape [B, S, A]. Vector heads — vec_dqn, vec_dqn_target — have shape [B, S, A, D]. Use get_action or vec_dqn_scores for the vector heads.

cache dict[str, Any] | None

Updated KV-cache dict, or None when use_cache=False.

Source code in mouse/models/base.py
def forward(
    self,
    step_stream: TensorDict,
    cache: dict[str, Any] | None = None,
    use_cache: bool = False,
    cache_position: torch.Tensor | None = None,
) -> tuple[TensorDict, dict[str, Any] | None]:
    """Run a full forward pass over a batch of step sequences.

    Args:
        step_stream: TensorDict ``[B, S]`` of step records (observations,
            actions, rewards, etc. as configured by the embedder).
        cache: KV-cache dict from a previous call, or ``None`` for a full
            prefill.  Only meaningful when ``use_cache=True``.
        use_cache: If ``True`` the backbone returns an updated cache that
            can be passed back on the next call for incremental decoding.
        cache_position: Token position indices ``[T]`` for incremental
            decoding; leave ``None`` for full prefill.

    Returns:
        out: TensorDict ``[B, S]`` with one key per enabled head.
             Logit heads — ``sp``, ``dqn``, ``dqn_target``, ``sv`` —
             have shape ``[B, S, A]``.  Vector heads — ``vec_dqn``,
             ``vec_dqn_target`` — have shape ``[B, S, A, D]``.
             Use ``get_action`` or ``vec_dqn_scores`` for the vector heads.
        cache: Updated KV-cache dict, or ``None`` when ``use_cache=False``.
    """
    B, S = int(step_stream.batch_size[0]), int(step_stream.batch_size[1])

    embeds, token_type = self.embedder(step_stream)
    h, new_cache = self.backbone_forward(
        embeds=embeds,
        token_type=token_type,
        cache=cache,
        use_cache=use_cache,
        cache_position=cache_position,
    )

    # Take the last token per step as the step representation
    T = self.embedder.tokens_per_step
    h_step = h.view(B, S, T, self.hidden_dim)[:, :, -1, :]  # [B, S, D]

    return self.head(h_step.float(), batch_size=(B, S)), new_cache

polyak_update

polyak_update(dqn_tau: float = 0.0, vec_dqn_tau: float = 0.0) -> None

Soft-update all target heads toward their online counterparts.

Source code in mouse/models/base.py
def polyak_update(self, dqn_tau: float = 0.0, vec_dqn_tau: float = 0.0) -> None:
    """Soft-update all target heads toward their online counterparts."""
    if self.dqn_head is not None:
        self.dqn_head.polyak_update(tau=dqn_tau)
    if self.vec_dqn_head is not None:
        self.vec_dqn_head.polyak_update(tau=vec_dqn_tau)

get_action

get_action(out: TensorDict, temperature: float = 1.0, num_actions: int | None = None) -> torch.Tensor

Select an action from model output at the last step position.

Uses self.action_head, which is set at construction time (or auto-detected from enabled heads).

Parameters:

Name Type Description Default
out TensorDict

Model output TensorDict [B, S, ...].

required
temperature float

Sampling temperature. 0.0 → greedy argmax; > 0 → softmax sampling.

1.0
num_actions int | None

If given, trim scores to the first num_actions columns before sampling (useful when the environment has fewer actions than the model's maximum).

None

Returns:

Type Description
Tensor

[B] int64 tensor of selected actions.

Source code in mouse/models/base.py
def get_action(
    self,
    out: TensorDict,
    temperature: float = 1.0,
    num_actions: int | None = None,
) -> torch.Tensor:
    """Select an action from model output at the last step position.

    Uses ``self.action_head``, which is set at construction time (or
    auto-detected from enabled heads).

    Args:
        out: Model output TensorDict ``[B, S, ...]``.
        temperature: Sampling temperature. ``0.0`` → greedy argmax;
                     ``> 0`` → softmax sampling.
        num_actions: If given, trim scores to the first ``num_actions``
                     columns before sampling (useful when the environment
                     has fewer actions than the model's maximum).

    Returns:
        ``[B]`` int64 tensor of selected actions.
    """
    raw = out[self.action_head][:, -1]  # [B, A] or [B, A, D] for vec_dqn
    scores: torch.Tensor = vec_dqn_scores(raw) if self.action_head == "vec_dqn" else raw  # [B, A]
    if num_actions is not None:
        scores = scores[:, :num_actions]
    if temperature == 0.0:
        return scores.argmax(dim=-1)
    scores = scores - scores.max(dim=-1, keepdim=True).values  # numerical stability
    probs = F.softmax(scores / temperature, dim=-1)
    return torch.multinomial(probs, num_samples=1).squeeze(-1)