Skip to content

Losses

All loss functions return (scalar_loss, dict[str, float]). The dict is ready to log directly to W&B or TensorBoard.

DQN loss

mouse.losses.dqn.DqnLossConfig dataclass

DqnLossConfig(weight: float = 0.0, gamma: float = 0.99, gamma_terminal: float = 0.0, gamma_truncated: float = 0.0, tau: float = 0.01, normalize_reward_mean: bool = False, normalize_reward_std: bool = False, normalize_reward_eps: float = 1e-08, normalize_reward_std_target: float = 1.0, use_xformed_reward: bool = False, cql_weight: float = 0.0, cql_scale_q_eps: float = 1.0, reward_scale: float = 1.0, reward_shift: float = 0.0)

Bases: LossConfig

Symmetric two-head one-step TD at PREDICTION (see dqn_loss).

mouse.losses.dqn.dqn_loss

dqn_loss(step_stream: TensorDict, out: TensorDict, cfg: DqnLossConfig) -> tuple[torch.Tensor, dict[str, float]]
Source code in mouse/losses/dqn.py
def dqn_loss(
    step_stream: TensorDict,
    out: TensorDict,
    cfg: DqnLossConfig,
) -> tuple[torch.Tensor, dict[str, float]]:

    q: torch.Tensor = out["dqn"]
    q_target: torch.Tensor = out["dqn_target"]

    B, S, A = q.shape
    device = q.device
    value_dtype = q.dtype

    if S < 2:
        raise ValueError("Not enough valid q values in data.")

    action = step_stream["action"].to(dtype=torch.long)
    if cfg.use_xformed_reward:
        if "xformed_reward" not in step_stream.keys():
            raise KeyError(
                "use_xformed_reward=True but 'xformed_reward' is not in the batch. "
                "Ensure your dataset includes the 'xformed_reward' column."
            )
        reward = step_stream["xformed_reward"].to(dtype=value_dtype)
    else:
        reward = step_stream["reward"].to(dtype=value_dtype)
    terminals = (step_stream["done"] == 1).to(dtype=value_dtype)
    truncateds = (step_stream["done"] == 2).to(dtype=value_dtype)

    # Each token at position t encodes (obs_t, action_{t-1}, reward_{t-1}, done_{t-1}),
    # i.e. the action and reward stored at t are the ones that *produced* obs_t, not
    # the ones taken *from* obs_t.  Therefore the action, reward, and done that
    # correspond to the transition out of state t are stored one step ahead at t+1.
    # Consecutive (s, s+1) pairs within each batch row.
    curr_q = q[:, :-1, :]          # [B, S-1, A]  Q(s_t)
    next_q_target = q_target[:, 1:, :]  # [B, S-1, A]  Q_target(s_{t+1})
    next_actions = action[:, 1:]   # [B, S-1]     a_t (stored at t+1)
    next_rewards = reward[:, 1:]   # [B, S-1]     r_t (stored at t+1)
    next_terminals = terminals[:, 1:]  # [B, S-1]     terminal_t (stored at t+1)
    next_truncateds = truncateds[:, 1:]  # [B, S-1]     truncated_t (stored at t+1)

    q_values = curr_q.gather(dim=-1, index=next_actions.unsqueeze(-1)).squeeze(-1)  # [B, S-1]
    next_max_q_target = next_q_target.amax(dim=-1)                           # [B, S-1]

    if cfg.normalize_reward_mean:
        next_rewards = next_rewards - next_rewards.mean(dim=1, keepdim=True)
    if cfg.normalize_reward_std:
        next_rewards = (next_rewards / (next_rewards.std(dim=1, keepdim=True) + cfg.normalize_reward_eps)) * cfg.normalize_reward_std_target

    discount = (
        cfg.gamma * (1.0 - next_terminals - next_truncateds)
        + cfg.gamma_terminal * next_terminals
        + cfg.gamma_truncated * next_truncateds
    )
    next_rewards_adjusted = next_rewards * cfg.reward_scale + cfg.reward_shift
    td_target = next_rewards_adjusted + discount * next_max_q_target
    td_target = td_target.to(dtype=q_values.dtype)

    loss = (q_values - td_target.detach()) ** 2

    if cfg.cql_weight > 0.0:
        # CQL penalty ∝ Q while TD loss ∝ Q², so a fixed weight becomes ineffective
        # as Q grows.  Multiplying by q_scale brings CQL up to Q² so the ratio
        # cfg.cql_weight stays constant throughout training.
        q_scale = (td_target.abs() + cfg.cql_scale_q_eps).detach()
        cql_penalty = curr_q.logsumexp(dim=-1) - q_values
        loss = loss + cfg.cql_weight * q_scale * cql_penalty

    loss = loss.mean()

    q_det = q_values.detach()
    named: dict[str, torch.Tensor] = {
        "q_values_mean": q_det.mean(),
        "q_values_std":  q_det.std(),
        "q_values_min":  q_det.min(),
        "q_values_max":  q_det.max(),
        "q_values_target": td_target.detach().mean(),
        "dqn_loss":      loss.detach(),
    }
    if cfg.cql_weight > 0.0:
        named["cql_penalty"] = cql_penalty.detach().mean()

    metrics: dict[str, float] = dict(zip(named, torch.stack(list(named.values())).tolist()))

    return loss, metrics

Vector DQN loss

mouse.losses.vec_dqn.VecDqnLossConfig dataclass

VecDqnLossConfig(weight: float = 0.0, tau: float = 0.01, reward_scale: float = 1.0, reward_shift: float = 0.0, normalize_reward_mean: bool = False, normalize_reward_std: bool = False, normalize_reward_eps: float = 1e-08, normalize_reward_std_target: float = 1.0, use_xformed_reward: bool = False)

Bases: LossConfig

Vector-DQN cosine-similarity loss at PREDICTION (see vec_dqn_loss).

mouse.losses.vec_dqn.vec_dqn_loss

vec_dqn_loss(step_stream: TensorDict, online_vecs: Tensor, target_vecs: Tensor, cfg: VecDqnLossConfig) -> tuple[torch.Tensor, dict[str, float]]
Source code in mouse/losses/vec_dqn.py
def vec_dqn_loss(
    step_stream: TensorDict,
    online_vecs: torch.Tensor,
    target_vecs: torch.Tensor,
    cfg: VecDqnLossConfig,
) -> tuple[torch.Tensor, dict[str, float]]:

    B, S, A, D = online_vecs.shape
    device = online_vecs.device
    dtype = torch.float32

    if S < 2:
        raise ValueError("Not enough valid vec_dqn vectors in data.")

    action = step_stream["action"].to(dtype=torch.long)
    if cfg.use_xformed_reward:
        if "xformed_reward" not in step_stream.keys():
            raise KeyError(
                "use_xformed_reward=True but 'xformed_reward' is not in the batch. "
                "Ensure your dataset includes the 'xformed_reward' column."
            )
        reward = step_stream["xformed_reward"].to(dtype=dtype)
    else:
        reward = step_stream["reward"].to(dtype=dtype)
    online_vecs = online_vecs.to(dtype=dtype)
    target_vecs = target_vecs.to(dtype=dtype)

    # Each token at position t encodes (obs_t, action_{t-1}, reward_{t-1}, done_{t-1}),
    # i.e. the action and reward stored at t are the ones that *produced* obs_t, not
    # the ones taken *from* obs_t.  Therefore the action, reward, and done that
    # correspond to the transition out of state t are stored one step ahead at t+1.
    # Consecutive (s, s+1) pairs within each batch row.
    curr_vecs = online_vecs[:, :-1, :, :]   # [B, S-1, A, D]  vecs(s_t)
    next_vecs = target_vecs[:, 1:, :, :]    # [B, S-1, A, D]  vecs_target(s_{t+1})
    next_actions = action[:, 1:]            # [B, S-1]         a_t (stored at t+1)
    next_rewards = reward[:, 1:]            # [B, S-1]         r_t (stored at t+1)

    # curr: vector for the executed action at s_t (what we train).
    # next: vector for the GREEDY best action at s_{t+1} (bootstrap target),
    #       selected with the same rotate-90 scoring used at inference.
    action_idx_exp = next_actions.unsqueeze(-1).unsqueeze(-1).expand(B, S - 1, 1, D)
    curr_action_vecs = curr_vecs.gather(dim=2, index=action_idx_exp).squeeze(2)  # [B, S-1, D]

    greedy_idx = vec_dqn_scores(next_vecs).argmax(dim=-1)                         # [B, S-1]
    greedy_idx_exp = greedy_idx.unsqueeze(-1).unsqueeze(-1).expand(B, S - 1, 1, D)
    next_action_vecs = next_vecs.gather(dim=2, index=greedy_idx_exp).squeeze(2)   # [B, S-1, D]

    if cfg.normalize_reward_mean:
        next_rewards = next_rewards - next_rewards.mean(dim=1, keepdim=True)
    if cfg.normalize_reward_std:
        next_rewards = (next_rewards / (next_rewards.std(dim=1, keepdim=True) + cfg.normalize_reward_eps)) * cfg.normalize_reward_std_target

    theta = next_rewards * cfg.reward_scale + cfg.reward_shift                     # [B, S-1]
    rotated = rope_rotate(x=next_action_vecs, theta=theta)                         # [B, S-1, D]

    # Cosine similarity loss — detach target mirrors td_target.detach() in dqn_loss
    cosine_sim = F.cosine_similarity(curr_action_vecs, rotated.detach(), dim=-1)  # [B, S-1]
    loss = (1.0 - cosine_sim).mean()

    abs_scores = vec_dqn_scores(online_vecs[:, -1].float()).abs() / (math.pi)  # [B, A]
    named: dict[str, torch.Tensor] = {
        "vec_dqn_loss": loss.detach(),
        "vec_dqn_score_abs_min": abs_scores.min().detach(),
        "vec_dqn_score_abs_max": abs_scores.max().detach(),
        "vec_dqn_score_abs_mean": abs_scores.mean().detach(),
    }
    metrics: dict[str, float] = dict(zip(named, torch.stack(list(named.values())).tolist()))

    return loss, metrics

Supervised policy loss

mouse.losses.sp.SpLossConfig dataclass

SpLossConfig(weight: float = 0.0, label_smoothing: float = 0.0, loss_type: Literal['ce', 'ce-soft-fwd', 'ce-soft-bwd', 'js', 'kl-fwd', 'kl-bwd'] = 'ce', temperature: float = 1.0)

Bases: LossConfig

Supervised action loss at PREDICTION (see sp_loss).

mouse.losses.sp.sp_loss

sp_loss(step_stream: TensorDict, logits: Tensor, cfg: SpLossConfig) -> tuple[torch.Tensor, dict[str, float]]

Supervised policy loss over all [B, S] step positions.

Parameters:

Name Type Description Default
step_stream TensorDict

TensorDict of shape [B, S] containing q_star targets.

required
logits Tensor

[B, S, A] action logits.

required
cfg SpLossConfig

SP loss configuration (loss_type, temperature, label_smoothing).

required

Returns: Scalar loss and scalar metrics for logging (e.g. W&B).

Source code in mouse/losses/sp.py
def sp_loss(
    step_stream: TensorDict,
    logits: torch.Tensor,
    cfg: SpLossConfig,
) -> tuple[torch.Tensor, dict[str, float]]:
    """Supervised policy loss over all ``[B, S]`` step positions.

    Args:
        step_stream: TensorDict of shape ``[B, S]`` containing ``q_star`` targets.
        logits: ``[B, S, A]`` action logits.
        cfg: SP loss configuration (loss_type, temperature, label_smoothing).
    Returns:
        Scalar loss and scalar metrics for logging (e.g. W&B).
    """
    temp = float(cfg.temperature)

    A = logits.shape[-1]
    logits = logits.reshape(-1, A)
    q_targets = step_stream["q_star"].reshape(-1, A).to(dtype=logits.dtype)

    if q_targets.shape[0] == 0:
        raise ValueError("sp_loss: batch is empty (no tokens).")

    if not torch.isfinite(q_targets).all():
        raise ValueError("sp_loss: q_star contains non-finite values (NaN or inf).")

    if cfg.loss_type == "ce":
        target_actions = q_targets.argmax(dim=-1).to(dtype=torch.long)
        loss = F.cross_entropy(logits, target_actions, label_smoothing=cfg.label_smoothing)
    elif cfg.loss_type == "ce-soft-fwd":
        loss = sp_soft_ce_loss(
            q_targets=q_targets,
            logits=logits,
            temperature=temp,
            label_smoothing=cfg.label_smoothing,
            direction="fwd",
        )
    elif cfg.loss_type == "ce-soft-bwd":
        loss = sp_soft_ce_loss(
            q_targets=q_targets,
            logits=logits,
            temperature=temp,
            label_smoothing=cfg.label_smoothing,
            direction="bwd",
        )
    elif cfg.loss_type == "js":
        loss = sp_js_loss(
            q_targets=q_targets,
            logits=logits,
            temperature=temp,
            label_smoothing=cfg.label_smoothing,
        )
    elif cfg.loss_type == "kl-fwd":
        loss = sp_kl_loss(
            q_targets=q_targets,
            logits=logits,
            temperature=temp,
            label_smoothing=cfg.label_smoothing,
            direction="fwd",
        )
    elif cfg.loss_type == "kl-bwd":
        loss = sp_kl_loss(
            q_targets=q_targets,
            logits=logits,
            temperature=temp,
            label_smoothing=cfg.label_smoothing,
            direction="bwd",
        )
    else:
        raise ValueError(
            f"Invalid SP loss loss_type: {cfg.loss_type!r} "
            "(expected 'ce', 'ce-soft-fwd', 'ce-soft-bwd', 'js', 'kl-fwd', or 'kl-bwd')."
        )

    metrics: dict[str, float] = {"sp_loss": float(loss.detach().item())}

    return loss, metrics

mouse.losses.sp.sp_js_loss

sp_js_loss(q_targets: Tensor, logits: Tensor, temperature: float, label_smoothing: float = 0.0) -> torch.Tensor

Jensen–Shannon loss between teacher q_targets and student logits (aligned rows).

Builds temperature-scaled soft distributions, optional label smoothing on the teacher only, then JS = 0.5 KL(P‖M) + 0.5 KL(Q‖M) with M = 0.5 (P + Q), mean over rows, × T².

Parameters:

Name Type Description Default
q_targets Tensor

[N, A] teacher Q-values (e.g. q_star at PREDICTION rows).

required
logits Tensor

[N, A] student action logits at the same rows.

required
temperature float

Must be > 0; applied to both teacher and student logits.

required
label_smoothing float

Mixes uniform mass into the teacher distribution (probability space).

0.0
Source code in mouse/losses/sp.py
def sp_js_loss(
    q_targets: torch.Tensor,
    logits: torch.Tensor,
    temperature: float,
    label_smoothing: float = 0.0,
) -> torch.Tensor:
    """Jensen–Shannon loss between teacher ``q_targets`` and student ``logits`` (aligned rows).

    Builds temperature-scaled soft distributions, optional label smoothing on the teacher only,
    then ``JS = 0.5 KL(P‖M) + 0.5 KL(Q‖M)`` with ``M = 0.5 (P + Q)``, mean over rows, × T².

    Args:
        q_targets: ``[N, A]`` teacher Q-values (e.g. ``q_star`` at PREDICTION rows).
        logits: ``[N, A]`` student action logits at the same rows.
        temperature: Must be ``> 0``; applied to both teacher and student logits.
        label_smoothing: Mixes uniform mass into the teacher distribution (probability space).
    """
    temp = float(temperature)
    if temp <= 0.0:
        raise ValueError(f"sp_js_loss temperature must be > 0, got {temp}.")
    log_teacher = F.log_softmax(q_targets / temp, dim=-1)
    log_student = F.log_softmax(logits, dim=-1)
    if label_smoothing > 0.0:
        num_actions = q_targets.shape[-1]
        log_teacher = ((1.0 - label_smoothing) * log_teacher.exp() + label_smoothing / num_actions).log()

    log_m = torch.logaddexp(log_teacher, log_student) - math.log(2.0)
    # KL(P‖M) and KL(Q‖M) via kl_div(input=log M, target=log P, log_target=True)
    # → exp(log P) * (log P - log M). nan_to_num: -inf padding in q_star gives 0·(-inf) → NaN otherwise.
    kl_pm = torch.nan_to_num(
        F.kl_div(log_m, log_teacher, log_target=True, reduction="none"),
        nan=0.0,
    ).sum(dim=-1)
    kl_qm = torch.nan_to_num(
        F.kl_div(log_m, log_student, log_target=True, reduction="none"),
        nan=0.0,
    ).sum(dim=-1)
    js = 0.5 * (kl_pm + kl_qm)
    return js.mean()

mouse.losses.sp.sp_kl_loss

sp_kl_loss(q_targets: Tensor, logits: Tensor, temperature: float, label_smoothing: float = 0.0, direction: str = 'fwd') -> torch.Tensor

Temperature-scaled KL loss between teacher q_targets and student logits.

Parameters:

Name Type Description Default
q_targets Tensor

[N, A] teacher Q-values.

required
logits Tensor

[N, A] student logits.

required
temperature float

Must be > 0.

required
label_smoothing float

Optional smoothing applied to teacher distribution only.

0.0
direction str

"fwd" computes KL(P_teacher || Q_student); "bwd" computes KL(Q_student || P_teacher).

'fwd'
Source code in mouse/losses/sp.py
def sp_kl_loss(
    q_targets: torch.Tensor,
    logits: torch.Tensor,
    temperature: float,
    label_smoothing: float = 0.0,
    direction: str = "fwd",
) -> torch.Tensor:
    """Temperature-scaled KL loss between teacher ``q_targets`` and student ``logits``.

    Args:
        q_targets: ``[N, A]`` teacher Q-values.
        logits: ``[N, A]`` student logits.
        temperature: Must be ``> 0``.
        label_smoothing: Optional smoothing applied to teacher distribution only.
        direction: ``"fwd"`` computes ``KL(P_teacher || Q_student)``;
            ``"bwd"`` computes ``KL(Q_student || P_teacher)``.
    """
    temp = float(temperature)
    if temp <= 0.0:
        raise ValueError(f"sp_kl_loss temperature must be > 0, got {temp}.")
    if direction not in ("fwd", "bwd"):
        raise ValueError(f"sp_kl_loss direction must be 'fwd' or 'bwd', got {direction!r}.")
    log_teacher = F.log_softmax(q_targets / temp, dim=-1)
    log_student = F.log_softmax(logits, dim=-1)
    if label_smoothing > 0.0:
        num_actions = q_targets.shape[-1]
        log_teacher = ((1.0 - label_smoothing) * log_teacher.exp() + label_smoothing / num_actions).log()

    if direction == "fwd":
        kl = torch.nan_to_num(
            F.kl_div(log_student, log_teacher, log_target=True, reduction="none"),
            nan=0.0,
        ).sum(dim=-1)
    else:
        kl = torch.nan_to_num(
            F.kl_div(log_teacher, log_student, log_target=True, reduction="none"),
            nan=0.0,
        ).sum(dim=-1)
    return kl.mean()

mouse.losses.sp.sp_soft_ce_loss

sp_soft_ce_loss(q_targets: Tensor, logits: Tensor, temperature: float, label_smoothing: float = 0.0, direction: str = 'fwd') -> torch.Tensor

Directional soft cross-entropy between teacher q_targets and student logits.

Teacher targets are softmax(q_targets / temperature). Optional label smoothing is applied on the teacher distribution only.

  • direction="fwd" computes H(P_teacher, Q_student) = -sum P log Q.
  • direction="bwd" computes H(Q_student, P_teacher) = -sum Q log P.
Source code in mouse/losses/sp.py
def sp_soft_ce_loss(
    q_targets: torch.Tensor,
    logits: torch.Tensor,
    temperature: float,
    label_smoothing: float = 0.0,
    direction: str = "fwd",
) -> torch.Tensor:
    """Directional soft cross-entropy between teacher ``q_targets`` and student ``logits``.

    Teacher targets are ``softmax(q_targets / temperature)``. Optional label
    smoothing is applied on the teacher distribution only.

    - ``direction="fwd"`` computes ``H(P_teacher, Q_student) = -sum P log Q``.
    - ``direction="bwd"`` computes ``H(Q_student, P_teacher) = -sum Q log P``.
    """
    temp = float(temperature)
    if temp <= 0.0:
        raise ValueError(f"sp_soft_ce_loss temperature must be > 0, got {temp}.")
    if direction not in ("fwd", "bwd"):
        raise ValueError(f"sp_soft_ce_loss direction must be 'fwd' or 'bwd', got {direction!r}.")
    log_teacher = F.log_softmax(q_targets / temp, dim=-1)
    if label_smoothing > 0.0:
        num_actions = q_targets.shape[-1]
        log_teacher = ((1.0 - label_smoothing) * log_teacher.exp() + label_smoothing / num_actions).log()
    log_student = F.log_softmax(logits, dim=-1)
    if direction == "fwd":
        teacher = log_teacher.exp()
        per_row = torch.nan_to_num(-(teacher * log_student), nan=0.0).sum(dim=-1)
    else:
        student = log_student.exp()
        per_row = torch.nan_to_num(-(student * log_teacher), nan=0.0).sum(dim=-1)
    return per_row.mean()

Supervised value loss

mouse.losses.sv.SvLossConfig dataclass

SvLossConfig(weight: float = 0.0, loss_type: Literal['mse', 'mae'] = 'mse')

Bases: LossConfig

Supervised q_star loss at PREDICTION (see sv_loss).

mouse.losses.sv.sv_loss

sv_loss(step_stream: TensorDict, logits: Tensor, cfg: SvLossConfig) -> tuple[torch.Tensor, dict[str, float]]

Supervised q_star loss over all [B, S] step positions, restricted to finite action slots.

q_star_tok uses -inf as a sentinel for padded/invalid actions; only finite entries participate in the loss so padding never contributes gradients.

Returns:

Type Description
tuple[Tensor, dict[str, float]]

Scalar loss and scalar metrics for logging (e.g. W&B).

Source code in mouse/losses/sv.py
def sv_loss(
    step_stream: TensorDict,
    logits: torch.Tensor,
    cfg: SvLossConfig,
) -> tuple[torch.Tensor, dict[str, float]]:
    """Supervised q_star loss over all ``[B, S]`` step positions, restricted to finite action slots.

    ``q_star_tok`` uses ``-inf`` as a sentinel for padded/invalid actions; only finite entries
    participate in the loss so padding never contributes gradients.

    Returns:
        Scalar loss and scalar metrics for logging (e.g. W&B).
    """
    A = logits.shape[-1]
    logits = logits.reshape(-1, A)
    q_targets = step_stream["q_star"].reshape(-1, A).to(dtype=logits.dtype)  # [B*S, A]

    if q_targets.shape[0] == 0:
        raise ValueError("sv_loss: batch is empty (no tokens).")

    finite_mask = torch.isfinite(q_targets)  # [N, A]
    if not finite_mask.any():
        raise ValueError("sv_loss: q_star contains no finite values (all NaN or -inf).")

    if cfg.loss_type == "mse":
        loss = F.mse_loss(logits[finite_mask], q_targets[finite_mask])
    elif cfg.loss_type == "mae":
        loss = F.l1_loss(logits[finite_mask], q_targets[finite_mask])
    else:
        raise ValueError(
            f"Invalid SV loss loss_type: {cfg.loss_type!r} (expected 'mse' or 'mae')."
        )

    metrics: dict[str, float] = {}
    metrics["sv_loss"] = float(loss.detach().item())

    return loss, metrics