Skip to content

Heads

Output heads are constructed and owned by Model. They are documented here for reference when building custom training loops.

DQNHead

mouse.models.heads.dqn.DQNHead

DQNHead(in_features: int, out_features: int, hidden_dim: int, num_layers: int, scale: float = 1.0, use_norm: bool = True)

Bases: BaseHeadWithTarget

SwiGLUHead paired with an EMA target copy and Polyak averaging.

forward runs the online head. target_forward runs the target head (no gradient tracking). Call polyak_update(tau) after each optimizer step to soft-update the target: θ_target ← τ·θ_online + (1−τ)·θ_target. Initialize with tau=1.0 to copy online weights into the target.

Source code in mouse/models/heads/dqn.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    hidden_dim: int,
    num_layers: int,
    scale: float = 1.0,
    use_norm: bool = True,
):
    super().__init__()
    head_kwargs = dict(
        in_features=in_features,
        out_features=out_features,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        scale=scale,
        use_norm=use_norm,
    )
    self.online = SwiGLUHead(**head_kwargs)
    self._init_target(self.online)

forward

forward(x: Tensor) -> torch.Tensor

Run the online head; returns Q-value logits [B, S, A].

Source code in mouse/models/heads/dqn.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Run the online head; returns Q-value logits ``[B, S, A]``."""
    return self.online(x)

VecDQNHead

mouse.models.heads.vec_dqn.VecDQNHead

VecDQNHead(in_features: int, max_num_actions: int, vec_dim: int, hidden_dim: int, num_layers: int, scale: float = 1.0, bias_scale: float | None = None, use_norm: bool = True)

Bases: BaseHeadWithTarget

SwiGLUHead paired with an EMA target copy and Polyak averaging.

Like DQNHead but each action produces a vec_dim-dimensional vector instead of a single scalar. Output shape is [..., max_num_actions, vec_dim].

forward runs the online head. target_forward runs the target head (no gradient tracking). Call polyak_update(tau) after each optimizer step to soft-update the target: θ_target ← τ·θ_online + (1−τ)·θ_target. Initialize with tau=1.0 to copy online weights into the target.

Source code in mouse/models/heads/vec_dqn.py
def __init__(
    self,
    in_features: int,
    max_num_actions: int,
    vec_dim: int,
    hidden_dim: int,
    num_layers: int,
    scale: float = 1.0,
    bias_scale: float | None = None,
    use_norm: bool = True,
):
    super().__init__()
    self.A = max_num_actions
    self.D = vec_dim
    head_kwargs = dict(
        in_features=in_features,
        out_features=max_num_actions * vec_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        scale=scale,
        use_norm=use_norm,
    )
    self.online = SwiGLUHead(**head_kwargs)
    if bias_scale is not None:
        out_bias = self.online.layers[-1].bias
        if out_bias is not None:
            with torch.no_grad():
                out_bias.fill_(float(bias_scale))
    self._init_target(self.online)

vec_dqn_scores

mouse.models.heads.vec_dqn.vec_dqn_scores

vec_dqn_scores(vecs: Tensor) -> torch.Tensor

Compute pairwise angular action scores from vec-DQN vectors.

For each pair of actions (i, a), computes the full signed angle φ_a − φ_i via atan2(sin, cos), then sums over all i to give action a's total angular lead.

The sin component is dot(rot90(vᵢ), vₐ) and the cos component is dot(vᵢ, vₐ). Using both avoids the aliasing of a sin-only score, which saturates at ±90° and folds back toward zero toward 180°. The atan2 score is monotone across the full (−π, +π) range — aliasing only occurs if two action vectors rotate past ±180° apart, which is twice as hard to reach.

For D = 2 (a single rotation plane) this is geometrically exact. For D > 2 (RoPE with multiple planes) it is a well-conditioned approximation; the D = 2 case is recommended for exact geometry.

Self-terms contribute atan2(0, 1) = 0 and require no masking.

Parameters:

Name Type Description Default
vecs Tensor

[..., A, D] — raw (un-normalised) action vectors.

required

Returns:

Name Type Description
scores Tensor

[..., A] — summed angular lead per action, in radians.

Source code in mouse/models/heads/vec_dqn.py
def vec_dqn_scores(vecs: torch.Tensor) -> torch.Tensor:
    """Compute pairwise angular action scores from vec-DQN vectors.

    For each pair of actions ``(i, a)``, computes the full signed angle
    ``φ_a − φ_i`` via ``atan2(sin, cos)``, then sums over all ``i`` to give
    action ``a``'s total angular lead.

    The sin component is ``dot(rot90(vᵢ), vₐ)`` and the cos component is
    ``dot(vᵢ, vₐ)``.  Using both avoids the aliasing of a sin-only score,
    which saturates at ±90° and folds back toward zero toward 180°.  The
    atan2 score is monotone across the full (−π, +π) range — aliasing only
    occurs if two action vectors rotate past ±180° apart, which is twice as
    hard to reach.

    For ``D = 2`` (a single rotation plane) this is geometrically exact.
    For ``D > 2`` (RoPE with multiple planes) it is a well-conditioned
    approximation; the D = 2 case is recommended for exact geometry.

    Self-terms contribute ``atan2(0, 1) = 0`` and require no masking.

    Args:
        vecs: ``[..., A, D]`` — raw (un-normalised) action vectors.

    Returns:
        scores: ``[..., A]`` — summed angular lead per action, in radians.
    """
    leading = vecs.shape[:-2]
    A, D = vecs.shape[-2], vecs.shape[-1]
    vecs_norm = F.normalize(vecs, dim=-1)                              # [..., A, D]
    flat = vecs_norm.reshape(-1, D)
    theta90 = torch.full((flat.shape[0],), math.pi / 2, device=flat.device, dtype=flat.dtype)
    rot90 = rope_rotate(x=flat, theta=theta90).reshape(*leading, A, D)
    sin_ia = torch.einsum("...id,...ad->...ia", rot90, vecs_norm)      # [..., A, A]  sin(φ_a − φ_i)
    cos_ia = torch.einsum("...id,...ad->...ia", vecs_norm, vecs_norm)  # [..., A, A]  cos(φ_a − φ_i)
    return torch.atan2(sin_ia, cos_ia).sum(dim=-2)                     # [..., A]

rope_rotate

mouse.models.heads.vec_dqn.rope_rotate

rope_rotate(x: Tensor, theta: Tensor) -> torch.Tensor

Rotate each consecutive pair of dimensions in x by theta.

Parameters:

Name Type Description Default
x Tensor

[..., D] where D is even.

required
theta Tensor

[...] — same leading shape as x, one angle per vector.

required

Returns:

Type Description
Tensor

Rotated tensor of the same shape as x.

Source code in mouse/models/heads/vec_dqn.py
def rope_rotate(x: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
    """Rotate each consecutive pair of dimensions in ``x`` by ``theta``.

    Args:
        x:     ``[..., D]`` where ``D`` is even.
        theta: ``[...]`` — same leading shape as ``x``, one angle per vector.

    Returns:
        Rotated tensor of the same shape as ``x``.
    """
    rotated = torch.empty_like(x)
    rotated[..., ::2]  = x[..., ::2]  * theta.unsqueeze(-1).cos() - x[..., 1::2] * theta.unsqueeze(-1).sin()
    rotated[..., 1::2] = x[..., ::2]  * theta.unsqueeze(-1).sin() + x[..., 1::2] * theta.unsqueeze(-1).cos()
    return rotated

SwiGLUHead

mouse.models.heads.swiglu.SwiGLUHead

SwiGLUHead(in_features: int, out_features: int, hidden_dim: int, num_layers: int, scale: float = 1.0, use_norm: bool = True)

Bases: BaseHead

MLP head built from stacked SwiGLU blocks with a scaled output projection.

Architecture::

[RMSNorm →] SwiGLU(D→hidden) × (num_layers−1) → ScaledLinear(hidden→out)

The optional RMSNorm (use_norm=True) is applied to the input before the first SwiGLU block. scale controls the output weight initialisation magnitude — set small (e.g. 0.01) for a near-zero initial output.

Parameters:

Name Type Description Default
in_features int

Input dimension D.

required
out_features int

Output dimension (number of actions A, or A * vec_dim).

required
hidden_dim int

Width of the SwiGLU hidden layers.

required
num_layers int

Total depth including the final linear; must be >= 1.

required
scale float

ScaledLinear weight init multiplier for the output projection.

1.0
use_norm bool

Whether to prepend an RMSNorm layer.

True
Source code in mouse/models/heads/swiglu.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    hidden_dim: int,
    num_layers: int,
    scale: float = 1.0,
    use_norm: bool = True,
):
    super().__init__()
    if use_norm:
        self.norm = nn.RMSNorm(in_features, elementwise_affine=True, eps=1e-5)
    else:
        self.norm = None
    dims = [in_features] + [hidden_dim] * (num_layers - 1) + [out_features]
    self.layers = nn.Sequential(
        *[SwiGLU(in_features=dims[i], hidden_dim=dims[i+1]) for i in range(num_layers - 1)],
        ScaledLinear(in_features=dims[-2], out_features=dims[-1], scale=scale),
    )

SwiGLU

mouse.models.heads.swiglu.SwiGLU

SwiGLU(in_features: int, hidden_dim: int)

Bases: Module

Gated linear unit with SiLU on the gate: silu(x @ W1) * (x @ W2) via one fused Linear to 2 * dim.

Source code in mouse/models/heads/swiglu.py
def __init__(self, in_features: int, hidden_dim: int) -> None:
    super().__init__()
    self.linear = nn.Linear(in_features, 2 * hidden_dim)