Data¶
DatasetStore¶
mouse.data.dataset_store.DatasetStore
¶
DatasetStore(max_action_dim: int = 1000, max_obs_continuous_dim: int = 0, max_obs_discrete_dim: int = 0, max_obs_image_pixels: int = 0)
HuggingFace Dataset-backed step buffer.
Parameters¶
max_action_dim : Maximum number of discrete actions; used to clip q_star columns. max_obs_continuous_dim : Number of continuous observation dimensions to retain; 0 = no continuous obs. max_obs_discrete_dim : Number of discrete observation dimensions; 0 = no discrete obs. max_obs_image_pixels : Number of pixels per image observation; 0 = no image obs.
Source code in mouse/data/dataset_store.py
__getitem__
¶
Return encoded step records for the given indices as a TensorDict[N].
Source code in mouse/data/dataset_store.py
append
¶
encode_hf_rows
¶
Encode a HF Dataset batch (dict-of-lists) into a TensorDict[N].
Required fields (action, reward, done) are always present.
Optional fields (xformed_reward, q_star, time, obs_*)
are only included as keys when present in rows — no zero-fill
fallbacks, no silent substitutions.
Source code in mouse/data/dataset_store.py
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | |
from_dataset
¶
from_dataset(ds: 'Dataset | datasets.DatasetDict', splits: 'list[str] | None' = None, split_pattern: 'str | list[str] | None' = None) -> None
Load data from a HuggingFace Dataset or DatasetDict.
When ds is a Dataset the rows are appended to the source segment
(zero-copy reference). When ds is a DatasetDict, the splits to
load are selected by splits (exact names), split_pattern (one or
more glob patterns), or — if neither is provided — all splits. Calling
from_dataset more than once concatenates onto what is already loaded.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ds
|
'Dataset | datasets.DatasetDict'
|
A |
required |
splits
|
'list[str] | None'
|
Exact split names to include. Raises |
None
|
split_pattern
|
'str | list[str] | None'
|
Glob pattern string or list of glob pattern strings.
Every split whose name matches any pattern is included.
Uses |
None
|
Examples::
# Single split
store.from_dataset(load_dataset("org/ds", split="train"))
# All splits from a DatasetDict
store.from_dataset(load_dataset("org/ds"))
# Selected splits by exact name
store.from_dataset(load_dataset("org/ds"), splits=["train", "test"])
# Glob patterns — all train_ and eval_ splits
store.from_dataset(load_dataset("org/ds"), split_pattern=["train_*", "eval_*"])
# Single pattern
store.from_dataset(load_dataset("org/ds"), split_pattern="test_*")
Source code in mouse/data/dataset_store.py
to_dataset
¶
Return a HuggingFace Dataset of all steps.
Source-only: returns the reference directly (zero-copy). Buf-only: builds a Dataset from the persisted raw rows. Both: concatenates source and buf datasets.
Source code in mouse/data/dataset_store.py
merge_stores_to_dataset
classmethod
¶
Concatenate multiple DatasetStores into one HF Dataset.
Source code in mouse/data/dataset_store.py
PrefetchBatchifier¶
mouse.data.batch.PrefetchBatchifier
¶
PrefetchBatchifier(store: DatasetStore, sequence_length: int, batch_size: int, sampling: str = 'random', prefetch: int = 4, num_workers: int = 1, pin_memory: bool = False)
Background-thread batchifier for a DatasetStore source.
Parameters¶
store :
DatasetStore with from_dataset already called.
sequence_length :
Number of consecutive steps per sequence.
batch_size :
Number of sequences per batch.
sampling :
How start indices are chosen; see _sample_starts.
prefetch : int
Pre-encoded batches to keep ready. Higher values smooth over slow
Arrow reads; each batch costs batch_size × sequence_length encoded
steps in memory.
num_workers : int
Background threads. sequential/batch modes ignore this and
always use one worker to preserve epoch order. Pass 0 to skip
threading entirely and fetch synchronously on the calling thread.
pin_memory : bool
If True, workers call .pin_memory() on each batch before queuing
it. Pinned CPU tensors enable DMA-backed, non-blocking H2D copies on
the main thread. Only effective when the training device is CUDA.
Ignored in synchronous mode (num_workers=0).
Source code in mouse/data/batch.py
next_batch
¶
Return the next pre-encoded batch, blocking until one is ready.
Source code in mouse/data/batch.py
close
¶
Stop background workers and drain the queue.
Source code in mouse/data/batch.py
TokenAugmenter¶
mouse.data.augment.TokenAugmenter
¶
TokenAugmenter(augment: AugmentTokensConfig, max_num_actions: int, max_num_obs_discrete: int, device: device, generator: Generator | None = None)
Applies AugmentTokensConfig to a step TensorDict batch.
Call with step_stream to obtain a possibly augmented copy.
:meth:__call__ applies permutations/scalars from the stored snapshot; mask_prob
is sampled anew each call. Call :meth:update_augmentations first (required whenever
any augmentation is enabled).
Source code in mouse/data/augment.py
update_augmentations
¶
Sample permutations and scalar parameters for this batch and store them.
Source code in mouse/data/augment.py
__call__
¶
Augment a training batch; returns a new TensorDict when augmentation runs.
Requires step_stream shape [B, S] per field.
Call :meth:update_augmentations with the same batch first.
Permutations/scalars use :attr:snapshot; mask_prob is drawn fresh here.
Source code in mouse/data/augment.py
AugmentTokensConfig¶
mouse.data.augment.AugmentTokensConfig
dataclass
¶
AugmentTokensConfig(enabled: bool = True, permute_tokens: bool = False, scale_reward: AugmentScalarSpec = (lambda: AugmentScalarSpec(1.0, 0.0))(), shift_reward: AugmentScalarSpec = (lambda: AugmentScalarSpec(0.0, 0.0))(), scale_obs: AugmentScalarSpec = (lambda: AugmentScalarSpec(1.0, 0.0))(), shift_obs: AugmentScalarSpec = (lambda: AugmentScalarSpec(0.0, 0.0))(), scale_obs_image: AugmentScalarSpec = (lambda: AugmentScalarSpec(1.0, 0.0))(), shift_obs_image: AugmentScalarSpec = (lambda: AugmentScalarSpec(0.0, 0.0))(), permute_obs_discrete: bool = False, permute_action: Literal[False, 'input', 'target', 'both'] = False, permute_done: bool = False, mask_prob: AugmentMaskProbConfig = AugmentMaskProbConfig())
Optional training-time token augmentations (copied streams on train batches).
PREDICTION and COMPUTE tokens are never modified. See augment_tokens.TokenAugmenter.
mask_prob enables MLM-style random zero-masking per token type. Set enabled: false to disable all augmentations.
AugmentScalarSpec¶
mouse.data.augment.AugmentScalarSpec
dataclass
¶
AugmentScalarSpec(mean: float, std: float = 0.0, low: float | None = None, high: float | None = None)
Scalar augmentation: uniform on [low, high) per batch, or Gaussian mean + std * N(0,1).
YAML configs typically use low/high (low == high fixes at that value = identity).
Omit low/high for Gaussian; std == 0 fixes at mean.
AugmentMaskProbConfig¶
mouse.data.augment.AugmentMaskProbConfig
dataclass
¶
AugmentMaskProbConfig(action: float = 0.0, reward: float = 0.0, done: float = 0.0, obs_continuous: float = 0.0, obs_discrete: float = 0.0, obs_image: float = 0.0, time: float = 0.0)
Per-type Bernoulli mask probability (MLM-style): each eligible token row is masked i.i.d.
Masked rows replace payloads with a neutral value (0 / zero float / black pixel); step_stream
fields at those positions are aligned. PREDICTION and COMPUTE rows are never masked.