Meta-Optimization Using Sequential Experiences¶

MOUSE is a modular PyTorch library for in-context reinforcement learning. It provides the building blocks — embeddings, transformer backbones, output heads, losses, and data utilities — for training and deploying agents that adapt their behaviour by attending over their own transition history, with no weight updates at inference time.
Use MOUSE when you want to build or experiment with in-context RL agents: whether training from an offline dataset, collecting online rollouts, or combining both.
What is in-context meta-RL?¶
In standard reinforcement learning, an agent's policy is encoded in its weights — learning happens through gradient updates over many environment interactions. In meta-reinforcement learning, the goal is to produce an agent that can adapt its behavior rapidly to new tasks.
In-context meta-RL takes this further: instead of adapting through gradient steps, the agent adapts purely through its context — the sequence of transitions it has observed so far in the current episode or trial. The policy is not a fixed mapping from state to action; it is an emergent property of the transformer's attention over the history.
At each step the model sees the full sequence of past (observation, action, reward, done, time) tuples. By attending over this history the model can recognize what has and hasn't worked, infer the reward structure of the current task, and adjust its behavior accordingly — all within a single forward pass.
How MOUSE implements this¶
Each environment step is embedded into a block of tokens, one per modality (observation, action, reward, done, time). All step blocks are concatenated into a flat causal sequence, passed through a transformer backbone, and pooled back to one vector per step. Those vectors are then fed to output heads that produce action logits or Q-values.
The policy is never stored in fixed weights — it emerges from the transformer attending over the full history within a single forward pass. See the Architecture page for details.
Install¶
Creating and Uploading a Dataset¶
import gymnasium as gym
import torch
from mouse.data.dataset_store import DatasetStore
from mouse.data.hub import push_stores_to_hub
env = gym.make("FrozenLake-v1", is_slippery=True)
# max_obs_discrete_dim=1 for a single integer observation (grid cell index)
store = DatasetStore(max_action_dim=4, max_obs_discrete_dim=1)
for episode in range(500):
obs, _ = env.reset()
action = 0
reward = 0.0
done_flag = 0
for step_idx in range(200):
store.append({
"observation_discrete": [obs],
"action": action,
"reward": reward,
"done": done_flag,
"episode_step": step_idx,
})
action = env.action_space.sample()
obs, reward, terminated, truncated, _ = env.step(action)
done_flag = 1 if terminated else (2 if truncated else 0)
if terminated or truncated:
# Append the terminal transition before moving to the next episode
store.append({
"observation_discrete": [obs],
"action": action,
"reward": reward,
"done": done_flag,
"episode_step": step_idx + 1,
})
break
print(store) # DatasetStore(steps=...)
# Push to the Hugging Face Hub — creates the repo if it doesn't exist yet
push_stores_to_hub(
[store],
repo_id="your-org/your-dataset",
split="train",
private=True,
)
Multiple splits — use
push_to_hubdirectly if you want separate train/eval splits:
Offline Training Example¶
import torch
from datasets import load_dataset
from mouse.models.base import load_model
from mouse.data.dataset_store import DatasetStore
from mouse.data.batch import PrefetchBatchifier
from mouse.losses.dqn import DqnLossConfig, dqn_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load a pretrained MOUSE model and set it to training mode
model = load_model("your-org/your-model").train().to(device)
# Create the optimizer and after each optimizer step update the target Q-head
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
optimizer.register_step_post_hook(
lambda opt, args, kwargs: model.polyak_update(dqn_tau=dqn_cfg.tau)
)
# Point DatasetStore at your offline RL dataset on HuggingFace Hub
store = DatasetStore(max_action_dim=18, max_obs_continuous_dim=8)
store.from_dataset(load_dataset("your-org/your-dataset", split="train"))
# gamma=0.99 discounts future rewards; tau=0.005 controls target head lag
dqn_cfg = DqnLossConfig(weight=1.0, gamma=0.99, tau=0.005)
# Background threads pre-fetch and encode batches so next_batch() is instant
bf = PrefetchBatchifier(store, sequence_length=64, batch_size=32,
sampling="random", prefetch=4, num_workers=2,
pin_memory=True)
for step in range(100_000):
step_stream = bf.next_batch().to(device)
# Forward pass — out contains Q-values for every (step, action) pair
out, _ = model(step_stream)
loss, _ = dqn_loss(step_stream, out, dqn_cfg)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() # Polyak update fires automatically via post-step hook
bf.close()
Inference Example¶
import gymnasium as gym
import torch
from tensordict import TensorDict
from mouse.models.base import load_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("your-org/your-model").eval().to(device)
env = gym.make("FrozenLake-v1", is_slippery=True)
obs, _ = env.reset()
cache = None
step_idx = 0
action = 0
reward = 0.0
done_flag = 0
while True:
step_stream = TensorDict(
{
"action": torch.tensor([[action]], dtype=torch.long),
"reward": torch.tensor([[reward]], dtype=torch.float32),
"done": torch.tensor([[done_flag]], dtype=torch.long),
"time": torch.tensor([[step_idx]], dtype=torch.long),
"obs_discrete": torch.tensor([[obs]], dtype=torch.long),
},
batch_size=(1, 1),
)
with torch.no_grad():
out, cache = model(step_stream.to(device), cache=cache, use_cache=True)
action = model.get_action(out, temperature=0.0)[0].item()
obs, reward, terminated, truncated, _ = env.step(action)
done_flag = 1 if terminated else (2 if truncated else 0)
step_idx += 1
if terminated or truncated:
obs, _ = env.reset()
action = 0
reward = 0.0
done_flag = 0
Guides¶
- Data — downloading and processing offline RL datasets
- Model Architecture — design choices: embedding, backbone, heads
- Losses — DQN, VecDQN, SP, and SV loss functions
- Examples — training loop patterns
The full API reference is available here.