Skip to content

Backbone

Concrete Model subclasses and the config dataclasses that build HuggingFace transformer backbones. You rarely need to instantiate these directly — use load_model instead.

ModelLlama

mouse.models.backbone.llama.ModelLlama

ModelLlama(hidden_dim: int, backbone_kwargs: dict, embedding_kwargs: dict, sp_head_kwargs: dict, dqn_head_kwargs: dict, sv_head_kwargs: dict, vec_dqn_head_kwargs: dict, action_head: str | None = None)

Bases: Model, PyTorchModelHubMixin

MOUSE model with a Llama transformer backbone.

Attends over the full [B, S*T, D] token sequence with causal SDPA. Supports KV-cache for incremental rollouts (use_cache=True).

Source code in mouse/models/base.py
def __init__(
    self,
    hidden_dim: int,
    backbone_kwargs: dict,
    embedding_kwargs: dict,
    sp_head_kwargs: dict,
    dqn_head_kwargs: dict,
    sv_head_kwargs: dict,
    vec_dqn_head_kwargs: dict,
    action_head: str | None = None,
):
    super().__init__()

    self.hidden_dim = int(hidden_dim)
    embedding_kwargs = {k: v for k, v in embedding_kwargs.items() if k != "obs_continuous_encoder"}
    self.max_num_actions = int(embedding_kwargs["max_num_actions"])

    self.embedder = StepEmbedder(hidden_dim=hidden_dim, **embedding_kwargs)

    self.sp_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sp_head_kwargs)
        if sp_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.dqn_head = (
        DQNHead(in_features=hidden_dim, out_features=self.max_num_actions, **dqn_head_kwargs)
        if dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.vec_dqn_head = (
        VecDQNHead(
            in_features=hidden_dim,
            max_num_actions=self.max_num_actions,
            **vec_dqn_head_kwargs,
        )
        if vec_dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.sv_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sv_head_kwargs)
        if sv_head_kwargs.get("num_layers", 0) > 0 else None
    )

    if action_head is not None:
        if action_head not in self._VALID_HEADS:
            raise ValueError(
                f"action_head must be one of {self._VALID_HEADS}, got {action_head!r}."
            )
        self.action_head: str = action_head
    else:
        # Auto-detect: most expressive enabled head wins.
        if self.vec_dqn_head is not None:
            self.action_head = "vec_dqn"
        elif self.dqn_head is not None:
            self.action_head = "dqn"
        elif self.sp_head is not None:
            self.action_head = "sp"
        elif self.sv_head is not None:
            self.action_head = "sv"
        else:
            raise ValueError("No output head is enabled; cannot determine action_head.")

    self._init_backbone(backbone_kwargs)

backbone_forward

backbone_forward(embeds: Tensor, token_type: Tensor, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None, **kwargs: Any) -> tuple[torch.Tensor, dict[str, Any] | None]

Run the Llama backbone over the token sequence.

Parameters:

Name Type Description Default
embeds Tensor

[B, T_total, D] embedding tensor from StepEmbedder.

required
token_type Tensor

[B, T_total] int64 TokenType ids; PAD positions are masked out from attention.

required
cache dict[str, Any] | None

KV-cache dict from a previous call, or None for full prefill. Reads and writes the "backbone" key.

None
use_cache bool

If True, return an updated KV-cache dict.

False
cache_position Tensor | None

Unused; present for interface compatibility.

None
**kwargs Any

Forwarded to the underlying LlamaModel.

{}

Returns:

Type Description
tuple[Tensor, dict[str, Any] | None]

Tuple of (hidden_states [B, T_total, D], cache_dict | None).

Source code in mouse/models/backbone/llama.py
def backbone_forward(
    self,
    embeds: torch.Tensor,
    token_type: torch.Tensor,
    cache: dict[str, Any] | None = None,
    use_cache: bool = False,
    cache_position: torch.Tensor | None = None,
    **kwargs: Any,
) -> tuple[torch.Tensor, dict[str, Any] | None]:
    """Run the Llama backbone over the token sequence.

    Args:
        embeds: ``[B, T_total, D]`` embedding tensor from ``StepEmbedder``.
        token_type: ``[B, T_total]`` int64 ``TokenType`` ids; ``PAD`` positions
            are masked out from attention.
        cache: KV-cache dict from a previous call, or ``None`` for full prefill.
            Reads and writes the ``"backbone"`` key.
        use_cache: If ``True``, return an updated KV-cache dict.
        cache_position: Unused; present for interface compatibility.
        **kwargs: Forwarded to the underlying ``LlamaModel``.

    Returns:
        Tuple of ``(hidden_states [B, T_total, D], cache_dict | None)``.
    """
    cache = cache or {}

    has_padding = bool((token_type == TokenType.PAD).any())
    attention_mask = (token_type != TokenType.PAD).long() if has_padding else None

    out = self.backbone(
        inputs_embeds=embeds,
        past_key_values=cache.get("backbone", None),
        use_cache=use_cache,
        position_ids=None,
        attention_mask=attention_mask,
        **kwargs,
    )

    new_cache = {"backbone": out.past_key_values} if use_cache else None
    return out.last_hidden_state, new_cache

ModelQwen3

mouse.models.backbone.qwen3.ModelQwen3

ModelQwen3(hidden_dim: int, backbone_kwargs: dict, embedding_kwargs: dict, sp_head_kwargs: dict, dqn_head_kwargs: dict, sv_head_kwargs: dict, vec_dqn_head_kwargs: dict, action_head: str | None = None)

Bases: Model, PyTorchModelHubMixin

MOUSE model with a Qwen3 transformer backbone.

Attends over the full [B, S*T, D] token sequence with causal SDPA. Supports an explicit head_dim (set in backbone_kwargs) for grouped-query attention with a head size independent of the model width. Supports KV-cache for incremental rollouts (use_cache=True).

Source code in mouse/models/base.py
def __init__(
    self,
    hidden_dim: int,
    backbone_kwargs: dict,
    embedding_kwargs: dict,
    sp_head_kwargs: dict,
    dqn_head_kwargs: dict,
    sv_head_kwargs: dict,
    vec_dqn_head_kwargs: dict,
    action_head: str | None = None,
):
    super().__init__()

    self.hidden_dim = int(hidden_dim)
    embedding_kwargs = {k: v for k, v in embedding_kwargs.items() if k != "obs_continuous_encoder"}
    self.max_num_actions = int(embedding_kwargs["max_num_actions"])

    self.embedder = StepEmbedder(hidden_dim=hidden_dim, **embedding_kwargs)

    self.sp_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sp_head_kwargs)
        if sp_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.dqn_head = (
        DQNHead(in_features=hidden_dim, out_features=self.max_num_actions, **dqn_head_kwargs)
        if dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.vec_dqn_head = (
        VecDQNHead(
            in_features=hidden_dim,
            max_num_actions=self.max_num_actions,
            **vec_dqn_head_kwargs,
        )
        if vec_dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.sv_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sv_head_kwargs)
        if sv_head_kwargs.get("num_layers", 0) > 0 else None
    )

    if action_head is not None:
        if action_head not in self._VALID_HEADS:
            raise ValueError(
                f"action_head must be one of {self._VALID_HEADS}, got {action_head!r}."
            )
        self.action_head: str = action_head
    else:
        # Auto-detect: most expressive enabled head wins.
        if self.vec_dqn_head is not None:
            self.action_head = "vec_dqn"
        elif self.dqn_head is not None:
            self.action_head = "dqn"
        elif self.sp_head is not None:
            self.action_head = "sp"
        elif self.sv_head is not None:
            self.action_head = "sv"
        else:
            raise ValueError("No output head is enabled; cannot determine action_head.")

    self._init_backbone(backbone_kwargs)

backbone_forward

backbone_forward(embeds: Tensor, token_type: Tensor, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None, **kwargs: Any) -> tuple[torch.Tensor, dict[str, Any] | None]

Run the Qwen3 backbone over the token sequence.

Parameters:

Name Type Description Default
embeds Tensor

[B, T_total, D] embedding tensor from StepEmbedder.

required
token_type Tensor

[B, T_total] int64 TokenType ids; PAD positions are masked out from attention.

required
cache dict[str, Any] | None

KV-cache dict from a previous call, or None for full prefill. Reads and writes the "backbone" key.

None
use_cache bool

If True, return an updated KV-cache dict.

False
cache_position Tensor | None

Unused; present for interface compatibility.

None
**kwargs Any

Forwarded to the underlying Qwen3Model.

{}

Returns:

Type Description
tuple[Tensor, dict[str, Any] | None]

Tuple of (hidden_states [B, T_total, D], cache_dict | None).

Source code in mouse/models/backbone/qwen3.py
def backbone_forward(
    self,
    embeds: torch.Tensor,
    token_type: torch.Tensor,
    cache: dict[str, Any] | None = None,
    use_cache: bool = False,
    cache_position: torch.Tensor | None = None,
    **kwargs: Any,
) -> tuple[torch.Tensor, dict[str, Any] | None]:
    """Run the Qwen3 backbone over the token sequence.

    Args:
        embeds: ``[B, T_total, D]`` embedding tensor from ``StepEmbedder``.
        token_type: ``[B, T_total]`` int64 ``TokenType`` ids; ``PAD`` positions
            are masked out from attention.
        cache: KV-cache dict from a previous call, or ``None`` for full prefill.
            Reads and writes the ``"backbone"`` key.
        use_cache: If ``True``, return an updated KV-cache dict.
        cache_position: Unused; present for interface compatibility.
        **kwargs: Forwarded to the underlying ``Qwen3Model``.

    Returns:
        Tuple of ``(hidden_states [B, T_total, D], cache_dict | None)``.
    """
    cache = cache or {}

    has_padding = bool((token_type == TokenType.PAD).any())
    attention_mask = (token_type != TokenType.PAD).long() if has_padding else None

    out = self.backbone(
        inputs_embeds=embeds,
        past_key_values=cache.get("backbone", None),
        use_cache=use_cache,
        position_ids=None,
        attention_mask=attention_mask,
        **kwargs,
    )

    new_cache = {"backbone": out.past_key_values} if use_cache else None
    return out.last_hidden_state, new_cache

ModelNone

mouse.models.backbone.none.ModelNone

ModelNone(hidden_dim: int, backbone_kwargs: dict, embedding_kwargs: dict, sp_head_kwargs: dict, dqn_head_kwargs: dict, sv_head_kwargs: dict, vec_dqn_head_kwargs: dict, action_head: str | None = None)

Bases: Model, PyTorchModelHubMixin

MOUSE model with no backbone; embeddings pass directly to the output heads.

Useful for ablations or lightweight baselines where no temporal context is required. backbone_kwargs must be empty (or absent) in config.json. KV-cache is not supported — always returns None for the cache.

Source code in mouse/models/base.py
def __init__(
    self,
    hidden_dim: int,
    backbone_kwargs: dict,
    embedding_kwargs: dict,
    sp_head_kwargs: dict,
    dqn_head_kwargs: dict,
    sv_head_kwargs: dict,
    vec_dqn_head_kwargs: dict,
    action_head: str | None = None,
):
    super().__init__()

    self.hidden_dim = int(hidden_dim)
    embedding_kwargs = {k: v for k, v in embedding_kwargs.items() if k != "obs_continuous_encoder"}
    self.max_num_actions = int(embedding_kwargs["max_num_actions"])

    self.embedder = StepEmbedder(hidden_dim=hidden_dim, **embedding_kwargs)

    self.sp_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sp_head_kwargs)
        if sp_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.dqn_head = (
        DQNHead(in_features=hidden_dim, out_features=self.max_num_actions, **dqn_head_kwargs)
        if dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.vec_dqn_head = (
        VecDQNHead(
            in_features=hidden_dim,
            max_num_actions=self.max_num_actions,
            **vec_dqn_head_kwargs,
        )
        if vec_dqn_head_kwargs.get("num_layers", 0) > 0 else None
    )

    self.sv_head = (
        SwiGLUHead(in_features=hidden_dim, out_features=self.max_num_actions, **sv_head_kwargs)
        if sv_head_kwargs.get("num_layers", 0) > 0 else None
    )

    if action_head is not None:
        if action_head not in self._VALID_HEADS:
            raise ValueError(
                f"action_head must be one of {self._VALID_HEADS}, got {action_head!r}."
            )
        self.action_head: str = action_head
    else:
        # Auto-detect: most expressive enabled head wins.
        if self.vec_dqn_head is not None:
            self.action_head = "vec_dqn"
        elif self.dqn_head is not None:
            self.action_head = "dqn"
        elif self.sp_head is not None:
            self.action_head = "sp"
        elif self.sv_head is not None:
            self.action_head = "sv"
        else:
            raise ValueError("No output head is enabled; cannot determine action_head.")

    self._init_backbone(backbone_kwargs)

backbone_forward

backbone_forward(embeds: Tensor, token_type: Tensor, cache: dict[str, Any] | None = None, use_cache: bool = False, cache_position: Tensor | None = None, **kwargs: Any) -> tuple[torch.Tensor, dict[str, Any] | None]

Pass embeddings through unchanged; always returns None for cache.

Parameters:

Name Type Description Default
embeds Tensor

[B, T_total, D] embedding tensor.

required
token_type Tensor

Ignored.

required
cache dict[str, Any] | None

Ignored.

None
use_cache bool

Ignored.

False
cache_position Tensor | None

Ignored.

None

Returns:

Type Description
tuple[Tensor, dict[str, Any] | None]

Tuple of (embeds unchanged, None).

Source code in mouse/models/backbone/none.py
def backbone_forward(
    self,
    embeds: torch.Tensor,
    token_type: torch.Tensor,
    cache: dict[str, Any] | None = None,
    use_cache: bool = False,
    cache_position: torch.Tensor | None = None,
    **kwargs: Any,
) -> tuple[torch.Tensor, dict[str, Any] | None]:
    """Pass embeddings through unchanged; always returns ``None`` for cache.

    Args:
        embeds: ``[B, T_total, D]`` embedding tensor.
        token_type: Ignored.
        cache: Ignored.
        use_cache: Ignored.
        cache_position: Ignored.

    Returns:
        Tuple of ``(embeds unchanged, None)``.
    """
    return embeds, None

LlamaBackboneConfig

mouse.models.backbone.llama.LlamaBackboneConfig dataclass

LlamaBackboneConfig(num_layers: int, num_heads: int, num_key_value_heads: int | None = None, max_position_embeddings: int = 4096, expand: int = 4, intermediate_size: int | None = None, rope_parameters: dict | None = None, rms_norm_eps: float = 1e-05, attention_bias: bool = False)

Configuration for a Llama transformer backbone.

Builds a HuggingFace LlamaModel with SDPA attention and no token embedding or final layer norm (norm is replaced with nn.Identity).

Parameters:

Name Type Description Default
num_layers int

Number of transformer decoder layers.

required
num_heads int

Number of query attention heads.

required
num_key_value_heads int | None

Key/value heads for GQA; defaults to num_heads.

None
max_position_embeddings int

Maximum sequence length for RoPE; should be at least sequence_length * tokens_per_step.

4096
expand int

FFN intermediate size multiplier: intermediate_size = hidden_dim * expand.

4
intermediate_size int | None

Exact FFN size; overrides expand * hidden_dim when set. Use this when loading from a pretrained model whose FFN size is not an integer multiple of the hidden dim.

None
rope_parameters dict | None

Optional dict forwarded to LlamaConfig.rope_parameters for custom RoPE variants (e.g. {"rope_type": "llama3"}).

None
rms_norm_eps float

Epsilon for RMSNorm layers.

1e-05
attention_bias bool

Whether to add bias to QKV and output projections.

False

build

build(hidden_dim: int) -> LlamaModel

Instantiate a LlamaModel with this config.

Parameters:

Name Type Description Default
hidden_dim int

Model hidden dimension D; must be divisible by num_heads.

required

Returns:

Type Description
LlamaModel

LlamaModel with the final norm replaced by nn.Identity.

Source code in mouse/models/backbone/llama.py
def build(self, hidden_dim: int) -> LlamaModel:
    """Instantiate a ``LlamaModel`` with this config.

    Args:
        hidden_dim: Model hidden dimension ``D``; must be divisible by ``num_heads``.

    Returns:
        ``LlamaModel`` with the final norm replaced by ``nn.Identity``.
    """
    _disable_cudnn_sdp()
    if hidden_dim % self.num_heads != 0:
        raise ValueError(
            f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({self.num_heads})."
        )
    ffn_size = self.intermediate_size if self.intermediate_size is not None else hidden_dim * self.expand
    config_kwargs: dict = dict(
        vocab_size=1,
        hidden_size=hidden_dim,
        num_attention_heads=self.num_heads,
        num_key_value_heads=self.num_key_value_heads,
        intermediate_size=ffn_size,
        max_position_embeddings=self.max_position_embeddings,
        attention_dropout=0.0,
        attention_bias=self.attention_bias,
        rms_norm_eps=self.rms_norm_eps,
        num_hidden_layers=self.num_layers,
    )
    if self.rope_parameters is not None:
        config_kwargs["rope_parameters"] = self.rope_parameters
    config = LlamaConfig(**config_kwargs)
    config._attn_implementation = "sdpa"
    model = LlamaModel(config)
    model.norm = nn.Identity()  # type: ignore[assignment]
    return model

Qwen3BackboneConfig

mouse.models.backbone.qwen3.Qwen3BackboneConfig dataclass

Qwen3BackboneConfig(num_layers: int, num_heads: int, num_key_value_heads: int | None = None, head_dim: int | None = None, max_position_embeddings: int = 32768, expand: int = 3, intermediate_size: int | None = None, rope_parameters: dict | None = None, rms_norm_eps: float = 1e-06, attention_bias: bool = False, use_sliding_window: bool = False)

Configuration for a Qwen3 transformer backbone.

Builds a HuggingFace Qwen3Model with SDPA attention and no token embedding or final layer norm (norm is replaced with nn.Identity).

Parameters:

Name Type Description Default
num_layers int

Number of transformer decoder layers.

required
num_heads int

Number of query attention heads.

required
num_key_value_heads int | None

Key/value heads for GQA; defaults to num_heads.

None
head_dim int | None

Per-head attention dimension. When None, defaults to hidden_dim // num_heads. Set explicitly to decouple model width from attention head size (useful for GQA with small num_key_value_heads).

None
max_position_embeddings int

Maximum sequence length for RoPE.

32768
expand int

FFN intermediate size multiplier: intermediate_size = hidden_dim * expand.

3
intermediate_size int | None

Exact FFN size; overrides expand * hidden_dim when set. Use this when loading from a pretrained model whose FFN size is not an integer multiple of the hidden dim.

None
rope_parameters dict | None

Optional dict forwarded to Qwen3Config.rope_parameters.

None
rms_norm_eps float

Epsilon for RMSNorm layers.

1e-06
attention_bias bool

Whether to add bias to QKV and output projections.

False
use_sliding_window bool

Enable sliding-window attention (Qwen3 feature).

False

build

build(hidden_dim: int) -> Qwen3Model

Instantiate a Qwen3Model with this config.

Parameters:

Name Type Description Default
hidden_dim int

Model hidden dimension D. When head_dim is None, must be divisible by num_heads.

required

Returns:

Type Description
Qwen3Model

Qwen3Model with the final norm replaced by nn.Identity.

Source code in mouse/models/backbone/qwen3.py
def build(self, hidden_dim: int) -> Qwen3Model:
    """Instantiate a ``Qwen3Model`` with this config.

    Args:
        hidden_dim: Model hidden dimension ``D``. When ``head_dim`` is ``None``,
            must be divisible by ``num_heads``.

    Returns:
        ``Qwen3Model`` with the final norm replaced by ``nn.Identity``.
    """
    _disable_cudnn_sdp()
    if self.head_dim is None:
        if hidden_dim % self.num_heads != 0:
            raise ValueError(
                f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({self.num_heads})."
            )
        resolved_head_dim = hidden_dim // self.num_heads
    else:
        resolved_head_dim = int(self.head_dim)
    ffn_size = self.intermediate_size if self.intermediate_size is not None else hidden_dim * self.expand
    config_kwargs: dict = dict(
        vocab_size=1,
        hidden_size=hidden_dim,
        num_attention_heads=self.num_heads,
        num_key_value_heads=self.num_key_value_heads,
        head_dim=resolved_head_dim,
        intermediate_size=ffn_size,
        max_position_embeddings=self.max_position_embeddings,
        attention_dropout=0.0,
        attention_bias=self.attention_bias,
        rms_norm_eps=self.rms_norm_eps,
        num_hidden_layers=self.num_layers,
        use_sliding_window=self.use_sliding_window,
    )
    if self.rope_parameters is not None:
        config_kwargs["rope_parameters"] = self.rope_parameters
    config = Qwen3Config(**config_kwargs)
    config._attn_implementation = "sdpa"
    model = Qwen3Model(config)
    model.norm = nn.Identity()  # type: ignore[assignment]
    return model