Losses¶
All loss functions return (scalar_loss, dict[str, float]). The dict is ready to log directly to W&B or TensorBoard.
DQN loss¶
mouse.losses.dqn.DqnLossConfig
dataclass
¶
DqnLossConfig(weight: float = 0.0, gamma: float = 0.99, gamma_terminal: float = 0.0, gamma_truncated: float = 0.0, tau: float = 0.01, normalize_reward_mean: bool = False, normalize_reward_std: bool = False, normalize_reward_eps: float = 1e-08, normalize_reward_std_target: float = 1.0, use_xformed_reward: bool = False, cql_weight: float = 0.0, cql_scale_q_eps: float = 1.0, reward_scale: float = 1.0, reward_shift: float = 0.0)
Bases: LossConfig
Symmetric two-head one-step TD at PREDICTION (see dqn_loss).
mouse.losses.dqn.dqn_loss
¶
dqn_loss(step_stream: TensorDict, out: TensorDict, cfg: DqnLossConfig) -> tuple[torch.Tensor, dict[str, float]]
Source code in mouse/losses/dqn.py
Vector DQN loss¶
mouse.losses.vec_dqn.VecDqnLossConfig
dataclass
¶
VecDqnLossConfig(weight: float = 0.0, tau: float = 0.01, reward_scale: float = 1.0, reward_shift: float = 0.0, normalize_reward_mean: bool = False, normalize_reward_std: bool = False, normalize_reward_eps: float = 1e-08, normalize_reward_std_target: float = 1.0, use_xformed_reward: bool = False)
Bases: LossConfig
Vector-DQN cosine-similarity loss at PREDICTION (see vec_dqn_loss).
mouse.losses.vec_dqn.vec_dqn_loss
¶
vec_dqn_loss(step_stream: TensorDict, online_vecs: Tensor, target_vecs: Tensor, cfg: VecDqnLossConfig) -> tuple[torch.Tensor, dict[str, float]]
Source code in mouse/losses/vec_dqn.py
Supervised policy loss¶
mouse.losses.sp.SpLossConfig
dataclass
¶
SpLossConfig(weight: float = 0.0, label_smoothing: float = 0.0, loss_type: Literal['ce', 'ce-soft-fwd', 'ce-soft-bwd', 'js', 'kl-fwd', 'kl-bwd'] = 'ce', temperature: float = 1.0)
Bases: LossConfig
Supervised action loss at PREDICTION (see sp_loss).
mouse.losses.sp.sp_loss
¶
sp_loss(step_stream: TensorDict, logits: Tensor, cfg: SpLossConfig) -> tuple[torch.Tensor, dict[str, float]]
Supervised policy loss over all [B, S] step positions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_stream
|
TensorDict
|
TensorDict of shape |
required |
logits
|
Tensor
|
|
required |
cfg
|
SpLossConfig
|
SP loss configuration (loss_type, temperature, label_smoothing). |
required |
Returns: Scalar loss and scalar metrics for logging (e.g. W&B).
Source code in mouse/losses/sp.py
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | |
mouse.losses.sp.sp_js_loss
¶
sp_js_loss(q_targets: Tensor, logits: Tensor, temperature: float, label_smoothing: float = 0.0) -> torch.Tensor
Jensen–Shannon loss between teacher q_targets and student logits (aligned rows).
Builds temperature-scaled soft distributions, optional label smoothing on the teacher only,
then JS = 0.5 KL(P‖M) + 0.5 KL(Q‖M) with M = 0.5 (P + Q), mean over rows, × T².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_targets
|
Tensor
|
|
required |
logits
|
Tensor
|
|
required |
temperature
|
float
|
Must be |
required |
label_smoothing
|
float
|
Mixes uniform mass into the teacher distribution (probability space). |
0.0
|
Source code in mouse/losses/sp.py
mouse.losses.sp.sp_kl_loss
¶
sp_kl_loss(q_targets: Tensor, logits: Tensor, temperature: float, label_smoothing: float = 0.0, direction: str = 'fwd') -> torch.Tensor
Temperature-scaled KL loss between teacher q_targets and student logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q_targets
|
Tensor
|
|
required |
logits
|
Tensor
|
|
required |
temperature
|
float
|
Must be |
required |
label_smoothing
|
float
|
Optional smoothing applied to teacher distribution only. |
0.0
|
direction
|
str
|
|
'fwd'
|
Source code in mouse/losses/sp.py
mouse.losses.sp.sp_soft_ce_loss
¶
sp_soft_ce_loss(q_targets: Tensor, logits: Tensor, temperature: float, label_smoothing: float = 0.0, direction: str = 'fwd') -> torch.Tensor
Directional soft cross-entropy between teacher q_targets and student logits.
Teacher targets are softmax(q_targets / temperature). Optional label
smoothing is applied on the teacher distribution only.
direction="fwd"computesH(P_teacher, Q_student) = -sum P log Q.direction="bwd"computesH(Q_student, P_teacher) = -sum Q log P.
Source code in mouse/losses/sp.py
Supervised value loss¶
mouse.losses.sv.SvLossConfig
dataclass
¶
Bases: LossConfig
Supervised q_star loss at PREDICTION (see sv_loss).
mouse.losses.sv.sv_loss
¶
sv_loss(step_stream: TensorDict, logits: Tensor, cfg: SvLossConfig) -> tuple[torch.Tensor, dict[str, float]]
Supervised q_star loss over all [B, S] step positions, restricted to finite action slots.
q_star_tok uses -inf as a sentinel for padded/invalid actions; only finite entries
participate in the loss so padding never contributes gradients.
Returns:
| Type | Description |
|---|---|
tuple[Tensor, dict[str, float]]
|
Scalar loss and scalar metrics for logging (e.g. W&B). |