Skip to content

Data

DatasetStore

mouse.data.dataset_store.DatasetStore

DatasetStore(max_action_dim: int = 1000, max_obs_continuous_dim: int = 0, max_obs_discrete_dim: int = 0, max_obs_image_pixels: int = 0)

HuggingFace Dataset-backed step buffer.

Parameters

max_action_dim : Maximum number of discrete actions; used to clip q_star columns. max_obs_continuous_dim : Number of continuous observation dimensions to retain; 0 = no continuous obs. max_obs_discrete_dim : Number of discrete observation dimensions; 0 = no discrete obs. max_obs_image_pixels : Number of pixels per image observation; 0 = no image obs.

Source code in mouse/data/dataset_store.py
def __init__(
    self,
    max_action_dim: int = 1000,
    max_obs_continuous_dim: int = 0,
    max_obs_discrete_dim: int = 0,
    max_obs_image_pixels: int = 0,
) -> None:
    if int(max_action_dim) <= 0:
        raise ValueError(f"max_action_dim must be positive, got {max_action_dim}")
    self._max_action_dim = int(max_action_dim)
    self._max_obs_continuous_dim = int(max_obs_continuous_dim)
    self._max_obs_discrete_dim = int(max_obs_discrete_dim)
    self._max_obs_image_pixels = int(max_obs_image_pixels)

    # Source segment — HF Dataset stored by reference, never mutated.
    self._source: Dataset | None = None

    # Buf segment — raw row dicts for ``append`` (rollout path).
    self._rows: list[dict] = []

__getitem__

__getitem__(indices: Any) -> TensorDict

Return encoded step records for the given indices as a TensorDict[N].

Source code in mouse/data/dataset_store.py
def __getitem__(self, indices: Any) -> TensorDict:
    """Return encoded step records for the given indices as a TensorDict[N]."""
    src_len = self._src_len
    idx = np.asarray(indices).ravel()

    if self._buf_len == 0:
        return self.encode_hf_rows(self._source[idx.tolist()])  # type: ignore[index]

    if src_len == 0:
        rows = [self._rows[int(i)] for i in idx]
        return self.encode_hf_rows(_rows_to_hf_dict(rows))

    # Mixed: some indices from source, some from buf.
    src_mask = idx < src_len
    buf_mask = ~src_mask
    src_positions = np.where(src_mask)[0]
    buf_positions = np.where(buf_mask)[0]

    src_td = (
        self.encode_hf_rows(self._source[idx[src_mask].tolist()])
        if src_mask.any() else None
    )
    buf_td = None
    if buf_mask.any():
        buf_rows = [self._rows[int(i) - src_len] for i in idx[buf_mask]]
        buf_td = self.encode_hf_rows(_rows_to_hf_dict(buf_rows))

    return _interleave_tds(len(idx), src_td, src_positions, buf_td, buf_positions)

append

append(data: dict[str, Any]) -> None

Append one transition to the rollout buffer.

Source code in mouse/data/dataset_store.py
def append(self, data: dict[str, Any]) -> None:
    """Append one transition to the rollout buffer."""
    if not data:
        raise ValueError("Row cannot be empty.")
    self._rows.append(dict(data))

encode_hf_rows

encode_hf_rows(rows: dict[str, Any]) -> TensorDict

Encode a HF Dataset batch (dict-of-lists) into a TensorDict[N].

Required fields (action, reward, done) are always present. Optional fields (xformed_reward, q_star, time, obs_*) are only included as keys when present in rows — no zero-fill fallbacks, no silent substitutions.

Source code in mouse/data/dataset_store.py
def encode_hf_rows(self, rows: dict[str, Any]) -> TensorDict:
    """Encode a HF Dataset batch (dict-of-lists) into a TensorDict[N].

    Required fields (``action``, ``reward``, ``done``) are always present.
    Optional fields (``xformed_reward``, ``q_star``, ``time``, ``obs_*``)
    are only included as keys when present in ``rows`` — no zero-fill
    fallbacks, no silent substitutions.
    """
    cols = set(rows.keys())
    n = len(rows[next(iter(cols))]) if cols else 0
    tensors: dict[str, torch.Tensor] = {}

    # Required fields
    tensors["action"] = torch.from_numpy(np.asarray(rows["action"], dtype=np.int64))
    tensors["reward"] = torch.from_numpy(np.asarray(rows["reward"], dtype=np.float32))
    tensors["done"]   = torch.from_numpy(np.asarray(rows["done"],   dtype=np.int64))

    # Optional scalar fields
    if "xformed_reward" in cols:
        tensors["xformed_reward"] = torch.from_numpy(
            np.asarray(rows["xformed_reward"], dtype=np.float32)
        )
    if "episode_step" in cols:
        tensors["time"] = torch.from_numpy(
            np.asarray(rows["episode_step"], dtype=np.int64)
        )

    # q_star — only if data has it and at least the first entry is not None
    if "metadata_q_star" in cols:
        q_list = rows["metadata_q_star"]
        if q_list[0] is not None:
            q_dim = self._max_action_dim
            q_buf = np.full((n, q_dim), -np.inf, dtype=np.float32)
            if all(v is not None for v in q_list):
                q_arr = np.asarray(q_list, dtype=np.float32)
                qdim = min(q_arr.shape[-1], q_dim)
                q_buf[:, :qdim] = q_arr[:, :qdim]
            else:
                for i, v in enumerate(q_list):
                    if v is not None:
                        qa = np.asarray(v, dtype=np.float32).ravel()
                        qdim = min(qa.size, q_dim)
                        q_buf[i, :qdim] = qa[:qdim]
            tensors["q_star"] = torch.from_numpy(q_buf)

    # Continuous observation
    if "observation" in cols and self._max_obs_continuous_dim > 0:
        obs_list = rows["observation"]
        if obs_list[0] is not None:
            obs = np.asarray(obs_list, dtype=np.float64)
            max_dim = self._max_obs_continuous_dim
            odim = min(obs.shape[-1], max_dim)
            out = np.zeros((n, max_dim), dtype=np.float64)
            out[:, :odim] = obs[:, :odim]
            tensors["obs_continuous"] = torch.from_numpy(out)

    # Discrete observation
    if "observation_discrete" in cols and self._max_obs_discrete_dim > 0:
        obs_list = rows["observation_discrete"]
        if obs_list[0] is not None:
            tensors["obs_discrete"] = torch.from_numpy(
                np.asarray(obs_list, dtype=np.int64)
            )

    # Image observation
    if "observation_image" in cols and self._max_obs_image_pixels > 0:
        img_dim = self._max_obs_image_pixels
        img_buf = np.zeros((n, img_dim), dtype=np.int64)
        for i, img_val in enumerate(rows["observation_image"]):
            if img_val is not None:
                arr = self._image_to_uint8(img_val).ravel().astype(np.int64)
                d = min(arr.size, img_dim)
                img_buf[i, :d] = arr[:d]
        tensors["obs_image"] = torch.from_numpy(img_buf)

    return TensorDict(tensors, batch_size=[n])

from_dataset

from_dataset(ds: 'Dataset | datasets.DatasetDict', splits: 'list[str] | None' = None, split_pattern: 'str | list[str] | None' = None) -> None

Load data from a HuggingFace Dataset or DatasetDict.

When ds is a Dataset the rows are appended to the source segment (zero-copy reference). When ds is a DatasetDict, the splits to load are selected by splits (exact names), split_pattern (one or more glob patterns), or — if neither is provided — all splits. Calling from_dataset more than once concatenates onto what is already loaded.

Parameters:

Name Type Description Default
ds 'Dataset | datasets.DatasetDict'

A Dataset or DatasetDict (e.g. from load_dataset).

required
splits 'list[str] | None'

Exact split names to include. Raises KeyError if a name is not present. Mutually exclusive with split_pattern.

None
split_pattern 'str | list[str] | None'

Glob pattern string or list of glob pattern strings. Every split whose name matches any pattern is included. Uses fnmatch rules — * matches anything, ? matches one character. E.g. "train_*" loads train_frozenlake, train_lunar, … Raises KeyError if nothing matches.

None

Examples::

# Single split
store.from_dataset(load_dataset("org/ds", split="train"))

# All splits from a DatasetDict
store.from_dataset(load_dataset("org/ds"))

# Selected splits by exact name
store.from_dataset(load_dataset("org/ds"), splits=["train", "test"])

# Glob patterns — all train_ and eval_ splits
store.from_dataset(load_dataset("org/ds"), split_pattern=["train_*", "eval_*"])

# Single pattern
store.from_dataset(load_dataset("org/ds"), split_pattern="test_*")
Source code in mouse/data/dataset_store.py
def from_dataset(
    self,
    ds: "Dataset | datasets.DatasetDict",
    splits: "list[str] | None" = None,
    split_pattern: "str | list[str] | None" = None,
) -> None:
    """Load data from a HuggingFace ``Dataset`` or ``DatasetDict``.

    When *ds* is a ``Dataset`` the rows are appended to the source segment
    (zero-copy reference).  When *ds* is a ``DatasetDict``, the splits to
    load are selected by *splits* (exact names), *split_pattern* (one or
    more glob patterns), or — if neither is provided — all splits.  Calling
    ``from_dataset`` more than once concatenates onto what is already loaded.

    Args:
        ds: A ``Dataset`` or ``DatasetDict`` (e.g. from ``load_dataset``).
        splits: Exact split names to include.  Raises ``KeyError`` if a
            name is not present.  Mutually exclusive with *split_pattern*.
        split_pattern: Glob pattern string or list of glob pattern strings.
            Every split whose name matches any pattern is included.
            Uses ``fnmatch`` rules — ``*`` matches anything, ``?`` matches
            one character.  E.g. ``"train_*"`` loads ``train_frozenlake``,
            ``train_lunar``, …  Raises ``KeyError`` if nothing matches.

    Examples::

        # Single split
        store.from_dataset(load_dataset("org/ds", split="train"))

        # All splits from a DatasetDict
        store.from_dataset(load_dataset("org/ds"))

        # Selected splits by exact name
        store.from_dataset(load_dataset("org/ds"), splits=["train", "test"])

        # Glob patterns — all train_ and eval_ splits
        store.from_dataset(load_dataset("org/ds"), split_pattern=["train_*", "eval_*"])

        # Single pattern
        store.from_dataset(load_dataset("org/ds"), split_pattern="test_*")
    """
    import fnmatch

    if isinstance(ds, datasets.DatasetDict):
        if splits is not None and split_pattern is not None:
            raise ValueError("Provide splits or split_pattern, not both.")

        if splits is not None:
            keys = splits
            for key in keys:
                if key not in ds:
                    raise KeyError(
                        f"Split {key!r} not found in DatasetDict. "
                        f"Available splits: {list(ds.keys())}"
                    )
        elif split_pattern is not None:
            patterns = [split_pattern] if isinstance(split_pattern, str) else list(split_pattern)
            keys = [k for k in ds.keys() if any(fnmatch.fnmatch(k, p) for p in patterns)]
            if not keys:
                raise KeyError(
                    f"No splits match pattern(s) {patterns!r}. "
                    f"Available splits: {list(ds.keys())}"
                )
        else:
            keys = list(ds.keys())

        for key in keys:
            self.from_dataset(ds[key])
        return

    if len(ds) == 0:
        return
    self._source = ds if self._source is None else concatenate_datasets([self._source, ds])

to_dataset

to_dataset() -> Dataset

Return a HuggingFace Dataset of all steps.

Source-only: returns the reference directly (zero-copy). Buf-only: builds a Dataset from the persisted raw rows. Both: concatenates source and buf datasets.

Source code in mouse/data/dataset_store.py
def to_dataset(self) -> Dataset:
    """Return a HuggingFace Dataset of all steps.

    Source-only: returns the reference directly (zero-copy).
    Buf-only: builds a Dataset from the persisted raw rows.
    Both: concatenates source and buf datasets.
    """
    buf_ds = self._build_dataset(self._rows) if self._buf_len > 0 else None
    if self._source is None and buf_ds is None:
        return Dataset.from_list([])
    if self._source is None:
        return buf_ds  # type: ignore[return-value]
    if buf_ds is None:
        return self._source
    return concatenate_datasets([self._source, buf_ds])

merge_stores_to_dataset classmethod

merge_stores_to_dataset(stores: list[DatasetStore]) -> Dataset

Concatenate multiple DatasetStores into one HF Dataset.

Source code in mouse/data/dataset_store.py
@classmethod
def merge_stores_to_dataset(cls, stores: list[DatasetStore]) -> Dataset:
    """Concatenate multiple DatasetStores into one HF Dataset."""
    parts = [p for s in stores if len(p := s.to_dataset()) > 0]
    return concatenate_datasets(parts) if parts else Dataset.from_list([])

clear

clear() -> None

Reset the store.

Source code in mouse/data/dataset_store.py
def clear(self) -> None:
    """Reset the store."""
    self._source = None
    self._rows.clear()

PrefetchBatchifier

mouse.data.batch.PrefetchBatchifier

PrefetchBatchifier(store: DatasetStore, sequence_length: int, batch_size: int, sampling: str = 'random', prefetch: int = 4, num_workers: int = 1, pin_memory: bool = False)

Background-thread batchifier for a DatasetStore source.

Parameters

store : DatasetStore with from_dataset already called. sequence_length : Number of consecutive steps per sequence. batch_size : Number of sequences per batch. sampling : How start indices are chosen; see _sample_starts. prefetch : int Pre-encoded batches to keep ready. Higher values smooth over slow Arrow reads; each batch costs batch_size × sequence_length encoded steps in memory. num_workers : int Background threads. sequential/batch modes ignore this and always use one worker to preserve epoch order. Pass 0 to skip threading entirely and fetch synchronously on the calling thread. pin_memory : bool If True, workers call .pin_memory() on each batch before queuing it. Pinned CPU tensors enable DMA-backed, non-blocking H2D copies on the main thread. Only effective when the training device is CUDA. Ignored in synchronous mode (num_workers=0).

Source code in mouse/data/batch.py
def __init__(
    self,
    store: DatasetStore,
    sequence_length: int,
    batch_size: int,
    sampling: str = "random",
    prefetch: int = 4,
    num_workers: int = 1,
    pin_memory: bool = False,
) -> None:
    from mouse.data.dataset_store import DatasetStore as _DS

    if not isinstance(store, _DS) or store._source is None:
        raise TypeError(
            "PrefetchBatchifier requires a DatasetStore with a loaded source dataset. "
            "Call store.from_dataset(ds) first."
        )
    if sampling not in ("batch", "random", "sequential", "last"):
        raise ValueError(f"sampling must be one of batch/random/sequential/last, got {sampling!r}")

    self.store = store
    self.sequence_length = sequence_length
    self.batch_size = batch_size
    self.sampling = sampling

    self._dataset = store._source
    self._n = len(self._dataset)

    # Epoch-order state for sequential / batch / synchronous modes.
    self._lock = threading.Lock()
    self._next_window: int = 0
    self._window_order: np.ndarray = self._new_epoch_order()

    self._pin_memory = pin_memory

    n_workers = num_workers if sampling in ("random", "last") else 1

    if n_workers == 0:
        # Synchronous mode: fetch directly on the calling thread.
        self._sync_rng: np.random.Generator | None = np.random.default_rng(seed=0)
        self._result_queue: queue.Queue[TensorDict] | None = None
        self._stop: threading.Event | None = None
        self._worker_error: BaseException | None = None
        self._workers: list[threading.Thread] = []
    else:
        self._sync_rng = None
        self._result_queue = queue.Queue(maxsize=prefetch)
        self._stop = threading.Event()
        self._worker_error = None
        self._workers = [
            threading.Thread(
                target=self._worker_loop,
                args=(np.random.default_rng(seed=i),),
                daemon=True,
                name=f"PrefetchBatchifier-{i}",
            )
            for i in range(n_workers)
        ]
        for w in self._workers:
            w.start()

total_batches property

total_batches: int

Approximate non-overlapping windows in the dataset.

next_batch

next_batch() -> TensorDict

Return the next pre-encoded batch, blocking until one is ready.

Source code in mouse/data/batch.py
def next_batch(self) -> TensorDict:
    """Return the next pre-encoded batch, blocking until one is ready."""
    if self._sync_rng is not None:
        return self._fetch_one_batch(self._sync_rng)
    assert self._result_queue is not None
    while True:
        if self._worker_error is not None:
            raise RuntimeError("A prefetch worker raised an exception.") from self._worker_error
        try:
            return self._result_queue.get(timeout=0.05)
        except queue.Empty:
            if not any(w.is_alive() for w in self._workers):
                raise RuntimeError("All prefetch workers stopped unexpectedly.")

close

close() -> None

Stop background workers and drain the queue.

Source code in mouse/data/batch.py
def close(self) -> None:
    """Stop background workers and drain the queue."""
    if self._stop is None:
        return  # synchronous mode — nothing to tear down
    self._stop.set()
    assert self._result_queue is not None
    while True:
        try:
            self._result_queue.get_nowait()
        except queue.Empty:
            break
    for w in self._workers:
        w.join(timeout=2.0)

TokenAugmenter

mouse.data.augment.TokenAugmenter

TokenAugmenter(augment: AugmentTokensConfig, max_num_actions: int, max_num_obs_discrete: int, device: device, generator: Generator | None = None)

Applies AugmentTokensConfig to a step TensorDict batch.

Call with step_stream to obtain a possibly augmented copy. :meth:__call__ applies permutations/scalars from the stored snapshot; mask_prob is sampled anew each call. Call :meth:update_augmentations first (required whenever any augmentation is enabled).

Source code in mouse/data/augment.py
def __init__(
    self,
    augment: AugmentTokensConfig,
    max_num_actions: int,
    max_num_obs_discrete: int,
    device: torch.device,
    generator: torch.Generator | None = None,
) -> None:
    if not isinstance(augment, AugmentTokensConfig):
        raise TypeError(f"augment must be AugmentTokensConfig, got {type(augment).__name__}")
    self._augment = augment
    self._max_num_actions = int(max_num_actions)
    self._max_num_obs_discrete = int(max_num_obs_discrete)
    self._generator = generator if generator is not None else torch.Generator(device=device)
    self._snapshot: AugmentSnapshot | None = None

update_augmentations

update_augmentations(step_stream: TensorDict) -> None

Sample permutations and scalar parameters for this batch and store them.

Source code in mouse/data/augment.py
@torch.no_grad()
def update_augmentations(self, step_stream: TensorDict) -> None:
    """Sample permutations and scalar parameters for this batch and store them."""
    augment = self._augment
    if not augment.any_enabled():
        self._snapshot = None
        return

    action = step_stream["action"]
    B = int(action.shape[0])
    dev = cast(torch.device, action.device)
    g = self._generator

    perm_action: torch.Tensor | None = None
    if augment.permute_action_enabled():
        perm_action = torch.stack(
            [torch.randperm(self._max_num_actions, device=dev, generator=g) for _ in range(B)],
            dim=0,
        )

    perm_done: torch.Tensor | None = None
    if augment.permute_done:
        perm_done = torch.stack(
            [torch.randperm(3, device=dev, generator=g) for _ in range(B)],
            dim=0,
        )

    perm_obs_discrete: torch.Tensor | None = None
    if augment.permute_obs_discrete:
        perm_obs_discrete = torch.stack(
            [torch.randperm(self._max_num_obs_discrete, device=dev, generator=g) for _ in range(B)],
            dim=0,
        )

    r_scale: float | None = None
    r_shift: float | None = None
    if _augment_scalar_active(augment.scale_reward, 1.0):
        r_scale = _sample_scalar(augment.scale_reward, g)
    if _augment_scalar_active(augment.shift_reward, 0.0):
        r_shift = _sample_scalar(augment.shift_reward, g)

    o_scale: float | None = None
    o_shift: float | None = None
    if _augment_scalar_active(augment.scale_obs, 1.0):
        o_scale = _sample_scalar(augment.scale_obs, g)
    if _augment_scalar_active(augment.shift_obs, 0.0):
        o_shift = _sample_scalar(augment.shift_obs, g)

    im_scale: float | None = None
    im_shift: float | None = None
    if _augment_scalar_active(augment.scale_obs_image, 1.0):
        im_scale = _sample_scalar(augment.scale_obs_image, g)
    if _augment_scalar_active(augment.shift_obs_image, 0.0):
        im_shift = _sample_scalar(augment.shift_obs_image, g)

    self._snapshot = AugmentSnapshot(
        batch_size=B,
        device=dev,
        perm_action=perm_action,
        perm_done=perm_done,
        r_scale=r_scale,
        r_shift=r_shift,
        o_scale=o_scale,
        o_shift=o_shift,
        im_scale=im_scale,
        im_shift=im_shift,
        perm_obs_discrete=perm_obs_discrete,
    )

__call__

__call__(step_stream: TensorDict) -> TensorDict

Augment a training batch; returns a new TensorDict when augmentation runs.

Requires step_stream shape [B, S] per field. Call :meth:update_augmentations with the same batch first. Permutations/scalars use :attr:snapshot; mask_prob is drawn fresh here.

Source code in mouse/data/augment.py
@torch.no_grad()
def __call__(
    self,
    step_stream: TensorDict,
) -> TensorDict:
    """Augment a training batch; returns a new TensorDict when augmentation runs.

    Requires ``step_stream`` shape ``[B, S]`` per field.
    Call :meth:`update_augmentations` with the same batch first.
    Permutations/scalars use :attr:`snapshot`; ``mask_prob`` is drawn fresh here.
    """
    augment = self._augment
    if not augment.any_enabled():
        return step_stream

    step_stream = step_stream.clone()
    snap = self._assert_snapshot_matches(step_stream)

    # MLM-style masks first (corrupt inputs before permute/scale)
    if augment.mask_prob.any_positive():
        apply_field_masks(step_stream=step_stream, mask_prob=augment.mask_prob, generator=self._generator)

    if augment.permute_action_enabled():
        assert snap.perm_action is not None
        mode = augment.permute_action_mode()
        apply_permute_action_augmentation(
            step_stream=step_stream,
            perm=snap.perm_action,
            apply_to_input=mode in ("input", "both"),
            apply_to_target=mode in ("target", "both"),
        )

    if augment.permute_done:
        assert snap.perm_done is not None
        apply_permute_done_augmentation(step_stream=step_stream, perm=snap.perm_done)

    if augment.permute_obs_discrete:
        assert snap.perm_obs_discrete is not None
        apply_permute_obs_discrete_augmentation(step_stream=step_stream, perm=snap.perm_obs_discrete)

    if _augment_scalar_active(spec=augment.scale_reward, identity_mean=1.0):
        assert snap.r_scale is not None
        apply_reward_scale_shift(step_stream=step_stream, scale=snap.r_scale, shift=0.0)

    if _augment_scalar_active(spec=augment.shift_reward, identity_mean=0.0):
        assert snap.r_shift is not None
        apply_reward_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.r_shift)

    if _augment_scalar_active(spec=augment.scale_obs, identity_mean=1.0):
        assert snap.o_scale is not None
        apply_obs_continuous_scale_shift(step_stream=step_stream, scale=snap.o_scale, shift=0.0)

    if _augment_scalar_active(spec=augment.shift_obs, identity_mean=0.0):
        assert snap.o_shift is not None
        apply_obs_continuous_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.o_shift)

    if _augment_scalar_active(spec=augment.scale_obs_image, identity_mean=1.0):
        assert snap.im_scale is not None
        apply_obs_image_scale_shift(step_stream=step_stream, scale=snap.im_scale, shift=0.0)

    if _augment_scalar_active(spec=augment.shift_obs_image, identity_mean=0.0):
        assert snap.im_shift is not None
        apply_obs_image_scale_shift(step_stream=step_stream, scale=1.0, shift=snap.im_shift)

    return step_stream

AugmentTokensConfig

mouse.data.augment.AugmentTokensConfig dataclass

AugmentTokensConfig(enabled: bool = True, permute_tokens: bool = False, scale_reward: AugmentScalarSpec = (lambda: AugmentScalarSpec(1.0, 0.0))(), shift_reward: AugmentScalarSpec = (lambda: AugmentScalarSpec(0.0, 0.0))(), scale_obs: AugmentScalarSpec = (lambda: AugmentScalarSpec(1.0, 0.0))(), shift_obs: AugmentScalarSpec = (lambda: AugmentScalarSpec(0.0, 0.0))(), scale_obs_image: AugmentScalarSpec = (lambda: AugmentScalarSpec(1.0, 0.0))(), shift_obs_image: AugmentScalarSpec = (lambda: AugmentScalarSpec(0.0, 0.0))(), permute_obs_discrete: bool = False, permute_action: Literal[False, 'input', 'target', 'both'] = False, permute_done: bool = False, mask_prob: AugmentMaskProbConfig = AugmentMaskProbConfig())

Optional training-time token augmentations (copied streams on train batches).

PREDICTION and COMPUTE tokens are never modified. See augment_tokens.TokenAugmenter. mask_prob enables MLM-style random zero-masking per token type. Set enabled: false to disable all augmentations.


AugmentScalarSpec

mouse.data.augment.AugmentScalarSpec dataclass

AugmentScalarSpec(mean: float, std: float = 0.0, low: float | None = None, high: float | None = None)

Scalar augmentation: uniform on [low, high) per batch, or Gaussian mean + std * N(0,1).

YAML configs typically use low/high (low == high fixes at that value = identity). Omit low/high for Gaussian; std == 0 fixes at mean.


AugmentMaskProbConfig

mouse.data.augment.AugmentMaskProbConfig dataclass

AugmentMaskProbConfig(action: float = 0.0, reward: float = 0.0, done: float = 0.0, obs_continuous: float = 0.0, obs_discrete: float = 0.0, obs_image: float = 0.0, time: float = 0.0)

Per-type Bernoulli mask probability (MLM-style): each eligible token row is masked i.i.d.

Masked rows replace payloads with a neutral value (0 / zero float / black pixel); step_stream fields at those positions are aligned. PREDICTION and COMPUTE rows are never masked.