From 67a41e003dc85f1c5d3b54aaa509e5ef7c7bab42 Mon Sep 17 00:00:00 2001 From: dnddnjs Date: Mon, 8 Jun 2026 09:46:33 +0900 Subject: [PATCH] 4-atari-hard: Go-Explore (exploration phase) on Montezuma's Revenge + benchmark Go-Explore phase 1 (Ecoffet et al. 2019 / Nature 2021), no neural net: an archive of downscaled-frame cells (11x8, 9 gray levels), emulator state save/restore to return to frontier cells, repeated random actions to explore from them. 12 explorer processes over raw gymnasium ALE (envpool exposes no clone API, hence the separate env_go_explore.py). Result: best end-of-episode score 31,000 at 500M agent steps (~5.5h on a Mac Studio M4 Max), single seed, replay-verified (re-executing the stored 5,336-action demo from reset reproduces the score exactly). Deterministic protocol (no sticky actions) -- a trajectory-search result, not an RL policy score; see the README caveat. --- 4-atari-hard/2-go-explore.py | 438 +++++++++++++++++++++++++++++++++ 4-atari-hard/env_go_explore.py | 152 ++++++++++++ README.md | 10 + 3 files changed, 600 insertions(+) create mode 100644 4-atari-hard/2-go-explore.py create mode 100644 4-atari-hard/env_go_explore.py diff --git a/4-atari-hard/2-go-explore.py b/4-atari-hard/2-go-explore.py new file mode 100644 index 0000000..d9790e0 --- /dev/null +++ b/4-atari-hard/2-go-explore.py @@ -0,0 +1,438 @@ +"""Go-Explore Phase 1 (exploration phase) for Montezuma's Revenge. + +Ecoffet et al., 2019: "Go-Explore: a New Approach for Hard-Exploration +Problems" (arXiv:1901.10995); Nature 2021 version "First return, then +explore" (arXiv:2004.12919). No neural network: intrinsic-motivation methods +(RND etc.) suffer from detachment (forgetting promising frontiers) and +derailment (exploration noise breaking the return trip). Go-Explore fixes +both mechanically — remember everything in an archive, and RETURN exactly +via emulator state restore, then explore from there: + + archive: cell -> (best trajectory reaching it, emulator snapshot, score) + loop: sample cells (novelty-weighted) -> restore -> random exploration + -> add/update cells reached + +Design notes (verified against the official uber-research/go-explore code): + + 1. Cell key = grayscale frame -> cv2.resize to 11x8 (INTER_AREA) -> + quantize to 9 levels: floor(8 * p / 255). 88-byte key. + 2. Selection weight = 1 / sqrt(seen_times + 1) (Nature simplification); + sampling WITH replacement, batch of 100; the virtual DONE cell is + never selected. + 3. Exploration from a restored cell: up to K=100 agent steps, repeated + random actions (keep current action w.p. 0.95 -> geometric runs, + mean 20). Episode end = LIFE LOSS (or game over) -> the transition + maps to the DONE cell and the exploration episode aborts. + 4. Archive accept rule: replace/insert iff score is higher, or equal + score with a shorter trajectory. Scores are raw and unclipped. + On update the cell's counters reset and its snapshot/trajectory are + replaced; the *chosen* cell's chosen_since_new resets when anything + new is found. + 5. Trajectories are not stored per cell: a global append-only experience + log (prev_id linked list) + per-cell traj_last pointers reconstruct + any cell's action sequence — this is the demo source for a future + robustification phase, so the log is flushed to compressed chunks in + the run dir rather than discarded. + 6. ★ ALE pitfall (machine-verified): post-restore RAM/screen reads are + STALE until the next act. Cell keys come only from frames returned by + env.step(); the lives baseline travels in cell metadata. + 7. frames axis = agent steps actually EXECUTED by workers (frameskip 4 + applied; the hypothetical "replay from start" steps are not counted), + matching the harness budget/tier semantics. + 8. Phase-1 caveat: the score is a deterministic trajectory-search result, + NOT an RL-policy score. Never compare against sticky-action RL + numbers (e.g. the RND campaign) without this caveat. +""" +import multiprocessing as mp +import os +import pickle +import time + +import cv2 +import numpy as np + +from env_go_explore import ENV_IDS, RunLogger, make_restore_env, parse_args + + +TOTAL_FRAMES = 5_000_000 # agent steps executed (override with --total-frames) +BATCH_CELLS = 100 # cells sampled (with replacement) per iteration +EXPLORE_STEPS = 100 # K: max agent steps per exploration episode +ACTION_REPEAT_P = 0.95 # keep current action w.p. 0.95 (geometric, mean 20) +CELL_W, CELL_H = 11, 8 # downscale resolution (official fixed setting) +CELL_LEVELS = 8 # quantize to floor(8*p/255) -> values 0..8 +N_WORKERS = 12 # M4 Max 16 cores: leave headroom for master + OS +LOG_EVERY_BATCHES = 10 # metrics.jsonl cadence (~1M steps/100s at full speed) +EXPLOG_CHUNK = 1 << 22 # 4M entries per experience-log chunk (~40MB in RAM) +ROOM_RAM_BYTE = 3 # Montezuma current-room RAM index (diagnostic only) +DONE_KEY = (b"DONE", True) # virtual end-of-episode cell (never sampled) + + +def cell_key(frame): + """(210, 160) uint8 grayscale frame -> 88-byte archive key.""" + small = cv2.resize(frame, (CELL_W, CELL_H), interpolation=cv2.INTER_AREA) + return ((small / 255.0) * CELL_LEVELS).astype(np.uint8).tobytes() + + +class Cell: + """Archive entry. snapshot/lives describe the state AT this cell so a + worker can restore and keep exploring; traj_last points into the + experience log for trajectory reconstruction.""" + __slots__ = ("snapshot", "score", "traj_len", "traj_last", + "seen", "chosen", "chosen_since_new", "lives") + + def __init__(self, snapshot, score, traj_len, traj_last, lives): + self.snapshot = snapshot + self.score = score + self.traj_len = traj_len + self.traj_last = traj_last + self.lives = lives + self.seen = self.chosen = self.chosen_since_new = 0 + + +class ExperienceLog: + """Append-only step log as a prev_id linked list (design note 5). + + RAM holds only the active chunk; full chunks flush to + /chunk_NNNNN.npz (compressed — rewards/dones are almost all zero). + With dir=None (probes/tests) full chunks stay in RAM instead.""" + + def __init__(self, log_dir, chunk_size=EXPLOG_CHUNK, ancestor_dir=None): + self.dir = log_dir + if log_dir: + os.makedirs(log_dir, exist_ok=True) + self.chunk_size = chunk_size + self.ancestor = ancestor_dir # explog dir of the run we resumed FROM: + self.count = 0 # chunks flushed before the resume live there + self.n_flushed = 0 + self._ram_chunks = [] # dir=None mode only + self._cache = {} # chunk_idx -> loaded arrays (reconstruction) + self._new_chunk() + + def _new_chunk(self): + n = self.chunk_size + self.prev = np.empty(n, dtype=np.int64) + self.act = np.empty(n, dtype=np.uint8) + self.rew = np.empty(n, dtype=np.float32) + self.done = np.empty(n, dtype=np.uint8) + self.fill = 0 + + def append(self, prev_id, action, reward, done): + i = self.fill + self.prev[i], self.act[i], self.rew[i], self.done[i] = prev_id, action, reward, done + self.fill += 1 + idx = self.count + self.count += 1 + if self.fill == self.chunk_size: + self._flush() + return idx + + def _flush(self): + arrays = {"prev": self.prev[:self.fill], "act": self.act[:self.fill], + "rew": self.rew[:self.fill], "done": self.done[:self.fill]} + if self.dir: + tmp = os.path.join(self.dir, f"chunk_{self.n_flushed:05d}.tmp") + np.savez_compressed(tmp, **arrays) + os.replace(f"{tmp}.npz", os.path.join(self.dir, f"chunk_{self.n_flushed:05d}.npz")) + else: + self._ram_chunks.append({k: v.copy() for k, v in arrays.items()}) + self.n_flushed += 1 + self._new_chunk() + + def _chunk_path(self, chunk_idx): + """A flushed chunk lives in our own dir, or (after a cross-run-dir + resume) in the ancestor run's explog dir.""" + own = os.path.join(self.dir, f"chunk_{chunk_idx:05d}.npz") + if os.path.exists(own): + return own + if self.ancestor: + anc = os.path.join(self.ancestor, f"chunk_{chunk_idx:05d}.npz") + if os.path.exists(anc): + return anc + raise RuntimeError(f"explog chunk {chunk_idx} not found in {self.dir}" + + (f" or {self.ancestor}" if self.ancestor else "")) + + def _chunk(self, chunk_idx): + if chunk_idx == self.n_flushed: + return {"prev": self.prev, "act": self.act} + if self.dir: + if chunk_idx not in self._cache: + z = np.load(self._chunk_path(chunk_idx)) + self._cache[chunk_idx] = {"prev": z["prev"], "act": z["act"]} + return self._cache[chunk_idx] + return self._ram_chunks[chunk_idx] + + def reconstruct_actions(self, last_id): + """Walk the prev_id chain back to the root (-1); return actions in + forward order. This is how demos are rebuilt for replay/Phase 2.""" + actions = [] + idx = last_id + while idx >= 0: + c = self._chunk(idx // self.chunk_size) + off = idx % self.chunk_size + actions.append(int(c["act"][off])) + idx = int(c["prev"][off]) + return actions[::-1] + + def state(self): + return {"count": self.count, "n_flushed": self.n_flushed, + "chunk_size": self.chunk_size, + "cur_prev": self.prev[:self.fill].copy(), "cur_act": self.act[:self.fill].copy(), + "cur_rew": self.rew[:self.fill].copy(), "cur_done": self.done[:self.fill].copy()} + + def load_state(self, st): + assert st["chunk_size"] == self.chunk_size, "explog chunk_size mismatch" + if self.dir: # flushed chunks must be reachable (own dir or ancestor's) + self.n_flushed = st["n_flushed"] + for i in range(st["n_flushed"]): + self._chunk_path(i) # raises loudly if a chunk is missing + self.count, self.n_flushed = st["count"], st["n_flushed"] + self._new_chunk() + n = len(st["cur_prev"]) + self.prev[:n], self.act[:n] = st["cur_prev"], st["cur_act"] + self.rew[:n], self.done[:n] = st["cur_rew"], st["cur_done"] + self.fill = n + + +class Archive: + """Cell store + novelty-weighted selection + the accept rule. All updates + happen serially in the master process.""" + + def __init__(self): + self.cells = {} # (key_bytes, done_bool) -> Cell + self.rooms = set() # diagnostic only (RAM byte 3) + self.done_scores = [] # recent end-of-episode scores (logging) + + def seed_root(self, key, snapshot, lives): + self.cells[(key, False)] = Cell(snapshot, 0.0, 0, -1, lives) + + @property + def best_done_score(self): + c = self.cells.get(DONE_KEY) + return c.score if c else float("-inf") + + @property + def max_archive_score(self): + return max(c.score for k, c in self.cells.items() if k != DONE_KEY) + + def sample(self, n, rng): + """n cells with replacement, p ∝ 1/sqrt(seen+1); DONE excluded. + + Returns (key, CAPTURE) pairs that freeze the cell's snapshot/score/ + trajectory AT SAMPLING TIME. The trajectory walk must use the capture, + never the live cell: an earlier result in the same batch may replace + the cell, and stitching actions executed from the OLD state onto the + NEW prefix fabricates scores no single playthrough achieved + (2026-06-08 incident — caught by publish-time replay verification; + the official code ships these values inside the task for this reason).""" + keys = [k for k in self.cells if k != DONE_KEY] + w = np.array([1.0 / np.sqrt(self.cells[k].seen + 1.0) for k in keys]) + csum = np.cumsum(w) + picks = [] + for u in rng.random(n) * csum[-1]: + k = keys[min(int(np.searchsorted(csum, u)), len(keys) - 1)] + c = self.cells[k] + c.chosen += 1 + c.chosen_since_new += 1 + picks.append((k, {"snapshot": c.snapshot, "lives": c.lives, "score": c.score, + "traj_len": c.traj_len, "traj_last": c.traj_last})) + return picks + + def update_from_trajectory(self, chosen_key, capture, res, explog): + """Walk one exploration episode (master-side, serial): append to the + experience log, accumulate raw score from the SAMPLING-TIME capture + (never the live cell — see sample()), apply the accept rule (note 4).""" + chosen = self.cells.get(chosen_key) + cur_score = capture["score"] + cur_len = capture["traj_len"] + prev_id = capture["traj_last"] + found_new = False + seen_this_episode = set() + + for i in range(res["n_steps"]): + prev_id = explog.append(prev_id, res["actions"][i], res["rewards"][i], res["dones"][i]) + cur_score += float(res["rewards"][i]) + cur_len += 1 + done = bool(res["dones"][i]) + key = DONE_KEY if done else (res["keys"][i], False) + + cell = self.cells.get(key) + if cell is None: + self.cells[key] = Cell(res["snapshots"][i], cur_score, cur_len, prev_id, + res["lives"][i]) + self.cells[key].seen = 1 + seen_this_episode.add(key) + found_new = True + else: + if key not in seen_this_episode: + cell.seen += 1 + seen_this_episode.add(key) + if cur_score > cell.score or (cur_score == cell.score and cur_len < cell.traj_len): + cell.snapshot = res["snapshots"][i] + cell.score, cell.traj_len, cell.traj_last = cur_score, cur_len, prev_id + cell.lives = res["lives"][i] + cell.seen = cell.chosen = cell.chosen_since_new = 0 # reset_cell_on_update + found_new = True + if done: + self.done_scores.append(cur_score) + break + + if found_new and chosen is not None: + chosen.chosen_since_new = 0 + self.rooms.update(res["rooms"]) + + def state(self): + return {"cells": {k: {"snapshot": c.snapshot, "score": c.score, + "traj_len": c.traj_len, "traj_last": c.traj_last, + "seen": c.seen, "chosen": c.chosen, + "chosen_since_new": c.chosen_since_new, "lives": c.lives} + for k, c in self.cells.items()}, + "rooms": sorted(self.rooms), "done_scores": self.done_scores[-200:]} + + def load_state(self, st): + self.cells = {} + for k, d in st["cells"].items(): + c = Cell(d["snapshot"], d["score"], d["traj_len"], d["traj_last"], d["lives"]) + c.seen, c.chosen, c.chosen_since_new = d["seen"], d["chosen"], d["chosen_since_new"] + self.cells[k] = c + self.rooms = set(st["rooms"]) + self.done_scores = list(st["done_scores"]) + + +# --------------------------------------------------------------------------- +# Worker side. Top-level functions: mp 'spawn' re-imports this module, so the +# main body below stays behind the __main__ guard. Each worker owns one env. +# --------------------------------------------------------------------------- +_W = {} + + +def _worker_init(env_key): + _W["env"] = make_restore_env(env_key) + _W["ale"] = _W["env"].ale + + +def _explore_task(task): + """task = (snapshot bytes | None for root reset, lives, k, seed). + Restore -> up to k steps of repeated random actions; abort on life loss / + game over. Returns per-step arrays (keys/snapshots for archive insert).""" + snapshot, prev_lives, k, seed = task + env, ale = _W["env"], _W["ale"] + rng = np.random.default_rng(seed) + if snapshot is None: + env.reset(seed=0) + else: + ale.restoreState(pickle.loads(snapshot)) + # design note 6: NO reads here — the restored state is stale until we act + + n_actions = env.action_space.n + actions, rewards, dones, keys, snapshots, lives_list, rooms = [], [], [], [], [], [], set() + action = int(rng.integers(n_actions)) + for _ in range(k): + if rng.random() > ACTION_REPEAT_P: + action = int(rng.integers(n_actions)) + frame, reward, terminated, truncated, _ = env.step(action) + lives = ale.lives() + done = bool(terminated) or lives < prev_lives + actions.append(action) + rewards.append(float(reward)) + dones.append(done) + keys.append(cell_key(frame)) + snapshots.append(pickle.dumps(ale.cloneState())) + lives_list.append(lives) + rooms.add(int(ale.getRAM()[ROOM_RAM_BYTE])) + if done: + break + prev_lives = lives + return {"n_steps": len(actions), "actions": actions, "rewards": rewards, + "dones": dones, "keys": keys, "snapshots": snapshots, + "lives": lives_list, "rooms": rooms} + + +if __name__ == "__main__": + args = parse_args() + if args.total_frames: + TOTAL_FRAMES = args.total_frames + if args.n_workers: + N_WORKERS = args.n_workers + seed = args.seed if args.seed is not None else 0 + rng = np.random.default_rng(seed) + + logger = RunLogger(args.run_dir, args.ckpt_every) + explog_dir = os.path.join(args.run_dir, "explog") if args.run_dir else None + # cross-run-dir resume: flushed explog chunks live next to the checkpoint + # we resume from (the harness relaunches into a fresh run dir) + resume_path = logger.resolve_resume(args.resume) + ancestor = (os.path.join(os.path.dirname(os.path.dirname(resume_path)), "explog") + if resume_path else None) + explog = ExperienceLog(explog_dir, ancestor_dir=ancestor) + archive = Archive() + frames = 0 + batch = 0 + + def _state_fn(): + return {"version": 1, "frames": frames, "batch": batch, + "archive": archive.state(), "explog": explog.state(), + "rng": rng.bit_generator.state} + + # --- resume or seed the root cell --- + if resume_path: + import torch + ckpt = torch.load(resume_path, map_location="cpu", weights_only=False) + frames, batch = ckpt["frames"], ckpt["batch"] + archive.load_state(ckpt["archive"]) + explog.load_state(ckpt["explog"]) + rng.bit_generator.state = ckpt["rng"] + print(f"resumed from {resume_path}: frames={frames} batch={batch} " + f"cells={len(archive.cells)} explog={explog.count}") + else: + # root cell from a fresh reset (reset obs is NOT stale — note 6 only + # applies to restores) + env0 = make_restore_env(args.env) + frame0, _ = env0.reset(seed=0) + archive.seed_root(cell_key(np.asarray(frame0)), pickle.dumps(env0.ale.cloneState()), + env0.ale.lives()) + env0.close() + + print(f"env: {args.env} workers: {N_WORKERS} total_frames: {TOTAL_FRAMES:,} seed: {seed}") + ctx = mp.get_context("spawn") + t_start = time.time() + with ctx.Pool(N_WORKERS, initializer=_worker_init, initargs=(args.env,)) as pool: + while frames < TOTAL_FRAMES: + picks = archive.sample(BATCH_CELLS, rng) # (key, sampling-time capture) + tasks = [(cap["snapshot"], cap["lives"], EXPLORE_STEPS, int(rng.integers(2 ** 31))) + for _, cap in picks] + results = pool.map(_explore_task, tasks, chunksize=2) # ordered -> deterministic + for (key, cap), res in zip(picks, results): + archive.update_from_trajectory(key, cap, res, explog) + frames += res["n_steps"] + batch += 1 + + if batch % LOG_EVERY_BATCHES == 0: + best = archive.best_done_score + gate = best if best != float("-inf") else 0.0 + tail = archive.done_scores[-20:] + print(f"batch {batch:>6} frames {frames:>11,} cells {len(archive.cells):>6} " + f"best_done {gate:>8.0f} max_arch {archive.max_archive_score:>8.0f} " + f"rooms {len(archive.rooms):>3}") + logger.log(frames, { + "game_return_mean_lastK": gate, # semantics: best end-of-episode score (K=1) + "ep_return_mean": float(np.mean(tail)) if tail else 0.0, + "game_return_count": len(archive.done_scores), + "best_done_score": gate, + "max_archive_score": archive.max_archive_score, + "n_cells": len(archive.cells), + "rooms_found": len(archive.rooms), + "explog_entries": explog.count, + "batch": batch, + "nan_flag": 0, + }) + logger.checkpoint(frames, _state_fn, + gate=gate if gate > 0 else None) + + best = archive.best_done_score + final_score = best if best != float("-inf") else 0.0 + hours = (time.time() - t_start) / 3600 + print(f"done: frames {frames:,} cells {len(archive.cells)} rooms {len(archive.rooms)} " + f"best_done {final_score:.0f} ({hours:.2f}h)") + # final.json value_mean = best end-of-episode score (the official Phase-1 + # metric; see targets.yaml montezuma_goexplore for the protocol caveat) + logger.finalize(frames, [final_score], _state_fn, k=1) diff --git a/4-atari-hard/env_go_explore.py b/4-atari-hard/env_go_explore.py new file mode 100644 index 0000000..1e5420b --- /dev/null +++ b/4-atari-hard/env_go_explore.py @@ -0,0 +1,152 @@ +"""Go-Explore env setup (restore-based exploration on raw gymnasium ALE). + +Separate plumbing from this folder's `env.py` (the PPO/RND envpool stack): +Go-Explore's exploration phase needs the emulator's save/restore API +(ale.cloneState / restoreState), which envpool does not expose. Each +(worker) process owns a single raw ALE env built by `make_restore_env`. +The harness binds promotion markers to the script PLUS the sibling modules +it actually imports, so 2-go-explore.py is hashed with THIS file, not env.py. + +Protocol (Ecoffet et al. 2019/2021, exploration phase): fully deterministic — +frameskip 4, NO sticky actions, no no-ops, seed 0. Stochasticity only enters +in the (separate, later) robustification phase. The TimeLimit wrapper is +stripped (`.unwrapped`): its step counter is meaningless when episodes are +entered mid-trajectory via state restore. + +★ Verified ALE pitfall (this machine, ale-py 0.11.2): right after +`restoreState`, `getRAM()` / screen reads still return the PRE-restore values +until the next `act()`. Callers must therefore derive cell keys only from +frames returned by `env.step()`, never from immediate post-restore reads. +""" +import argparse +import json +import os +import statistics +import time + +import torch # checkpoint serialization only — there is no neural net here + + +def _atomic_save(state, path): + """tmp -> rename so a crash mid-write never corrupts the checkpoint.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = f"{path}.tmp" + torch.save(state, tmp) + os.replace(tmp, path) + + +class RunLogger: + """Optional run-directory outputs: metrics.jsonl, periodic / milestone / + best checkpoints, resume, and a final.json summary. Inert when run_dir is + None, so the script still runs standalone. + + Same contract as 4-atari-hard/env.py with one change: milestone + checkpoints fire every 50M frames instead of 5M — a Go-Explore checkpoint + carries the whole archive (~0.5 GB at 50k cells), and a 500M-step run + would otherwise pile up 100 of them.""" + + MILESTONE_EVERY = 50_000_000 + + def __init__(self, run_dir, ckpt_every): + self.dir = run_dir + self.ckpt_dir = os.path.join(run_dir, "ckpt") if run_dir else None + self.ckpt_every = ckpt_every + if self.ckpt_dir: + os.makedirs(self.ckpt_dir, exist_ok=True) + self.f = open(os.path.join(run_dir, "metrics.jsonl"), "a", buffering=1) if run_dir else None + self.t0, self.last_frames = time.time(), 0 + self.ckpt_last, self.ms_last, self.best = 0, 0, float("-inf") + + def log(self, frames, scalars): + """Append one structured row (frames + sps + caller's scalars) to metrics.jsonl.""" + if not self.f: + return + now = time.time() + sps = (frames - self.last_frames) / max(now - self.t0, 1e-9) + self.f.write(json.dumps({"ts": round(now, 1), "frames": frames, "sps": round(sps, 1), **scalars}) + "\n") + self.t0, self.last_frames = now, frames + + def resolve_resume(self, resume_arg): + """'auto' -> run_dir/ckpt/latest.pt, else a path, else None.""" + if resume_arg == "auto" and self.ckpt_dir: + cand = os.path.join(self.ckpt_dir, "latest.pt") + return cand if os.path.exists(cand) else None + if resume_arg and resume_arg != "auto": + return resume_arg if os.path.exists(resume_arg) else None + return None + + def checkpoint(self, frames, state_fn, gate=None): + """Periodic 'latest', 50M-step milestone, and best-gate checkpoints. + state_fn() builds the dict only when a save actually happens.""" + if not self.ckpt_dir or not self.ckpt_every: + return + if frames - self.ckpt_last >= self.ckpt_every: + _atomic_save(state_fn(), os.path.join(self.ckpt_dir, "latest.pt")) + self.ckpt_last = frames + if frames - self.ms_last >= self.MILESTONE_EVERY: + _atomic_save(state_fn(), os.path.join(self.ckpt_dir, f"step_{frames // 1_000_000}M.pt")) + self.ms_last = frames + if gate is not None and gate > self.best: + self.best = gate + _atomic_save(state_fn(), os.path.join(self.ckpt_dir, "best.pt")) + + def finalize(self, frames, game_returns, state_fn, k=100): + """Final 'latest' checkpoint + a final.json result summary.""" + if self.ckpt_dir: + _atomic_save(state_fn(), os.path.join(self.ckpt_dir, "latest.pt")) + if self.dir: + tail = [float(x) for x in game_returns[-k:]] + with open(os.path.join(self.dir, "final.json"), "w") as fh: + json.dump({"frames_total": frames, "frames_unit": "agent_steps", + "gate_metric": "game_return_mean_lastK", "K": k, + "value_mean": statistics.fmean(tail) if tail else float("nan"), + "value_std": statistics.pstdev(tail) if len(tail) > 1 else 0.0, + "episodes_counted": len(tail)}, fh, indent=1) + if self.f: + self.f.close() + + +# Gymnasium / ALE ids. The "_goexplore" key marks a distinct benchmark +# protocol (deterministic, no sticky) — never cross-compare with the +# sticky-action `montezuma` numbers elsewhere in this repo. +ENV_IDS = { + "montezuma_goexplore": "ALE/MontezumaRevenge-v5", +} + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--env", choices=list(ENV_IDS), default="montezuma_goexplore", + help="which game to explore") + # --- reproducibility / run-management flags (harness run contract) --- + p.add_argument("--seed", type=int, default=None, + help="seed for the action RNG (the emulator itself is deterministic)") + p.add_argument("--total-frames", type=int, default=None, + help="override the in-file TOTAL_FRAMES budget (agent steps actually executed)") + p.add_argument("--n-workers", type=int, default=None, + help="override the in-file N_WORKERS (parallel explorer processes)") + p.add_argument("--run-dir", type=str, default=None, + help="run directory: write metrics.jsonl / ckpt / final.json here") + p.add_argument("--ckpt-every", type=int, default=None, + help="periodic checkpoint interval in agent steps (resume-safe)") + p.add_argument("--resume", type=str, default=None, + help="'auto' (run-dir/ckpt/latest.pt) or a checkpoint path") + return p.parse_args() + + +def make_restore_env(env_key): + """Single raw ALE env with clone/restore access. + + Imports live here (not module top) so harness-side tests can stub this + module without pulling in ale_py. Returns the unwrapped env: TimeLimit's + step counter would spuriously truncate restore-based exploration, and + OrderEnforcing rejects step-after-restore patterns.""" + import ale_py + import gymnasium as gym + gym.register_envs(ale_py) + env = gym.make(ENV_IDS[env_key], frameskip=4, + repeat_action_probability=0.0, # deterministic — Phase 1 requirement + obs_type="grayscale").unwrapped + env.reset(seed=0) # canonical deterministic start; variation comes from action RNGs + assert env.spec.kwargs.get("repeat_action_probability", None) == 0.0 + return env diff --git a/README.md b/README.md index c27e5ac..c0b17cc 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,16 @@ Trained on a **Mac Studio (Apple M4 Max, 64 GB)** — different hardware from th > Random Network Distillation (Burda et al. 2018) for hard exploration. With 512 envs the first key is found reliably (~327k steps) and the extrinsic value bootstraps around 10M steps; with 128 envs the same code never scored in 50M steps — parallel breadth is what cracks the first-key bottleneck. Stopped at ~65M agent steps after the score plateaued **above the paper's PPO baseline (2497)**; not run to a fixed budget. Still far below RND's headline 8152, which used 128–1024 envs × 1.97B frames (~30× more experience). `Params` = trainable weights (actor-critic 1.69M + RND predictor 2.20M; the frozen RND target adds 1.68M). Single seed, so no ± std — a 3-seed run is the next step for a defensible number. +### Atari — Montezuma's Revenge (Go-Explore, exploration phase) + +Same hardware as the PPO + RND row above (Mac Studio, M4 Max, 64 GB). `ALE/MontezumaRevenge-v5` under the **deterministic protocol** (no sticky actions, frameskip 4, fixed seed) that restore-based exploration requires — these numbers are **not comparable** to the sticky-action RL rows above. 12 parallel explorer processes, no neural network. + +| Algorithm | Params | Train time | Best end-of-episode score | Frames | W&B | +|-----------|--------|------------|---------------------------|--------|-----| +| Go-Explore (exploration phase) | — (no NN) | ~5.5h | 31,000 (single seed) | 500M agent steps | [run](https://wandb.ai/rlcode/rl-atari-hard-go-explore/runs/m6ox4l3m) | + +> Go-Explore phase 1 (Ecoffet et al. 2019 / Nature 2021): an archive of downscaled-frame cells (11×8 pixels, 9 gray levels — no domain knowledge), emulator state save/restore to *return* to frontier cells, and repeated random actions to explore from them. The score is the best **end-of-episode** trajectory found by deterministic search, and it is **replay-verified**: re-executing the stored 5,336-action sequence from reset reproduces exactly 31,000. It is a trajectory-search result, **not an RL policy score** — the paper's robustification phase (distilling demos into a policy under sticky actions) is not run here. For reference, the Nature exploration-phase mean without domain knowledge is 24,758 at the same 2 B-frame budget (50+ seeds vs our single seed). Rooms found: 24. + ## Setup Requires Python 3.11 and [uv](https://docs.astral.sh/uv/).