diff --git a/docs/features/api-server/index.md b/docs/features/api-server/index.md index 104451fc0..2e96fa37d 100644 --- a/docs/features/api-server/index.md +++ b/docs/features/api-server/index.md @@ -50,6 +50,7 @@ All endpoints are under the `/api` prefix. | `DELETE` | `/api/sessions/:id` | Delete a session | | `PATCH` | `/api/sessions/:id/title` | Update session title | | `PATCH` | `/api/sessions/:id/permissions` | Update session permissions | +| `PATCH` | `/api/sessions/:id/mode` | Switch the session between `build` (default) and `plan` mode — see [Plan mode](#plan-mode) | | `POST` | `/api/sessions/:id/resume` | Resume a paused session (after tool confirmation) | | `POST` | `/api/sessions/:id/tools/toggle` | Toggle auto-approve (YOLO) mode | | `POST` | `/api/sessions/:id/elicitation` | Respond to an MCP tool elicitation request | @@ -204,6 +205,30 @@ By default, tool calls require approval. In the API workflow: Toggle auto-approve with `POST /api/sessions/:id/tools/toggle` for automated workflows. +## Plan mode {#plan-mode} + +Each session has an interaction `mode` that controls what the agent is allowed to do during a turn: + +- `build` (default) — the agent has its full toolset. +- `plan` — the runtime hides every tool that isn't tagged with the MCP-spec `ReadOnlyHint` annotation, and splices a per-turn system reminder telling the agent to draft a plan instead of acting. Use this when you want the agent to research and propose changes before the user authorises execution. + +The mode is server-scoped session state, persisted alongside the rest of the session. + +**Setting the mode** + +- At create time: `POST /api/sessions` with `{ "mode": "plan" }` in the body. Empty / omitted means `build`. Unknown values are rejected with `400`. +- Mid-session: `PATCH /api/sessions/:id/mode` with `{ "mode": "plan" }` or `{ "mode": "build" }`. The new mode applies on the **next** turn — an in-flight turn finishes under the mode it started with. Responds with `{ "id": "...", "mode": "..." }`. + +The current mode is included in `GET /api/sessions/:id` and `GET /api/sessions/:id/snapshot` responses as the top-level `mode` field. + +**Inheritance** + +Sub-sessions created by delegation tools (`transfer_task`, `run_skill`, the `agent` background-agent builtin) inherit the parent's mode, so a plan-mode parent can't bypass the filter by delegating to a child that would otherwise default to `build`. + +**Harness-backed agents** + +Plan mode is not supported for agents that delegate the whole turn to an external coding harness (`agent.harness` set in the YAML): the harness manages its own toolset, so the runtime cannot enforce the read-only filter. Attempting to run a harness agent while the session is in plan mode produces an `error` event with `code: "unsupported_mode"` — switch back to `build` first, or pick a non-harness agent. + ## Driving a running TUI with `--listen` {#listen} The same session API can be exposed by an **interactive run** so an external diff --git a/pkg/api/types.go b/pkg/api/types.go index 5994ea5fb..ec6aac21b 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -135,6 +135,7 @@ type SessionResponse struct { OutputTokens int64 `json:"output_tokens"` WorkingDir string `json:"working_dir,omitempty"` Permissions *session.PermissionsConfig `json:"permissions,omitempty"` + Mode session.Mode `json:"mode,omitempty"` } // UpdateSessionPermissionsRequest represents a request to update session permissions. @@ -142,6 +143,17 @@ type UpdateSessionPermissionsRequest struct { Permissions *session.PermissionsConfig `json:"permissions"` } +// UpdateSessionModeRequest represents a request to update a session's mode. +type UpdateSessionModeRequest struct { + Mode session.Mode `json:"mode"` +} + +// UpdateSessionModeResponse represents the response from updating a session's mode. +type UpdateSessionModeResponse struct { + ID string `json:"id"` + Mode session.Mode `json:"mode"` +} + // ResumeSessionRequest represents a request to resume a session type ResumeSessionRequest struct { Confirmation string `json:"confirmation"` @@ -304,6 +316,7 @@ type SessionSnapshotResponse struct { Messages []session.Message `json:"messages"` ToolsApproved bool `json:"tools_approved"` Permissions *session.PermissionsConfig `json:"permissions,omitempty"` + Mode session.Mode `json:"mode,omitempty"` InputTokens int64 `json:"input_tokens"` OutputTokens int64 `json:"output_tokens"` diff --git a/pkg/runtime/agent_delegation.go b/pkg/runtime/agent_delegation.go index af4a510ff..ee04bf43f 100644 --- a/pkg/runtime/agent_delegation.go +++ b/pkg/runtime/agent_delegation.go @@ -169,6 +169,20 @@ func newSubSession(parent *session.Session, cfg SubSessionConfig, childAgent *ag session.WithSendUserMessage(false), session.WithParentID(parent.ID), session.WithAttachedFiles(attachedFiles), + // Propagate the parent's interaction mode so that plan mode is + // not bypassable via delegation: transfer_task / handoff / the + // agent builtin are read-only and survive plan-mode tool + // filtering, but without this line the child session would + // default back to build mode and the child agent would get + // every mutating tool. Inheriting the parent's mode preserves + // the "hard tool removal" guarantee across the whole delegation + // tree (sub-skills, transferred tasks, background agents). + // + // LoadMode (not direct field access) because the parent's + // mode may be flipped concurrently by PATCH + // /sessions/:id/mode while the parent's turn is still + // running. + session.WithMode(parent.LoadMode()), } if cfg.PinAgent { opts = append(opts, session.WithAgentName(cfg.AgentName)) diff --git a/pkg/runtime/agent_delegation_test.go b/pkg/runtime/agent_delegation_test.go index 679032c8d..c582d6780 100644 --- a/pkg/runtime/agent_delegation_test.go +++ b/pkg/runtime/agent_delegation_test.go @@ -148,6 +148,30 @@ func TestNewSubSession(t *testing.T) { // We can verify the user message is still the default. assert.Equal(t, "Please proceed.", s.GetLastUserMessageContent()) }) + + t.Run("inherits parent mode (build)", func(t *testing.T) { + // Default-mode parent should produce a build-mode child. This + // is the trivial case but documents the invariant. + buildParent := session.New(session.WithUserMessage("hello")) + s := newSubSession(buildParent, SubSessionConfig{Task: "t"}, childAgent) + assert.Equal(t, session.ModeBuild, s.Mode) + }) + + t.Run("inherits parent mode (plan)", func(t *testing.T) { + // Regression test for the plan-mode delegation bypass: a + // plan-mode parent must produce plan-mode children, so that + // downstream filterToolsForSession strips mutating tools from + // the child's toolset. Without WithMode(parent.Mode) in + // newSubSession the child would default back to build and a + // plan-mode agent could route around the filter via + // transfer_task / run_skill / the agent builtin. + planParent := session.New( + session.WithUserMessage("hello"), + session.WithMode(session.ModePlan), + ) + s := newSubSession(planParent, SubSessionConfig{Task: "t"}, childAgent) + assert.Equal(t, session.ModePlan, s.Mode) + }) } func TestSubSessionConfig_DefaultValues(t *testing.T) { diff --git a/pkg/runtime/event.go b/pkg/runtime/event.go index 4cedc03ef..9849ea3ea 100644 --- a/pkg/runtime/event.go +++ b/pkg/runtime/event.go @@ -234,6 +234,14 @@ const ( ErrorCodeToolFailed = "tool_failed" ErrorCodeHookBlocked = "hook_blocked" ErrorCodeLoopDetected = "loop_detected" + // ErrorCodeUnsupportedMode signals that the session's current Mode + // (e.g. plan) is incompatible with the agent that's about to run. + // Today this only fires when a plan-mode session tries to run a + // harness-backed agent: the runtime can't enforce plan mode's + // read-only tool filter for harness agents because the harness + // owns its toolset, so the turn is refused instead of running with + // a partial (advisory-only) guarantee. + ErrorCodeUnsupportedMode = "unsupported_mode" ) type ErrorEvent struct { diff --git a/pkg/runtime/harness.go b/pkg/runtime/harness.go index c48b380f8..8ab9287d0 100644 --- a/pkg/runtime/harness.go +++ b/pkg/runtime/harness.go @@ -24,6 +24,22 @@ func (r *LocalRuntime) runHarnessAgent(ctx context.Context, sess *session.Sessio ctx, span := r.startSpan(ctx, "runtime.harness", trace.WithAttributes(traceAttributesForHarness(sess, a)...)) defer span.End() + // Plan mode's hard guarantee — every non-read-only tool is stripped + // from the model's toolset — relies on the runtime owning the + // toolset. Harness agents delegate the whole turn (tools included) + // to an external library, so we can't enforce the filter here. + // Rather than degrade plan mode to an advisory prompt (which the + // reminder text explicitly contradicts), refuse the turn so the + // user can either switch to build mode or pick a non-harness + // agent. + if sess.LoadMode() == session.ModePlan { + msg := fmt.Sprintf("plan mode is not supported for harness-backed agents (%q): the harness manages its own toolset, so the read-only tool filter cannot be enforced. Switch back to build mode to run this agent.", a.Name()) + events.Emit(ErrorWithCode(ErrorCodeUnsupportedMode, msg)) + r.notifyError(ctx, a, sess.ID, msg) + span.SetStatus(codes.Error, "plan mode unsupported for harness agent") + return turnEndReasonError + } + provider, err := codingharness.NewProvider(a.Harness()) if err != nil { msg := fmt.Sprintf("failed to configure harness: %v", err) @@ -46,6 +62,10 @@ func (r *LocalRuntime) runHarnessAgent(ctx context.Context, sess *session.Sessio }() turnStartMsgs := r.executeTurnStartHooks(ctx, sess, a, events) + // No plan-mode reminder spliced here: plan mode is refused for + // harness agents above, so by the time we reach this point + // sess.Mode is guaranteed to be build (or empty, which normalises + // to build). messages := sess.GetMessages(a, append(baseExtra, turnStartMsgs...)...) stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID, 1, messages) if stop { diff --git a/pkg/runtime/harness_test.go b/pkg/runtime/harness_test.go index 1792156de..2d9d421f2 100644 --- a/pkg/runtime/harness_test.go +++ b/pkg/runtime/harness_test.go @@ -167,6 +167,49 @@ printf '%s\n' '{"type":"result","result":"Hello world"}' assert.Equal(t, []string{"Hello", " world"}, chunks) } +// TestHarnessAgentRefusesPlanMode pins the plan-mode-vs-harness invariant: +// the runtime owns the toolset in the normal LLM loop and can strip +// non-read-only tools, but a harness-backed agent delegates the whole +// turn (tools included) to an external library. Rather than degrade +// plan mode to "advisory prompt only" — which the reminder text +// explicitly contradicts — the runtime refuses the turn and surfaces +// an unsupported_mode error so the user can switch back to build mode +// or pick a non-harness agent. +func TestHarnessAgentRefusesPlanMode(t *testing.T) { + if stdruntime.GOOS == "windows" { + t.Skip("shell script shim test") + } + + binDir := t.TempDir() + // Intentionally produces output that would normally be surfaced as + // an assistant message; the test asserts that the harness never + // runs, so this output should be dropped. + writeHarnessScript(t, binDir, "codex", `#!/bin/sh +printf '%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":"this should not appear"}}' +`) + t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH")) + + rt := newHarnessRuntime(t, "codex") + sess := session.New( + session.WithUserMessage("do the task"), + session.WithMode(session.ModePlan), + ) + events := collectRuntimeEvents(t, rt, sess) + + var errEvent *ErrorEvent + for _, ev := range events { + if e, ok := ev.(*ErrorEvent); ok { + errEvent = e + break + } + } + require.NotNil(t, errEvent, "expected ErrorEvent rejecting plan mode for harness agent") + assert.Equal(t, ErrorCodeUnsupportedMode, errEvent.Code) + assert.Contains(t, errEvent.Error, "plan mode") + // Harness must not have produced any assistant content. + assert.Empty(t, sess.GetLastAssistantMessageContent()) +} + func writeHarnessScript(t *testing.T, dir, name, content string) { t.Helper() require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o755)) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 59f23b2d1..3b149b297 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -266,7 +266,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, sink.Emit(ErrorWithCode(ErrorCodeToolFailed, fmt.Sprintf("failed to get tools: %v", err))) return } - agentTools = filterExcludedTools(agentTools, sess.ExcludedTools) + agentTools = filterToolsForSession(agentTools, sess) sink.Emit(ToolsetInfo(len(agentTools), false, a.Name())) @@ -348,7 +348,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, sink.Emit(ErrorWithCode(ErrorCodeToolFailed, fmt.Sprintf("failed to get tools: %v", err))) return } - agentTools = filterExcludedTools(agentTools, sess.ExcludedTools) + agentTools = filterToolsForSession(agentTools, sess) // Emit updated tool count. After a ToolListChanged MCP notification // the cache is invalidated, so getTools above re-fetches from the @@ -554,7 +554,13 @@ func (r *LocalRuntime) runTurn( // files) refresh every turn while session-level context (cwd, OS, // arch) stays stable — all without bloating the stored history. turnStartMsgs := r.executeTurnStartHooks(ctx, sess, a, events) - messages := sess.GetMessages(a, slices.Concat(ls.sessionStartMsgs, ls.userPromptMsgs, turnStartMsgs)...) + // Plan-mode reminder rides alongside the turn_start hook output so it + // participates in the same per-turn splice (and the cache_control marker + // that GetMessages applies to the last extra). It is appended last so its + // instruction is the most recent system context the model sees before the + // user prompt — minimising the chance the model ignores it. + planReminder := planModeReminderMessages(sess) + messages := sess.GetMessages(a, slices.Concat(ls.sessionStartMsgs, ls.userPromptMsgs, turnStartMsgs, planReminder)...) slog.DebugContext(ctx, "Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages)) // before_llm_call hooks fire just before the model is invoked. @@ -990,6 +996,35 @@ func filterExcludedTools(agentTools []tools.Tool, excluded []string) []tools.Too return filtered } +// filterToolsForSession applies all session-level tool filters: the explicit +// ExcludedTools name list (used by skill sub-sessions) and, when the session +// is in plan mode, anything whose tool definition doesn't advertise +// ReadOnlyHint. The MCP spec's ReadOnlyHint is the canonical "this tool has +// no side effects" signal, so it's the right knob for plan mode and it +// extends naturally to user-added MCP tools without any per-tool config. +func filterToolsForSession(agentTools []tools.Tool, sess *session.Session) []tools.Tool { + out := filterExcludedTools(agentTools, sess.ExcludedTools) + // LoadMode rather than direct field access: PATCH /sessions/:id/mode + // may flip Mode concurrently with the runtime stream goroutine. + if sess.LoadMode() == session.ModePlan { + out = filterToReadOnlyTools(out) + } + return out +} + +// filterToReadOnlyTools keeps only tools whose definition advertises +// ReadOnlyHint. Used by plan mode to hide every write/execute tool from the +// model so it can't reach for them even if the system reminder is ignored. +func filterToReadOnlyTools(agentTools []tools.Tool) []tools.Tool { + filtered := make([]tools.Tool, 0, len(agentTools)) + for _, t := range agentTools { + if t.Annotations.ReadOnlyHint { + filtered = append(filtered, t) + } + } + return filtered +} + // reprobe re-runs ensureToolSetsAreStarted after a batch of tool calls. // If new tools became available (by name-set diff), it emits a ToolsetInfo // event to update the TUI immediately. The new tools will be picked up by @@ -1010,7 +1045,7 @@ func (r *LocalRuntime) reprobe( slog.WarnContext(ctx, "reprobe: getTools failed", "agent", a.Name(), "error", err) return } - updated = filterExcludedTools(updated, sess.ExcludedTools) + updated = filterToolsForSession(updated, sess) // Emit any pending warnings that getTools just generated. r.emitAgentWarnings(a, events) diff --git a/pkg/runtime/plan_mode.go b/pkg/runtime/plan_mode.go new file mode 100644 index 000000000..2f89c24ff --- /dev/null +++ b/pkg/runtime/plan_mode.go @@ -0,0 +1,48 @@ +package runtime + +import ( + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/session" +) + +// planModeReminder is the per-turn system instruction injected when a session +// is in plan mode. Two layers enforce plan mode: the runtime hides every +// non-read-only tool from the model (see filterToolsForSession in loop.go), +// and this reminder tells the model how it should behave. Hiding the tools +// is the hard guarantee; the reminder is the explanation, so the model +// produces a useful plan instead of just bouncing off missing tools. +const planModeReminder = ` +You are currently in PLAN MODE. + +In this mode you research the codebase, ask clarifying questions, and write a +clear, actionable plan for the user. You MUST NOT make any changes to the +system: + +- No edits to files (no write, edit, create, or delete). +- No shell commands or background jobs. +- No state-changing tool calls of any kind. + +Only read-only tools have been made available to you for this turn. If you try +to call a tool that isn't in your list, the user has explicitly disabled it +for planning. + +End the turn by presenting the plan in your final message and asking the user +to review it. The user will switch you to BUILD MODE when they want execution +to begin. +` + +// planModeReminderMessages returns the system-reminder messages to splice +// before the conversation history when sess is in plan mode. Returns nil for +// other modes so callers can use it unconditionally. +// +// Reads mode via LoadMode so it stays consistent with concurrent +// PATCH /sessions/:id/mode writes coming through SessionManager. +func planModeReminderMessages(sess *session.Session) []chat.Message { + if sess == nil || sess.LoadMode() != session.ModePlan { + return nil + } + return []chat.Message{{ + Role: chat.MessageRoleSystem, + Content: planModeReminder, + }} +} diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 86cc3fc62..1203ccf83 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -2395,6 +2395,60 @@ func TestFilterExcludedTools(t *testing.T) { }) } +func TestFilterToolsForSession_PlanMode(t *testing.T) { + readOnly := tools.Tool{Name: "read_file", Annotations: tools.ToolAnnotations{ReadOnlyHint: true}} + mutating := tools.Tool{Name: "write_file"} + all := []tools.Tool{readOnly, mutating, {Name: "shell"}} + + t.Run("build mode keeps all tools", func(t *testing.T) { + sess := &session.Session{Mode: session.ModeBuild} + result := filterToolsForSession(all, sess) + assert.Len(t, result, 3) + }) + + t.Run("empty mode is treated as build", func(t *testing.T) { + // Sessions loaded before the mode column existed have Mode == "". + sess := &session.Session{} + result := filterToolsForSession(all, sess) + assert.Len(t, result, 3) + }) + + t.Run("plan mode keeps only read-only tools", func(t *testing.T) { + sess := &session.Session{Mode: session.ModePlan} + result := filterToolsForSession(all, sess) + assert.Len(t, result, 1) + assert.Equal(t, "read_file", result[0].Name) + }) + + t.Run("plan mode still respects ExcludedTools", func(t *testing.T) { + readOnly2 := tools.Tool{Name: "list_directory", Annotations: tools.ToolAnnotations{ReadOnlyHint: true}} + sess := &session.Session{ + Mode: session.ModePlan, + ExcludedTools: []string{"read_file"}, + } + result := filterToolsForSession([]tools.Tool{readOnly, readOnly2, mutating}, sess) + assert.Len(t, result, 1) + assert.Equal(t, "list_directory", result[0].Name) + }) +} + +func TestPlanModeReminderMessages(t *testing.T) { + t.Run("build mode returns nil", func(t *testing.T) { + assert.Nil(t, planModeReminderMessages(&session.Session{Mode: session.ModeBuild})) + }) + + t.Run("nil session returns nil", func(t *testing.T) { + assert.Nil(t, planModeReminderMessages(nil)) + }) + + t.Run("plan mode returns a single system reminder", func(t *testing.T) { + msgs := planModeReminderMessages(&session.Session{Mode: session.ModePlan}) + assert.Len(t, msgs, 1) + assert.Equal(t, chat.MessageRoleSystem, msgs[0].Role) + assert.Contains(t, msgs[0].Content, "PLAN MODE") + }) +} + func TestMergeExcludedTools(t *testing.T) { t.Run("both empty", func(t *testing.T) { assert.Nil(t, mergeExcludedTools(nil, nil)) diff --git a/pkg/server/server.go b/pkg/server/server.go index 257c25faf..8654eb8f0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -71,6 +71,7 @@ func (s *Server) registerRoutes() { group.POST("/sessions/:id/resume", s.resumeSession) group.POST("/sessions/:id/tools/toggle", s.toggleSessionYolo) group.PATCH("/sessions/:id/permissions", s.updateSessionPermissions) + group.PATCH("/sessions/:id/mode", s.updateSessionMode) group.PATCH("/sessions/:id/title", s.updateSessionTitle) group.PATCH("/sessions/:id/tokens", s.updateSessionTokens) group.PATCH("/sessions/:id/starred", s.setSessionStarred) @@ -225,6 +226,17 @@ func (s *Server) createSession(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err)) } + // Reject unknown mode values with a clear 400 instead of silently + // coercing to build (which is what the WithMode opt's NormalizeMode + // would do downstream). Matches the validation already done by + // PATCH /api/sessions/:id/mode so API clients get the same shape + // of error whether they pick a mode at create-time or via the + // mode-update endpoint. An empty mode is still accepted: it means + // "default" and resolves to ModeBuild. + if sessionTemplate.Mode != "" && !sessionTemplate.Mode.IsValid() { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid mode %q; must be one of: %s, %s", sessionTemplate.Mode, session.ModeBuild, session.ModePlan)) + } + sess, err := s.sm.CreateSession(c.Request().Context(), &sessionTemplate) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create session: %v", err)) @@ -249,6 +261,9 @@ func (s *Server) getSession(c echo.Context) error { OutputTokens: sess.OutputTokens, WorkingDir: sess.WorkingDir, Permissions: sess.Permissions, + // LoadMode: the runtime may be writing Mode via StoreMode + // concurrently if the session is actively streaming. + Mode: sess.LoadMode(), }) } @@ -329,6 +344,26 @@ func (s *Server) updateSessionPermissions(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"message": "session permissions updated"}) } +func (s *Server) updateSessionMode(c echo.Context) error { + sessionID := c.Param("id") + var req api.UpdateSessionModeRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err)) + } + if !req.Mode.IsValid() { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid mode %q; must be one of: %s, %s", req.Mode, session.ModeBuild, session.ModePlan)) + } + + if err := s.sm.UpdateSessionMode(c.Request().Context(), sessionID, req.Mode); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to update session mode: %v", err)) + } + + return c.JSON(http.StatusOK, api.UpdateSessionModeResponse{ + ID: sessionID, + Mode: req.Mode, + }) +} + func (s *Server) updateSessionTitle(c echo.Context) error { sessionID := c.Param("id") var req api.UpdateSessionTitleRequest diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 4bf503544..14d2386ce 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -163,6 +163,35 @@ func httpDo(t *testing.T, ctx context.Context, method, socketPath, path string, return buf } +// rawHTTPDo issues an HTTP request like httpDo but returns the status +// code and body without asserting < 400. Use it from tests that need +// to verify error responses (4xx/5xx). +func rawHTTPDo(t *testing.T, ctx context.Context, method, socketPath, path string, payload any) (int, []byte) { + t.Helper() + + buf, err := json.Marshal(payload) + require.NoError(t, err) + req, err := http.NewRequestWithContext(ctx, method, "http://_"+path, bytes.NewReader(buf)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", strings.TrimPrefix(socketPath, "unix://")) + }, + }, + } + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + out, err := io.ReadAll(resp.Body) + require.NoError(t, err) + return resp.StatusCode, out +} + func unmarshal(t *testing.T, buf []byte, v any) { t.Helper() err := json.Unmarshal(buf, &v) @@ -199,6 +228,82 @@ func TestServer_UpdateSessionTitle(t *testing.T) { assert.Equal(t, newTitle, sessionResp.Title) } +func TestServer_UpdateSessionMode(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + lnPath := startServerWithStore(t, ctx, prepareAgentsDir(t), store) + + // Create a session in default (build) mode. + createResp := httpDo(t, ctx, http.MethodPost, lnPath, "/api/sessions", map[string]any{}) + var createdSession session.Session + unmarshal(t, createResp, &createdSession) + require.NotEmpty(t, createdSession.ID) + + // Switch the session into plan mode. + patchResp := httpDo(t, ctx, http.MethodPatch, lnPath, + "/api/sessions/"+createdSession.ID+"/mode", + api.UpdateSessionModeRequest{Mode: session.ModePlan}) + var modeResp api.UpdateSessionModeResponse + unmarshal(t, patchResp, &modeResp) + + assert.Equal(t, createdSession.ID, modeResp.ID) + assert.Equal(t, session.ModePlan, modeResp.Mode) + + // GET should reflect the new mode. + getResp := httpGET(t, ctx, lnPath, "/api/sessions/"+createdSession.ID) + var sessionResp api.SessionResponse + unmarshal(t, getResp, &sessionResp) + assert.Equal(t, session.ModePlan, sessionResp.Mode) + + // Switch back to build mode. + patchResp = httpDo(t, ctx, http.MethodPatch, lnPath, + "/api/sessions/"+createdSession.ID+"/mode", + api.UpdateSessionModeRequest{Mode: session.ModeBuild}) + unmarshal(t, patchResp, &modeResp) + assert.Equal(t, session.ModeBuild, modeResp.Mode) +} + +func TestServer_CreateSession_AcceptsMode(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + lnPath := startServerWithStore(t, ctx, prepareAgentsDir(t), store) + + // Creating a session with mode=plan should persist that mode. + createResp := httpDo(t, ctx, http.MethodPost, lnPath, "/api/sessions", + map[string]any{"mode": string(session.ModePlan)}) + var createdSession session.Session + unmarshal(t, createResp, &createdSession) + require.NotEmpty(t, createdSession.ID) + assert.Equal(t, session.ModePlan, createdSession.Mode) + + getResp := httpGET(t, ctx, lnPath, "/api/sessions/"+createdSession.ID) + var sessionResp api.SessionResponse + unmarshal(t, getResp, &sessionResp) + assert.Equal(t, session.ModePlan, sessionResp.Mode) +} + +func TestServer_RejectsInvalidMode(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + lnPath := startServerWithStore(t, ctx, prepareAgentsDir(t), store) + + // Unknown mode values must be rejected with 400 rather than + // silently coerced to build — mirrors the validation already done + // by PATCH /api/sessions/:id/mode so clients get the same error + // shape at create-time and at update-time. + status, body := rawHTTPDo(t, ctx, http.MethodPost, lnPath, "/api/sessions", + map[string]any{"mode": "yolo"}) + assert.Equal(t, http.StatusBadRequest, status) + // Body is a JSON-escaped echo error message; match on the escaped form. + assert.Contains(t, string(body), `invalid mode \"yolo\"`) +} + func startServerWithStore(t *testing.T, ctx context.Context, agentsDir string, store session.Store) string { t.Helper() diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index ba13671f3..621c5a887 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -316,11 +316,14 @@ func (sm *SessionManager) GetSessionSnapshot(ctx context.Context, id string) (*a Messages: sess.GetAllMessages(), ToolsApproved: sess.ToolsApproved, Permissions: sess.Permissions, - InputTokens: sess.InputTokens, - OutputTokens: sess.OutputTokens, - Streaming: streaming, - Agent: agent, - LastEventSeq: lastSeq, + // LoadMode: the runtime may be writing Mode via StoreMode + // concurrently if the session is actively streaming. + Mode: sess.LoadMode(), + InputTokens: sess.InputTokens, + OutputTokens: sess.OutputTokens, + Streaming: streaming, + Agent: agent, + LastEventSeq: lastSeq, }, nil } @@ -353,6 +356,10 @@ func (sm *SessionManager) CreateSession(ctx context.Context, sessionTemplate *se opts = append(opts, session.WithPermissions(sessionTemplate.Permissions)) } + if sessionTemplate.Mode != "" { + opts = append(opts, session.WithMode(sessionTemplate.Mode)) + } + sess := session.New(opts...) // Copy model-related fields from the template so callers can pin a @@ -741,6 +748,34 @@ func (sm *SessionManager) UpdateSessionPermissions(ctx context.Context, sessionI return sm.sessionStore.UpdateSession(ctx, sess) } +// UpdateSessionMode updates the interaction mode (build/plan) for a session. +// If the session is actively running, it also updates the in-memory session +// object so the next turn's tool filter and plan-mode reminder see the new +// mode without having to round-trip through the store. +// +// The runtime stream goroutine reads Mode via LoadMode while we write +// here, so the write goes through StoreMode to take s.mu. sm.mux still +// gates UpdateSession to keep the store snapshot consistent with +// other session-manager mutations (Title, Permissions, …). +func (sm *SessionManager) UpdateSessionMode(ctx context.Context, sessionID string, mode session.Mode) error { + mode = session.NormalizeMode(mode) + sm.mux.Lock() + defer sm.mux.Unlock() + + if rt, ok := sm.runtimeSessions.Load(sessionID); ok && rt.session != nil { + rt.session.StoreMode(mode) + slog.DebugContext(ctx, "Updated mode for active session", "session_id", sessionID, "mode", mode) + return sm.sessionStore.UpdateSession(ctx, rt.session) + } + + sess, err := sm.sessionStore.GetSession(ctx, sessionID) + if err != nil { + return err + } + sess.StoreMode(mode) + return sm.sessionStore.UpdateSession(ctx, sess) +} + // UpdateSessionTitle updates the title for a session. // If the session is actively running, it also updates the in-memory session // object to prevent subsequent runtime saves from overwriting the title. diff --git a/pkg/session/branch.go b/pkg/session/branch.go index 6de7b5ed6..32dc22fa5 100644 --- a/pkg/session/branch.go +++ b/pkg/session/branch.go @@ -76,6 +76,7 @@ func (s *Session) Clone() *Session { CustomModelsUsed: cloneStringSlice(s.CustomModelsUsed), AttachedFiles: cloneStringSlice(s.AttachedFiles), ExcludedTools: cloneStringSlice(s.ExcludedTools), + Mode: s.Mode, AgentName: s.AgentName, ParentID: s.ParentID, MessageUsageHistory: slices.Clone(s.MessageUsageHistory), diff --git a/pkg/session/migrations.go b/pkg/session/migrations.go index 400d41bd1..3cd176964 100644 --- a/pkg/session/migrations.go +++ b/pkg/session/migrations.go @@ -400,6 +400,13 @@ func getAllMigrations() []Migration { Description: "Add first_kept_entry column to session_items for compaction-preserved messages", UpSQL: `ALTER TABLE session_items ADD COLUMN first_kept_entry INTEGER DEFAULT 0`, }, + { + ID: 22, + Name: "022_add_mode_column", + Description: "Add mode column to sessions table for build/plan mode", + UpSQL: `ALTER TABLE sessions ADD COLUMN mode TEXT DEFAULT ''`, + DownSQL: `ALTER TABLE sessions DROP COLUMN mode`, + }, } } diff --git a/pkg/session/migrations_pinned_test.go b/pkg/session/migrations_pinned_test.go index 1ffdff8d0..96ae7c5aa 100644 --- a/pkg/session/migrations_pinned_test.go +++ b/pkg/session/migrations_pinned_test.go @@ -39,7 +39,7 @@ func TestMigrationCatalogIsContentPinned(t *testing.T) { got := digestMigrationCatalog(getAllMigrations()) - const wantDigest = "0c6d5df46b970104cf49988ee3931e33643d5db85c68dd41b74b639d0094cec9" + const wantDigest = "399611d010efb60b9349257e05e0e68702d432e5019cdd56b4ae4e69654ac691" if got != wantDigest { t.Fatalf(`migration catalogue content has changed. diff --git a/pkg/session/session.go b/pkg/session/session.go index fffdb82e5..3482f285e 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -159,6 +159,13 @@ type Session struct { // recursive run_skill calls. ExcludedTools []string `json:"-"` + // Mode is the session's interaction mode. ModeBuild (default) gives the + // agent its full toolset. ModePlan filters the toolset to read-only tools + // and injects a system reminder so the agent drafts a plan instead of + // making changes. The mode can be flipped at any time via + // PATCH /api/sessions/:id/mode; the next turn picks it up. + Mode Mode `json:"mode,omitempty"` + // AgentName, when set, tells RunStream which agent to use for this session // instead of reading from the shared runtime currentAgent field. This is // required for background agent tasks where multiple sessions may run @@ -185,6 +192,41 @@ type MessageUsageRecord struct { Usage chat.Usage `json:"usage"` } +// Mode is the session's interaction mode (build vs plan). +// +// ModeBuild is the default and gives the agent its full toolset. +// ModePlan filters the toolset to read-only tools (anything whose tool +// definition lacks Annotations.ReadOnlyHint) and injects a per-turn system +// reminder telling the agent to plan rather than act. The runtime applies +// both effects automatically based on this field, so callers only need to +// flip the mode — they don't have to compute tool lists themselves. +type Mode string + +const ( + ModeBuild Mode = "build" + ModePlan Mode = "plan" +) + +// IsValid reports whether m is a known mode. +func (m Mode) IsValid() bool { + switch m { + case ModeBuild, ModePlan: + return true + default: + return false + } +} + +// NormalizeMode returns m if it is a known mode, or ModeBuild otherwise. +// Use this when reading mode from external input (persistence, HTTP body) +// to make sure downstream code always sees a valid mode. +func NormalizeMode(m Mode) Mode { + if m.IsValid() { + return m + } + return ModeBuild +} + // PermissionsConfig defines session-level tool permission overrides // using pattern-based rules (Allow/Ask/Deny arrays). type PermissionsConfig struct { @@ -470,6 +512,34 @@ func (s *Session) Usage() (input, output int64) { return s.InputTokens, s.OutputTokens } +// LoadMode returns the session's interaction mode under s.mu. Use +// this from any code that may run concurrently with the runtime +// (filterToolsForSession, planModeReminderMessages, the harness gate, +// the PATCH /sessions/:id/mode handler). Direct reads of sess.Mode +// are only safe at construction/serialisation time, before the +// session is shared with the runtime goroutine. Returns ModeBuild +// for any value that NormalizeMode doesn't recognise so downstream +// code always sees a valid mode. +// +// Verb naming matches sync/atomic.Value (Load/Store) rather than the +// field name to avoid a method-vs-field collision on Mode. +func (s *Session) LoadMode() Mode { + s.mu.RLock() + defer s.mu.RUnlock() + return NormalizeMode(s.Mode) +} + +// StoreMode atomically updates the session's interaction mode under +// s.mu. Unknown/empty inputs are normalised to ModeBuild so callers +// can pass through HTTP/store input without an extra IsValid gate. +// Use this in place of `sess.Mode = ...` everywhere except +// construction/serialisation paths. +func (s *Session) StoreMode(mode Mode) { + s.mu.Lock() + s.Mode = NormalizeMode(mode) + s.mu.Unlock() +} + // ApplyCompaction atomically resets the session's cumulative token // counts and appends a summary item under s.mu so concurrent readers // (e.g. the persistence observer's UpdateSession snapshot) cannot @@ -767,6 +837,14 @@ func WithExcludedTools(names []string) Opt { } } +// WithMode sets the session's interaction mode. An empty or unknown mode is +// normalised to ModeBuild so callers can pass through user input directly. +func WithMode(mode Mode) Opt { + return func(s *Session) { + s.Mode = NormalizeMode(mode) + } +} + // WithAttachedFiles seeds the session with absolute paths of files the user // attached. Used when creating sub-sessions so that delegated agents inherit // the parent's file context. Empty and duplicate paths are dropped. diff --git a/pkg/session/store.go b/pkg/session/store.go index 40970c579..87caeac83 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -230,6 +230,7 @@ func (s *InMemorySessionStore) UpdateSession(_ context.Context, session *Session AgentModelOverrides: cloneStringMap(session.AgentModelOverrides), CustomModelsUsed: cloneStringSlice(session.CustomModelsUsed), AttachedFiles: slices.Clone(session.AttachedFiles), + Mode: session.Mode, ParentID: session.ParentID, } session.mu.RUnlock() @@ -354,7 +355,7 @@ type SQLiteSessionStore struct { // sessionSelectColumns is the canonical SELECT list for the sessions table. // The column order matches what scanSession expects; all read paths use this // constant so that adding a column requires updating exactly one place. -const sessionSelectColumns = `id, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, custom_models_used, thinking, parent_id` +const sessionSelectColumns = `id, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, custom_models_used, thinking, parent_id, mode` // sessionPersistedFields holds the encoded form of a Session's JSON-bearing // columns plus the SQL representation of parent_id (nil for the empty @@ -591,16 +592,20 @@ func (s *SQLiteSessionStore) AddSession(ctx context.Context, session *Session) e } defer func() { _ = tx.Rollback() }() + // LoadMode (not direct Mode access) since the runtime may write + // to sess.Mode concurrently via StoreMode (PATCH /sessions/:id/mode). + // Other fields here pre-date this PR and follow the existing + // unlocked pattern — see TODO on session.go for the broader cleanup. _, err = tx.ExecContext(ctx, `INSERT INTO sessions ( id, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, permissions, agent_model_overrides, - custom_models_used, thinking, parent_id - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + custom_models_used, thinking, parent_id, mode + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, session.ID, session.ToolsApproved, session.InputTokens, session.OutputTokens, session.Title, session.Cost, session.SendUserMessage, session.MaxIterations, session.WorkingDir, session.CreatedAt.Format(time.RFC3339), fields.PermissionsJSON, fields.AgentModelOverridesJSON, - fields.CustomModelsUsedJSON, false, fields.ParentID) + fields.CustomModelsUsedJSON, false, fields.ParentID, string(session.LoadMode())) if err != nil { return err } @@ -628,6 +633,7 @@ func scanSession(scanner interface { workingDir sql.NullString permissionsJSON sql.NullString parentID sql.NullString + modeStr sql.NullString agentModelOverridesJSON string customModelsUsedJSON string createdAtStr string @@ -639,6 +645,7 @@ func scanSession(scanner interface { &sess.Title, &sess.Cost, &sess.SendUserMessage, &sess.MaxIterations, &workingDir, &createdAtStr, &sess.Starred, &permissionsJSON, &agentModelOverridesJSON, &customModelsUsedJSON, &thinking, &parentID, + &modeStr, ) if err != nil { return nil, err @@ -651,6 +658,7 @@ func scanSession(scanner interface { sess.WorkingDir = workingDir.String sess.ParentID = parentID.String + sess.Mode = NormalizeMode(Mode(modeStr.String)) if permissionsJSON.Valid && permissionsJSON.String != "" { sess.Permissions = &PermissionsConfig{} @@ -908,9 +916,9 @@ func (s *SQLiteSessionStore) UpdateSession(ctx context.Context, session *Session `INSERT INTO sessions ( id, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, - custom_models_used, thinking, parent_id + custom_models_used, thinking, parent_id, mode ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET title = excluded.title, tools_approved = excluded.tools_approved, @@ -925,11 +933,12 @@ func (s *SQLiteSessionStore) UpdateSession(ctx context.Context, session *Session agent_model_overrides = excluded.agent_model_overrides, custom_models_used = excluded.custom_models_used, thinking = excluded.thinking, - parent_id = excluded.parent_id`, + parent_id = excluded.parent_id, + mode = excluded.mode`, session.ID, session.ToolsApproved, session.InputTokens, session.OutputTokens, session.Title, session.Cost, session.SendUserMessage, session.MaxIterations, session.WorkingDir, session.CreatedAt.Format(time.RFC3339), session.Starred, fields.PermissionsJSON, fields.AgentModelOverridesJSON, - fields.CustomModelsUsedJSON, false, fields.ParentID) + fields.CustomModelsUsedJSON, false, fields.ParentID, string(session.LoadMode())) if err != nil { return err } @@ -1076,14 +1085,14 @@ func (s *SQLiteSessionStore) addSessionTx(ctx context.Context, tx *sql.Tx, sessi `INSERT INTO sessions ( id, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, - custom_models_used, thinking, parent_id + custom_models_used, thinking, parent_id, mode ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, session.ID, session.ToolsApproved, session.InputTokens, session.OutputTokens, session.Title, session.Cost, session.SendUserMessage, session.MaxIterations, session.WorkingDir, session.CreatedAt.Format(time.RFC3339), session.Starred, fields.PermissionsJSON, fields.AgentModelOverridesJSON, fields.CustomModelsUsedJSON, false, - fields.ParentID) + fields.ParentID, string(session.LoadMode())) return err } diff --git a/pkg/session/store_test.go b/pkg/session/store_test.go index cef3d0479..225c1d953 100644 --- a/pkg/session/store_test.go +++ b/pkg/session/store_test.go @@ -542,6 +542,101 @@ func TestUpdateSession_Permissions(t *testing.T) { assert.Equal(t, []string{"dangerous_*"}, retrieved.Permissions.Deny) } +func TestSessionMode_SQLite(t *testing.T) { + tempDB := filepath.Join(t.TempDir(), "test_session_mode.db") + + store, err := NewSQLiteSessionStore(tempDB) + require.NoError(t, err) + defer store.(*SQLiteSessionStore).Close() + + // Default Mode (empty string) round-trips as ModeBuild after scan. + defaultSess := &Session{ + ID: "default-mode-session", + Title: "Default mode", + CreatedAt: time.Now(), + } + require.NoError(t, store.AddSession(t.Context(), defaultSess)) + retrieved, err := store.GetSession(t.Context(), defaultSess.ID) + require.NoError(t, err) + assert.Equal(t, ModeBuild, retrieved.Mode) + + // ModePlan persists and reloads. + planSess := &Session{ + ID: "plan-mode-session", + Title: "Plan mode", + CreatedAt: time.Now(), + Mode: ModePlan, + } + require.NoError(t, store.AddSession(t.Context(), planSess)) + retrieved, err = store.GetSession(t.Context(), planSess.ID) + require.NoError(t, err) + assert.Equal(t, ModePlan, retrieved.Mode) + + // Mode can be flipped via UpdateSession. + planSess.Mode = ModeBuild + require.NoError(t, store.UpdateSession(t.Context(), planSess)) + retrieved, err = store.GetSession(t.Context(), planSess.ID) + require.NoError(t, err) + assert.Equal(t, ModeBuild, retrieved.Mode) +} + +func TestNormalizeMode(t *testing.T) { + assert.Equal(t, ModeBuild, NormalizeMode("")) + assert.Equal(t, ModeBuild, NormalizeMode("garbage")) + assert.Equal(t, ModeBuild, NormalizeMode(ModeBuild)) + assert.Equal(t, ModePlan, NormalizeMode(ModePlan)) +} + +func TestSession_ModeAccessors(t *testing.T) { + t.Run("LoadMode normalises unset / invalid", func(t *testing.T) { + // Sessions loaded before the mode column existed have Mode == "" + // in the struct. LoadMode should still return a valid value + // (ModeBuild) so downstream code never sees the empty sentinel. + s := &Session{} + assert.Equal(t, ModeBuild, s.LoadMode()) + + s.Mode = Mode("garbage") + assert.Equal(t, ModeBuild, s.LoadMode()) + }) + + t.Run("StoreMode normalises invalid input", func(t *testing.T) { + s := &Session{} + s.StoreMode(Mode("garbage")) + assert.Equal(t, ModeBuild, s.LoadMode()) + s.StoreMode(ModePlan) + assert.Equal(t, ModePlan, s.LoadMode()) + }) + + t.Run("Load and Store are race-free", func(t *testing.T) { + // Go's race detector catches unsynchronised reads/writes; this + // test is a no-op without -race but must succeed under it. + // Spawn one writer flipping the mode and many readers + // sampling it; LoadMode/StoreMode both take s.mu so neither + // observes a torn value. + s := New(WithUserMessage("hi")) + stop := make(chan struct{}) + done := make(chan struct{}) + go func() { + for { + select { + case <-stop: + close(done) + return + default: + s.StoreMode(ModePlan) + s.StoreMode(ModeBuild) + } + } + }() + for range 1000 { + m := s.LoadMode() + require.Truef(t, m == ModeBuild || m == ModePlan, "unexpected mode %q", m) + } + close(stop) + <-done + }) +} + func TestAgentModelOverrides_SQLite(t *testing.T) { tempDB := filepath.Join(t.TempDir(), "test_model_overrides.db")