diff --git a/4-atari-hard/2-go-explore.py b/4-atari-hard/2-go-explore.py
new file mode 100644
index 0000000..d9790e0
--- /dev/null
+++ b/4-atari-hard/2-go-explore.py
@@ -0,0 +1,438 @@
+"""Go-Explore Phase 1 (exploration phase) for Montezuma's Revenge.
+
+Ecoffet et al., 2019: "Go-Explore: a New Approach for Hard-Exploration
+Problems" (arXiv:1901.10995); Nature 2021 version "First return, then
+explore" (arXiv:2004.12919). No neural network: intrinsic-motivation methods
+(RND etc.) suffer from detachment (forgetting promising frontiers) and
+derailment (exploration noise breaking the return trip). Go-Explore fixes
+both mechanically — remember everything in an archive, and RETURN exactly
+via emulator state restore, then explore from there:
+
+ archive: cell -> (best trajectory reaching it, emulator snapshot, score)
+ loop: sample cells (novelty-weighted) -> restore -> random exploration
+ -> add/update cells reached
+
+Design notes (verified against the official uber-research/go-explore code):
+
+ 1. Cell key = grayscale frame -> cv2.resize to 11x8 (INTER_AREA) ->
+ quantize to 9 levels: floor(8 * p / 255). 88-byte key.
+ 2. Selection weight = 1 / sqrt(seen_times + 1) (Nature simplification);
+ sampling WITH replacement, batch of 100; the virtual DONE cell is
+ never selected.
+ 3. Exploration from a restored cell: up to K=100 agent steps, repeated
+ random actions (keep current action w.p. 0.95 -> geometric runs,
+ mean 20). Episode end = LIFE LOSS (or game over) -> the transition
+ maps to the DONE cell and the exploration episode aborts.
+ 4. Archive accept rule: replace/insert iff score is higher, or equal
+ score with a shorter trajectory. Scores are raw and unclipped.
+ On update the cell's counters reset and its snapshot/trajectory are
+ replaced; the *chosen* cell's chosen_since_new resets when anything
+ new is found.
+ 5. Trajectories are not stored per cell: a global append-only experience
+ log (prev_id linked list) + per-cell traj_last pointers reconstruct
+ any cell's action sequence — this is the demo source for a future
+ robustification phase, so the log is flushed to compressed chunks in
+ the run dir rather than discarded.
+ 6. ★ ALE pitfall (machine-verified): post-restore RAM/screen reads are
+ STALE until the next act. Cell keys come only from frames returned by
+ env.step(); the lives baseline travels in cell metadata.
+ 7. frames axis = agent steps actually EXECUTED by workers (frameskip 4
+ applied; the hypothetical "replay from start" steps are not counted),
+ matching the harness budget/tier semantics.
+ 8. Phase-1 caveat: the score is a deterministic trajectory-search result,
+ NOT an RL-policy score. Never compare against sticky-action RL
+ numbers (e.g. the RND campaign) without this caveat.
+"""
+import multiprocessing as mp
+import os
+import pickle
+import time
+
+import cv2
+import numpy as np
+
+from env_go_explore import ENV_IDS, RunLogger, make_restore_env, parse_args
+
+
+TOTAL_FRAMES = 5_000_000 # agent steps executed (override with --total-frames)
+BATCH_CELLS = 100 # cells sampled (with replacement) per iteration
+EXPLORE_STEPS = 100 # K: max agent steps per exploration episode
+ACTION_REPEAT_P = 0.95 # keep current action w.p. 0.95 (geometric, mean 20)
+CELL_W, CELL_H = 11, 8 # downscale resolution (official fixed setting)
+CELL_LEVELS = 8 # quantize to floor(8*p/255) -> values 0..8
+N_WORKERS = 12 # M4 Max 16 cores: leave headroom for master + OS
+LOG_EVERY_BATCHES = 10 # metrics.jsonl cadence (~1M steps/100s at full speed)
+EXPLOG_CHUNK = 1 << 22 # 4M entries per experience-log chunk (~40MB in RAM)
+ROOM_RAM_BYTE = 3 # Montezuma current-room RAM index (diagnostic only)
+DONE_KEY = (b"DONE", True) # virtual end-of-episode cell (never sampled)
+
+
+def cell_key(frame):
+ """(210, 160) uint8 grayscale frame -> 88-byte archive key."""
+ small = cv2.resize(frame, (CELL_W, CELL_H), interpolation=cv2.INTER_AREA)
+ return ((small / 255.0) * CELL_LEVELS).astype(np.uint8).tobytes()
+
+
+class Cell:
+ """Archive entry. snapshot/lives describe the state AT this cell so a
+ worker can restore and keep exploring; traj_last points into the
+ experience log for trajectory reconstruction."""
+ __slots__ = ("snapshot", "score", "traj_len", "traj_last",
+ "seen", "chosen", "chosen_since_new", "lives")
+
+ def __init__(self, snapshot, score, traj_len, traj_last, lives):
+ self.snapshot = snapshot
+ self.score = score
+ self.traj_len = traj_len
+ self.traj_last = traj_last
+ self.lives = lives
+ self.seen = self.chosen = self.chosen_since_new = 0
+
+
+class ExperienceLog:
+ """Append-only step log as a prev_id linked list (design note 5).
+
+ RAM holds only the active chunk; full chunks flush to
+
/chunk_NNNNN.npz (compressed — rewards/dones are almost all zero).
+ With dir=None (probes/tests) full chunks stay in RAM instead."""
+
+ def __init__(self, log_dir, chunk_size=EXPLOG_CHUNK, ancestor_dir=None):
+ self.dir = log_dir
+ if log_dir:
+ os.makedirs(log_dir, exist_ok=True)
+ self.chunk_size = chunk_size
+ self.ancestor = ancestor_dir # explog dir of the run we resumed FROM:
+ self.count = 0 # chunks flushed before the resume live there
+ self.n_flushed = 0
+ self._ram_chunks = [] # dir=None mode only
+ self._cache = {} # chunk_idx -> loaded arrays (reconstruction)
+ self._new_chunk()
+
+ def _new_chunk(self):
+ n = self.chunk_size
+ self.prev = np.empty(n, dtype=np.int64)
+ self.act = np.empty(n, dtype=np.uint8)
+ self.rew = np.empty(n, dtype=np.float32)
+ self.done = np.empty(n, dtype=np.uint8)
+ self.fill = 0
+
+ def append(self, prev_id, action, reward, done):
+ i = self.fill
+ self.prev[i], self.act[i], self.rew[i], self.done[i] = prev_id, action, reward, done
+ self.fill += 1
+ idx = self.count
+ self.count += 1
+ if self.fill == self.chunk_size:
+ self._flush()
+ return idx
+
+ def _flush(self):
+ arrays = {"prev": self.prev[:self.fill], "act": self.act[:self.fill],
+ "rew": self.rew[:self.fill], "done": self.done[:self.fill]}
+ if self.dir:
+ tmp = os.path.join(self.dir, f"chunk_{self.n_flushed:05d}.tmp")
+ np.savez_compressed(tmp, **arrays)
+ os.replace(f"{tmp}.npz", os.path.join(self.dir, f"chunk_{self.n_flushed:05d}.npz"))
+ else:
+ self._ram_chunks.append({k: v.copy() for k, v in arrays.items()})
+ self.n_flushed += 1
+ self._new_chunk()
+
+ def _chunk_path(self, chunk_idx):
+ """A flushed chunk lives in our own dir, or (after a cross-run-dir
+ resume) in the ancestor run's explog dir."""
+ own = os.path.join(self.dir, f"chunk_{chunk_idx:05d}.npz")
+ if os.path.exists(own):
+ return own
+ if self.ancestor:
+ anc = os.path.join(self.ancestor, f"chunk_{chunk_idx:05d}.npz")
+ if os.path.exists(anc):
+ return anc
+ raise RuntimeError(f"explog chunk {chunk_idx} not found in {self.dir}"
+ + (f" or {self.ancestor}" if self.ancestor else ""))
+
+ def _chunk(self, chunk_idx):
+ if chunk_idx == self.n_flushed:
+ return {"prev": self.prev, "act": self.act}
+ if self.dir:
+ if chunk_idx not in self._cache:
+ z = np.load(self._chunk_path(chunk_idx))
+ self._cache[chunk_idx] = {"prev": z["prev"], "act": z["act"]}
+ return self._cache[chunk_idx]
+ return self._ram_chunks[chunk_idx]
+
+ def reconstruct_actions(self, last_id):
+ """Walk the prev_id chain back to the root (-1); return actions in
+ forward order. This is how demos are rebuilt for replay/Phase 2."""
+ actions = []
+ idx = last_id
+ while idx >= 0:
+ c = self._chunk(idx // self.chunk_size)
+ off = idx % self.chunk_size
+ actions.append(int(c["act"][off]))
+ idx = int(c["prev"][off])
+ return actions[::-1]
+
+ def state(self):
+ return {"count": self.count, "n_flushed": self.n_flushed,
+ "chunk_size": self.chunk_size,
+ "cur_prev": self.prev[:self.fill].copy(), "cur_act": self.act[:self.fill].copy(),
+ "cur_rew": self.rew[:self.fill].copy(), "cur_done": self.done[:self.fill].copy()}
+
+ def load_state(self, st):
+ assert st["chunk_size"] == self.chunk_size, "explog chunk_size mismatch"
+ if self.dir: # flushed chunks must be reachable (own dir or ancestor's)
+ self.n_flushed = st["n_flushed"]
+ for i in range(st["n_flushed"]):
+ self._chunk_path(i) # raises loudly if a chunk is missing
+ self.count, self.n_flushed = st["count"], st["n_flushed"]
+ self._new_chunk()
+ n = len(st["cur_prev"])
+ self.prev[:n], self.act[:n] = st["cur_prev"], st["cur_act"]
+ self.rew[:n], self.done[:n] = st["cur_rew"], st["cur_done"]
+ self.fill = n
+
+
+class Archive:
+ """Cell store + novelty-weighted selection + the accept rule. All updates
+ happen serially in the master process."""
+
+ def __init__(self):
+ self.cells = {} # (key_bytes, done_bool) -> Cell
+ self.rooms = set() # diagnostic only (RAM byte 3)
+ self.done_scores = [] # recent end-of-episode scores (logging)
+
+ def seed_root(self, key, snapshot, lives):
+ self.cells[(key, False)] = Cell(snapshot, 0.0, 0, -1, lives)
+
+ @property
+ def best_done_score(self):
+ c = self.cells.get(DONE_KEY)
+ return c.score if c else float("-inf")
+
+ @property
+ def max_archive_score(self):
+ return max(c.score for k, c in self.cells.items() if k != DONE_KEY)
+
+ def sample(self, n, rng):
+ """n cells with replacement, p ∝ 1/sqrt(seen+1); DONE excluded.
+
+ Returns (key, CAPTURE) pairs that freeze the cell's snapshot/score/
+ trajectory AT SAMPLING TIME. The trajectory walk must use the capture,
+ never the live cell: an earlier result in the same batch may replace
+ the cell, and stitching actions executed from the OLD state onto the
+ NEW prefix fabricates scores no single playthrough achieved
+ (2026-06-08 incident — caught by publish-time replay verification;
+ the official code ships these values inside the task for this reason)."""
+ keys = [k for k in self.cells if k != DONE_KEY]
+ w = np.array([1.0 / np.sqrt(self.cells[k].seen + 1.0) for k in keys])
+ csum = np.cumsum(w)
+ picks = []
+ for u in rng.random(n) * csum[-1]:
+ k = keys[min(int(np.searchsorted(csum, u)), len(keys) - 1)]
+ c = self.cells[k]
+ c.chosen += 1
+ c.chosen_since_new += 1
+ picks.append((k, {"snapshot": c.snapshot, "lives": c.lives, "score": c.score,
+ "traj_len": c.traj_len, "traj_last": c.traj_last}))
+ return picks
+
+ def update_from_trajectory(self, chosen_key, capture, res, explog):
+ """Walk one exploration episode (master-side, serial): append to the
+ experience log, accumulate raw score from the SAMPLING-TIME capture
+ (never the live cell — see sample()), apply the accept rule (note 4)."""
+ chosen = self.cells.get(chosen_key)
+ cur_score = capture["score"]
+ cur_len = capture["traj_len"]
+ prev_id = capture["traj_last"]
+ found_new = False
+ seen_this_episode = set()
+
+ for i in range(res["n_steps"]):
+ prev_id = explog.append(prev_id, res["actions"][i], res["rewards"][i], res["dones"][i])
+ cur_score += float(res["rewards"][i])
+ cur_len += 1
+ done = bool(res["dones"][i])
+ key = DONE_KEY if done else (res["keys"][i], False)
+
+ cell = self.cells.get(key)
+ if cell is None:
+ self.cells[key] = Cell(res["snapshots"][i], cur_score, cur_len, prev_id,
+ res["lives"][i])
+ self.cells[key].seen = 1
+ seen_this_episode.add(key)
+ found_new = True
+ else:
+ if key not in seen_this_episode:
+ cell.seen += 1
+ seen_this_episode.add(key)
+ if cur_score > cell.score or (cur_score == cell.score and cur_len < cell.traj_len):
+ cell.snapshot = res["snapshots"][i]
+ cell.score, cell.traj_len, cell.traj_last = cur_score, cur_len, prev_id
+ cell.lives = res["lives"][i]
+ cell.seen = cell.chosen = cell.chosen_since_new = 0 # reset_cell_on_update
+ found_new = True
+ if done:
+ self.done_scores.append(cur_score)
+ break
+
+ if found_new and chosen is not None:
+ chosen.chosen_since_new = 0
+ self.rooms.update(res["rooms"])
+
+ def state(self):
+ return {"cells": {k: {"snapshot": c.snapshot, "score": c.score,
+ "traj_len": c.traj_len, "traj_last": c.traj_last,
+ "seen": c.seen, "chosen": c.chosen,
+ "chosen_since_new": c.chosen_since_new, "lives": c.lives}
+ for k, c in self.cells.items()},
+ "rooms": sorted(self.rooms), "done_scores": self.done_scores[-200:]}
+
+ def load_state(self, st):
+ self.cells = {}
+ for k, d in st["cells"].items():
+ c = Cell(d["snapshot"], d["score"], d["traj_len"], d["traj_last"], d["lives"])
+ c.seen, c.chosen, c.chosen_since_new = d["seen"], d["chosen"], d["chosen_since_new"]
+ self.cells[k] = c
+ self.rooms = set(st["rooms"])
+ self.done_scores = list(st["done_scores"])
+
+
+# ---------------------------------------------------------------------------
+# Worker side. Top-level functions: mp 'spawn' re-imports this module, so the
+# main body below stays behind the __main__ guard. Each worker owns one env.
+# ---------------------------------------------------------------------------
+_W = {}
+
+
+def _worker_init(env_key):
+ _W["env"] = make_restore_env(env_key)
+ _W["ale"] = _W["env"].ale
+
+
+def _explore_task(task):
+ """task = (snapshot bytes | None for root reset, lives, k, seed).
+ Restore -> up to k steps of repeated random actions; abort on life loss /
+ game over. Returns per-step arrays (keys/snapshots for archive insert)."""
+ snapshot, prev_lives, k, seed = task
+ env, ale = _W["env"], _W["ale"]
+ rng = np.random.default_rng(seed)
+ if snapshot is None:
+ env.reset(seed=0)
+ else:
+ ale.restoreState(pickle.loads(snapshot))
+ # design note 6: NO reads here — the restored state is stale until we act
+
+ n_actions = env.action_space.n
+ actions, rewards, dones, keys, snapshots, lives_list, rooms = [], [], [], [], [], [], set()
+ action = int(rng.integers(n_actions))
+ for _ in range(k):
+ if rng.random() > ACTION_REPEAT_P:
+ action = int(rng.integers(n_actions))
+ frame, reward, terminated, truncated, _ = env.step(action)
+ lives = ale.lives()
+ done = bool(terminated) or lives < prev_lives
+ actions.append(action)
+ rewards.append(float(reward))
+ dones.append(done)
+ keys.append(cell_key(frame))
+ snapshots.append(pickle.dumps(ale.cloneState()))
+ lives_list.append(lives)
+ rooms.add(int(ale.getRAM()[ROOM_RAM_BYTE]))
+ if done:
+ break
+ prev_lives = lives
+ return {"n_steps": len(actions), "actions": actions, "rewards": rewards,
+ "dones": dones, "keys": keys, "snapshots": snapshots,
+ "lives": lives_list, "rooms": rooms}
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.total_frames:
+ TOTAL_FRAMES = args.total_frames
+ if args.n_workers:
+ N_WORKERS = args.n_workers
+ seed = args.seed if args.seed is not None else 0
+ rng = np.random.default_rng(seed)
+
+ logger = RunLogger(args.run_dir, args.ckpt_every)
+ explog_dir = os.path.join(args.run_dir, "explog") if args.run_dir else None
+ # cross-run-dir resume: flushed explog chunks live next to the checkpoint
+ # we resume from (the harness relaunches into a fresh run dir)
+ resume_path = logger.resolve_resume(args.resume)
+ ancestor = (os.path.join(os.path.dirname(os.path.dirname(resume_path)), "explog")
+ if resume_path else None)
+ explog = ExperienceLog(explog_dir, ancestor_dir=ancestor)
+ archive = Archive()
+ frames = 0
+ batch = 0
+
+ def _state_fn():
+ return {"version": 1, "frames": frames, "batch": batch,
+ "archive": archive.state(), "explog": explog.state(),
+ "rng": rng.bit_generator.state}
+
+ # --- resume or seed the root cell ---
+ if resume_path:
+ import torch
+ ckpt = torch.load(resume_path, map_location="cpu", weights_only=False)
+ frames, batch = ckpt["frames"], ckpt["batch"]
+ archive.load_state(ckpt["archive"])
+ explog.load_state(ckpt["explog"])
+ rng.bit_generator.state = ckpt["rng"]
+ print(f"resumed from {resume_path}: frames={frames} batch={batch} "
+ f"cells={len(archive.cells)} explog={explog.count}")
+ else:
+ # root cell from a fresh reset (reset obs is NOT stale — note 6 only
+ # applies to restores)
+ env0 = make_restore_env(args.env)
+ frame0, _ = env0.reset(seed=0)
+ archive.seed_root(cell_key(np.asarray(frame0)), pickle.dumps(env0.ale.cloneState()),
+ env0.ale.lives())
+ env0.close()
+
+ print(f"env: {args.env} workers: {N_WORKERS} total_frames: {TOTAL_FRAMES:,} seed: {seed}")
+ ctx = mp.get_context("spawn")
+ t_start = time.time()
+ with ctx.Pool(N_WORKERS, initializer=_worker_init, initargs=(args.env,)) as pool:
+ while frames < TOTAL_FRAMES:
+ picks = archive.sample(BATCH_CELLS, rng) # (key, sampling-time capture)
+ tasks = [(cap["snapshot"], cap["lives"], EXPLORE_STEPS, int(rng.integers(2 ** 31)))
+ for _, cap in picks]
+ results = pool.map(_explore_task, tasks, chunksize=2) # ordered -> deterministic
+ for (key, cap), res in zip(picks, results):
+ archive.update_from_trajectory(key, cap, res, explog)
+ frames += res["n_steps"]
+ batch += 1
+
+ if batch % LOG_EVERY_BATCHES == 0:
+ best = archive.best_done_score
+ gate = best if best != float("-inf") else 0.0
+ tail = archive.done_scores[-20:]
+ print(f"batch {batch:>6} frames {frames:>11,} cells {len(archive.cells):>6} "
+ f"best_done {gate:>8.0f} max_arch {archive.max_archive_score:>8.0f} "
+ f"rooms {len(archive.rooms):>3}")
+ logger.log(frames, {
+ "game_return_mean_lastK": gate, # semantics: best end-of-episode score (K=1)
+ "ep_return_mean": float(np.mean(tail)) if tail else 0.0,
+ "game_return_count": len(archive.done_scores),
+ "best_done_score": gate,
+ "max_archive_score": archive.max_archive_score,
+ "n_cells": len(archive.cells),
+ "rooms_found": len(archive.rooms),
+ "explog_entries": explog.count,
+ "batch": batch,
+ "nan_flag": 0,
+ })
+ logger.checkpoint(frames, _state_fn,
+ gate=gate if gate > 0 else None)
+
+ best = archive.best_done_score
+ final_score = best if best != float("-inf") else 0.0
+ hours = (time.time() - t_start) / 3600
+ print(f"done: frames {frames:,} cells {len(archive.cells)} rooms {len(archive.rooms)} "
+ f"best_done {final_score:.0f} ({hours:.2f}h)")
+ # final.json value_mean = best end-of-episode score (the official Phase-1
+ # metric; see targets.yaml montezuma_goexplore for the protocol caveat)
+ logger.finalize(frames, [final_score], _state_fn, k=1)
diff --git a/4-atari-hard/env_go_explore.py b/4-atari-hard/env_go_explore.py
new file mode 100644
index 0000000..1e5420b
--- /dev/null
+++ b/4-atari-hard/env_go_explore.py
@@ -0,0 +1,152 @@
+"""Go-Explore env setup (restore-based exploration on raw gymnasium ALE).
+
+Separate plumbing from this folder's `env.py` (the PPO/RND envpool stack):
+Go-Explore's exploration phase needs the emulator's save/restore API
+(ale.cloneState / restoreState), which envpool does not expose. Each
+(worker) process owns a single raw ALE env built by `make_restore_env`.
+The harness binds promotion markers to the script PLUS the sibling modules
+it actually imports, so 2-go-explore.py is hashed with THIS file, not env.py.
+
+Protocol (Ecoffet et al. 2019/2021, exploration phase): fully deterministic —
+frameskip 4, NO sticky actions, no no-ops, seed 0. Stochasticity only enters
+in the (separate, later) robustification phase. The TimeLimit wrapper is
+stripped (`.unwrapped`): its step counter is meaningless when episodes are
+entered mid-trajectory via state restore.
+
+★ Verified ALE pitfall (this machine, ale-py 0.11.2): right after
+`restoreState`, `getRAM()` / screen reads still return the PRE-restore values
+until the next `act()`. Callers must therefore derive cell keys only from
+frames returned by `env.step()`, never from immediate post-restore reads.
+"""
+import argparse
+import json
+import os
+import statistics
+import time
+
+import torch # checkpoint serialization only — there is no neural net here
+
+
+def _atomic_save(state, path):
+ """tmp -> rename so a crash mid-write never corrupts the checkpoint."""
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ tmp = f"{path}.tmp"
+ torch.save(state, tmp)
+ os.replace(tmp, path)
+
+
+class RunLogger:
+ """Optional run-directory outputs: metrics.jsonl, periodic / milestone /
+ best checkpoints, resume, and a final.json summary. Inert when run_dir is
+ None, so the script still runs standalone.
+
+ Same contract as 4-atari-hard/env.py with one change: milestone
+ checkpoints fire every 50M frames instead of 5M — a Go-Explore checkpoint
+ carries the whole archive (~0.5 GB at 50k cells), and a 500M-step run
+ would otherwise pile up 100 of them."""
+
+ MILESTONE_EVERY = 50_000_000
+
+ def __init__(self, run_dir, ckpt_every):
+ self.dir = run_dir
+ self.ckpt_dir = os.path.join(run_dir, "ckpt") if run_dir else None
+ self.ckpt_every = ckpt_every
+ if self.ckpt_dir:
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+ self.f = open(os.path.join(run_dir, "metrics.jsonl"), "a", buffering=1) if run_dir else None
+ self.t0, self.last_frames = time.time(), 0
+ self.ckpt_last, self.ms_last, self.best = 0, 0, float("-inf")
+
+ def log(self, frames, scalars):
+ """Append one structured row (frames + sps + caller's scalars) to metrics.jsonl."""
+ if not self.f:
+ return
+ now = time.time()
+ sps = (frames - self.last_frames) / max(now - self.t0, 1e-9)
+ self.f.write(json.dumps({"ts": round(now, 1), "frames": frames, "sps": round(sps, 1), **scalars}) + "\n")
+ self.t0, self.last_frames = now, frames
+
+ def resolve_resume(self, resume_arg):
+ """'auto' -> run_dir/ckpt/latest.pt, else a path, else None."""
+ if resume_arg == "auto" and self.ckpt_dir:
+ cand = os.path.join(self.ckpt_dir, "latest.pt")
+ return cand if os.path.exists(cand) else None
+ if resume_arg and resume_arg != "auto":
+ return resume_arg if os.path.exists(resume_arg) else None
+ return None
+
+ def checkpoint(self, frames, state_fn, gate=None):
+ """Periodic 'latest', 50M-step milestone, and best-gate checkpoints.
+ state_fn() builds the dict only when a save actually happens."""
+ if not self.ckpt_dir or not self.ckpt_every:
+ return
+ if frames - self.ckpt_last >= self.ckpt_every:
+ _atomic_save(state_fn(), os.path.join(self.ckpt_dir, "latest.pt"))
+ self.ckpt_last = frames
+ if frames - self.ms_last >= self.MILESTONE_EVERY:
+ _atomic_save(state_fn(), os.path.join(self.ckpt_dir, f"step_{frames // 1_000_000}M.pt"))
+ self.ms_last = frames
+ if gate is not None and gate > self.best:
+ self.best = gate
+ _atomic_save(state_fn(), os.path.join(self.ckpt_dir, "best.pt"))
+
+ def finalize(self, frames, game_returns, state_fn, k=100):
+ """Final 'latest' checkpoint + a final.json result summary."""
+ if self.ckpt_dir:
+ _atomic_save(state_fn(), os.path.join(self.ckpt_dir, "latest.pt"))
+ if self.dir:
+ tail = [float(x) for x in game_returns[-k:]]
+ with open(os.path.join(self.dir, "final.json"), "w") as fh:
+ json.dump({"frames_total": frames, "frames_unit": "agent_steps",
+ "gate_metric": "game_return_mean_lastK", "K": k,
+ "value_mean": statistics.fmean(tail) if tail else float("nan"),
+ "value_std": statistics.pstdev(tail) if len(tail) > 1 else 0.0,
+ "episodes_counted": len(tail)}, fh, indent=1)
+ if self.f:
+ self.f.close()
+
+
+# Gymnasium / ALE ids. The "_goexplore" key marks a distinct benchmark
+# protocol (deterministic, no sticky) — never cross-compare with the
+# sticky-action `montezuma` numbers elsewhere in this repo.
+ENV_IDS = {
+ "montezuma_goexplore": "ALE/MontezumaRevenge-v5",
+}
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+ p.add_argument("--env", choices=list(ENV_IDS), default="montezuma_goexplore",
+ help="which game to explore")
+ # --- reproducibility / run-management flags (harness run contract) ---
+ p.add_argument("--seed", type=int, default=None,
+ help="seed for the action RNG (the emulator itself is deterministic)")
+ p.add_argument("--total-frames", type=int, default=None,
+ help="override the in-file TOTAL_FRAMES budget (agent steps actually executed)")
+ p.add_argument("--n-workers", type=int, default=None,
+ help="override the in-file N_WORKERS (parallel explorer processes)")
+ p.add_argument("--run-dir", type=str, default=None,
+ help="run directory: write metrics.jsonl / ckpt / final.json here")
+ p.add_argument("--ckpt-every", type=int, default=None,
+ help="periodic checkpoint interval in agent steps (resume-safe)")
+ p.add_argument("--resume", type=str, default=None,
+ help="'auto' (run-dir/ckpt/latest.pt) or a checkpoint path")
+ return p.parse_args()
+
+
+def make_restore_env(env_key):
+ """Single raw ALE env with clone/restore access.
+
+ Imports live here (not module top) so harness-side tests can stub this
+ module without pulling in ale_py. Returns the unwrapped env: TimeLimit's
+ step counter would spuriously truncate restore-based exploration, and
+ OrderEnforcing rejects step-after-restore patterns."""
+ import ale_py
+ import gymnasium as gym
+ gym.register_envs(ale_py)
+ env = gym.make(ENV_IDS[env_key], frameskip=4,
+ repeat_action_probability=0.0, # deterministic — Phase 1 requirement
+ obs_type="grayscale").unwrapped
+ env.reset(seed=0) # canonical deterministic start; variation comes from action RNGs
+ assert env.spec.kwargs.get("repeat_action_probability", None) == 0.0
+ return env
diff --git a/README.md b/README.md
index c27e5ac..c0b17cc 100644
--- a/README.md
+++ b/README.md
@@ -47,6 +47,16 @@ Trained on a **Mac Studio (Apple M4 Max, 64 GB)** — different hardware from th
> Random Network Distillation (Burda et al. 2018) for hard exploration. With 512 envs the first key is found reliably (~327k steps) and the extrinsic value bootstraps around 10M steps; with 128 envs the same code never scored in 50M steps — parallel breadth is what cracks the first-key bottleneck. Stopped at ~65M agent steps after the score plateaued **above the paper's PPO baseline (2497)**; not run to a fixed budget. Still far below RND's headline 8152, which used 128–1024 envs × 1.97B frames (~30× more experience). `Params` = trainable weights (actor-critic 1.69M + RND predictor 2.20M; the frozen RND target adds 1.68M). Single seed, so no ± std — a 3-seed run is the next step for a defensible number.
+### Atari — Montezuma's Revenge (Go-Explore, exploration phase)
+
+Same hardware as the PPO + RND row above (Mac Studio, M4 Max, 64 GB). `ALE/MontezumaRevenge-v5` under the **deterministic protocol** (no sticky actions, frameskip 4, fixed seed) that restore-based exploration requires — these numbers are **not comparable** to the sticky-action RL rows above. 12 parallel explorer processes, no neural network.
+
+| Algorithm | Params | Train time | Best end-of-episode score | Frames | W&B |
+|-----------|--------|------------|---------------------------|--------|-----|
+| Go-Explore (exploration phase) | — (no NN) | ~5.5h | 31,000 (single seed) | 500M agent steps | [run](https://wandb.ai/rlcode/rl-atari-hard-go-explore/runs/m6ox4l3m) |
+
+> Go-Explore phase 1 (Ecoffet et al. 2019 / Nature 2021): an archive of downscaled-frame cells (11×8 pixels, 9 gray levels — no domain knowledge), emulator state save/restore to *return* to frontier cells, and repeated random actions to explore from them. The score is the best **end-of-episode** trajectory found by deterministic search, and it is **replay-verified**: re-executing the stored 5,336-action sequence from reset reproduces exactly 31,000. It is a trajectory-search result, **not an RL policy score** — the paper's robustification phase (distilling demos into a policy under sticky actions) is not run here. For reference, the Nature exploration-phase mean without domain knowledge is 24,758 at the same 2 B-frame budget (50+ seeds vs our single seed). Rooms found: 24.
+
## Setup
Requires Python 3.11 and [uv](https://docs.astral.sh/uv/).