diff --git a/agent-schema.json b/agent-schema.json index 86fbaa22a..b0624c0ce 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -2038,6 +2038,10 @@ "type": "boolean", "description": "Opt in to dialling non-public IP addresses (valid for type 'fetch', 'api', 'openapi', 'a2a', and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply." }, + "safer": { + "type": "boolean", + "description": "Enable destructive command detection for the shell toolset (only valid for type 'shell'). When enabled, every shell command requires explicit user approval regardless of permissions or --yolo. Commands matching docker-agent's embedded safety-pattern taxonomy use the matched blast-radius level; unmatched commands still warn with an unknown blast radius. Default false." + }, "sudo_askpass": { "type": "boolean", "description": "Opt in to a sudo privilege escalation flow for the shell toolset (only valid for type 'shell'). When enabled, sudo commands prompt the user for their password through the host UI via SUDO_ASKPASS; in non-interactive runs the prompt is declined automatically. Only a bare 'sudo ...' invocation in a POSIX shell is handled. No effect on Windows. Default false." diff --git a/docs/configuration/tools/index.md b/docs/configuration/tools/index.md index 1429371fc..2b9b1915d 100644 --- a/docs/configuration/tools/index.md +++ b/docs/configuration/tools/index.md @@ -15,7 +15,7 @@ Built-in tools are included with docker-agent and require no external dependenci | Type | Description | Page | | --- | --- | --- | | `filesystem` | Read, write, list, search, navigate | [Filesystem]({{ '/tools/filesystem/' | relative_url }}) | -| `shell` | Execute shell commands (sync + background jobs) | [Shell]({{ '/tools/shell/' | relative_url }}) | +| `shell` | Execute shell commands (sync + background jobs). Supports `safer: true` to force confirmation for known destructive commands. | [Shell]({{ '/tools/shell/' | relative_url }}) | | `think` | Reasoning scratchpad | [Think]({{ '/tools/think/' | relative_url }}) | | `todo` | Task list management | [Todo]({{ '/tools/todo/' | relative_url }}) | | `tasks` | Persistent task database shared across sessions | [Tasks]({{ '/tools/tasks/' | relative_url }}) | @@ -40,6 +40,7 @@ Built-in tools are included with docker-agent and require no external dependenci toolsets: - type: filesystem - type: shell + safer: true - type: think - type: todo - type: memory diff --git a/docs/tools/shell/index.md b/docs/tools/shell/index.md index 6263edf66..546026237 100644 --- a/docs/tools/shell/index.md +++ b/docs/tools/shell/index.md @@ -26,6 +26,7 @@ toolsets: | Property | Type | Description | | -------------- | ------- | --------------------------------------------------------------------------------------------------- | | `env` | object | Environment variables to set for all shell commands | +| `safer` | boolean | Detect known destructive shell commands and always ask for confirmation with a blast-radius warning. Default `false`. | | `sudo_askpass` | boolean | Opt in to prompting for a `sudo` password (see [Sudo support](#sudo-support)). Default `false`. | ### Custom Environment Variables @@ -38,6 +39,22 @@ toolsets: PATH: "${PATH}:/custom/bin" ``` +### Safer mode + +Set `safer: true` to enable destructive command detection for the `shell` tool: + +```yaml +toolsets: + - type: shell + safer: true +``` + +When enabled, docker-agent checks each `shell` tool call before the normal approval flow. The runtime always asks for explicit user approval, even when `--yolo` or permissions would otherwise auto-approve it. If the command matches a known destructive operation, the confirmation uses the taxonomy's blast-radius level; otherwise it still warns with an `unknown` blast radius. + +See [`examples/shell_safer.yaml`](https://github.com/docker/docker-agent/blob/main/examples/shell_safer.yaml) for a complete example. + +Current destructive command patterns are loaded from docker-agent's embedded `safety_patterns.json` taxonomy. The list covers filesystem deletion/overwrite commands, Docker cleanup commands, and selected out-of-scope-but-common destructive commands such as Git history rewrites. Each match carries a blast-radius level (`low`, `medium`, `high`, or `unknown`). + ### Sudo support By default a shell command has no controlling terminal, so a `sudo` command that needs a password hangs until it times out (the agent usually gives up and falls back to printing manual instructions). diff --git a/examples/README.md b/examples/README.md index caf1efb81..c0975620b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -65,6 +65,7 @@ Examples that wire up one of the toolsets shipped with docker-agent | File | What it shows | |------|---------------| | [`shell.yaml`](shell.yaml) | Plain `shell` toolset. | +| [`shell_safer.yaml`](shell_safer.yaml) | Shell toolset with `safer: true`, forcing confirmation for known destructive commands. | | [`filesystem.yaml`](filesystem.yaml) | Plain `filesystem` toolset. | | [`filesystem_allow_deny.yaml`](filesystem_allow_deny.yaml) | Restricting the filesystem tool with allow/deny path lists. | | [`script_shell.yaml`](script_shell.yaml) | Defining custom shell commands as named tools via `type: script`. | @@ -210,6 +211,7 @@ remote MCP endpoints. | File | What it shows | |------|---------------| | [`permissions.yaml`](permissions.yaml) | Top-level `permissions` block with `allow`/`deny` patterns for tool calls. | +| [`shell_safer.yaml`](shell_safer.yaml) | Shell `safer: true` mode that always asks before known destructive commands and shows blast radius. | | [`llm_judge.yaml`](llm_judge.yaml) | Layered defense: deterministic permissions + an LLM-as-judge `pre_tool_use` hook + user prompts. | | [`redact_secrets.yaml`](redact_secrets.yaml) | Single-flag (`redact_secrets: true`) scrubbing of detected secrets in args, chat content, and tool output. | | [`redact_secrets_hooks.yaml`](redact_secrets_hooks.yaml) | The same scrubbing wired manually as three hooks. | diff --git a/examples/shell_safer.yaml b/examples/shell_safer.yaml new file mode 100644 index 000000000..d604e53e0 --- /dev/null +++ b/examples/shell_safer.yaml @@ -0,0 +1,10 @@ +agents: + root: + model: anthropic/claude-haiku-4-5 + description: Shell agent with safer mode; enable snapshots in user config + welcome_message: | + Shell safer mode is enabled. To capture snapshots for /undo, enable `settings.snapshot: true` in ~/.config/cagent/config.yaml. + instruction: Use the shell tool to run the command the user asks for. + toolsets: + - type: shell + safer: true diff --git a/pkg/acp/agent.go b/pkg/acp/agent.go index 558ef5bfd..9f9e92b95 100644 --- a/pkg/acp/agent.go +++ b/pkg/acp/agent.go @@ -24,6 +24,7 @@ import ( "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/teamloader" loaderdefaults "github.com/docker/docker-agent/pkg/teamloader/defaults" + "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/version" ) @@ -707,28 +708,12 @@ func (a *Agent) runAgent(ctx context.Context, acpSess *Session) error { // handleToolCallConfirmation handles tool call permission requests. func (a *Agent) handleToolCallConfirmation(ctx context.Context, acpSess *Session, e *runtime.ToolCallConfirmationEvent) error { - toolCallUpdate := buildToolCallUpdate(e.ToolCall, e.ToolDefinition, acp.ToolCallStatusPending) + toolCallUpdate := buildToolCallUpdate(e.ToolCall, e.ToolDefinition, e.Safety, acp.ToolCallStatusPending) permResp, err := a.conn.RequestPermission(ctx, acp.RequestPermissionRequest{ SessionId: acp.SessionId(acpSess.id), ToolCall: toolCallUpdate, - Options: []acp.PermissionOption{ - { - Kind: acp.PermissionOptionKindAllowOnce, - Name: "Allow this action", - OptionId: "allow", - }, - { - Kind: acp.PermissionOptionKindAllowAlways, - Name: "Allow and remember my choice", - OptionId: "allow-always", - }, - { - Kind: acp.PermissionOptionKindRejectOnce, - Name: "Skip this action", - OptionId: "reject", - }, - }, + Options: permissionOptions(e.Safety), }) if err != nil { return err @@ -757,6 +742,34 @@ func (a *Agent) handleToolCallConfirmation(ctx context.Context, acpSess *Session return nil } +func permissionOptions(safety *tools.ToolCallSafety) []acp.PermissionOption { + allowName := "Allow this action" + if safety != nil && safety.Destructive { + level := safety.BlastRadius + if level == "" { + level = tools.BlastRadiusUnknown + } + allowName = fmt.Sprintf("Allow destructive tool (blast radius: %s)", level) + } + return []acp.PermissionOption{ + { + Kind: acp.PermissionOptionKindAllowOnce, + Name: allowName, + OptionId: "allow", + }, + { + Kind: acp.PermissionOptionKindAllowAlways, + Name: "Allow and remember my choice", + OptionId: "allow-always", + }, + { + Kind: acp.PermissionOptionKindRejectOnce, + Name: "Skip this action", + OptionId: "reject", + }, + } +} + // handleMaxIterationsReached handles max iterations events. func (a *Agent) handleMaxIterationsReached(ctx context.Context, acpSess *Session, e *runtime.MaxIterationsReachedEvent) error { title := fmt.Sprintf("Maximum iterations (%d) reached", e.MaxIterations) diff --git a/pkg/acp/toolcall.go b/pkg/acp/toolcall.go index 7a314dace..d0603b2a9 100644 --- a/pkg/acp/toolcall.go +++ b/pkg/acp/toolcall.go @@ -67,11 +67,18 @@ func buildToolCallComplete(arguments string, event *runtime.ToolCallResponseEven } // buildToolCallUpdate creates a tool call update for permission requests. -func buildToolCallUpdate(toolCall tools.ToolCall, tool tools.Tool, status acp.ToolCallStatus) acp.ToolCallUpdate { +func buildToolCallUpdate(toolCall tools.ToolCall, tool tools.Tool, safety *tools.ToolCallSafety, status acp.ToolCallStatus) acp.ToolCallUpdate { kind := acp.ToolKindExecute title := cmp.Or(tool.Annotations.Title, toolCall.Function.Name) - if tool.Annotations.ReadOnlyHint { + if safety != nil && safety.Destructive { + kind = acp.ToolKindDelete + level := safety.BlastRadius + if level == "" { + level = tools.BlastRadiusUnknown + } + title = fmt.Sprintf("Destructive tool: %s (blast radius: %s)", title, level) + } else if tool.Annotations.ReadOnlyHint { kind = acp.ToolKindRead } diff --git a/pkg/cli/printer.go b/pkg/cli/printer.go index 129d8693e..cb8973b23 100644 --- a/pkg/cli/printer.go +++ b/pkg/cli/printer.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/input" "github.com/docker/docker-agent/pkg/tools" + "github.com/docker/docker-agent/pkg/tui/components/toolconfirm" ) // ConfirmationResult represents the result of a user confirmation prompt @@ -80,9 +81,35 @@ func (p *Printer) PrintToolCall(toolCall tools.ToolCall) { p.Printf("\nCalling %s%s\n", bold(toolCall.Function.Name), formatToolCallArguments(toolCall.Function.Arguments)) } +func destructiveWarningPrinter() *color.Color { + return color.New(color.FgHiYellow, color.Bold) +} + +func blastRadiusPrinter(level tools.BlastRadiusLevel) *color.Color { + switch level { + case tools.BlastRadiusLow: + return color.New(color.FgGreen, color.Bold) + case tools.BlastRadiusMedium: + return color.New(color.FgYellow, color.Bold) + case tools.BlastRadiusHigh: + return color.New(color.FgRed, color.Bold) + default: + return color.New(color.FgWhite, color.Bold) + } +} + // PrintToolCallWithConfirmation prints a tool call and prompts for confirmation -func (p *Printer) PrintToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall, rd io.Reader) ConfirmationResult { - p.Printf("\n%s\n", bold("🛠️ Tool call requires confirmation 🛠️")) +func (p *Printer) PrintToolCallWithConfirmation(ctx context.Context, toolCall tools.ToolCall, safety *tools.ToolCallSafety, rd io.Reader) ConfirmationResult { + if safety != nil && safety.Destructive { + level := safety.BlastRadius + if level == "" { + level = tools.BlastRadiusUnknown + } + p.Printf("\n%s\n", destructiveWarningPrinter().Sprint(toolconfirm.DestructiveWarningTitle)) + p.Printf("Blast radius level: %s\n", blastRadiusPrinter(level).Sprint(string(level))) + } else { + p.Printf("\n%s\n", bold("🛠️ Tool call requires confirmation 🛠️")) + } p.PrintToolCall(toolCall) p.Printf("\n%s", bold("Can I run this tool? ([y]es/[a]ll/[n]o): ")) diff --git a/pkg/cli/runner.go b/pkg/cli/runner.go index 29cd28205..5fa3a94da 100644 --- a/pkg/cli/runner.go +++ b/pkg/cli/runner.go @@ -164,7 +164,7 @@ func Run(ctx context.Context, out *Printer, cfg Config, rt runtime.Runtime, sess case *runtime.AgentChoiceReasoningEvent: out.Print(e.Content) case *runtime.ToolCallConfirmationEvent: - result := out.PrintToolCallWithConfirmation(ctx, e.ToolCall, rd) + result := out.PrintToolCallWithConfirmation(ctx, e.ToolCall, e.Safety, rd) // If interrupted, skip resuming; the runtime will notice context cancellation and stop if ctx.Err() != nil { continue diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 362c5d1ec..6ecfcc77c 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -1110,6 +1110,12 @@ type Toolset struct { // nil means the field was omitted and may inherit from a referenced definition. AllowPrivateIPs *bool `json:"allow_private_ips,omitempty" yaml:"allow_private_ips,omitempty"` + // For the `shell` toolset — enable destructive command detection for the + // shell tool. When a shell call matches a known destructive command, the + // runtime always asks the user and includes the blast-radius level in the + // confirmation, regardless of permissions or --yolo. + Safer *bool `json:"safer,omitempty" yaml:"safer,omitempty"` + // For the `shell` toolset — opt in to a sudo privilege escalation flow. // When enabled, sudo commands prompt the user for their password (masked) // through the host UI via SUDO_ASKPASS; in non-interactive runs the prompt diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 3b8a90084..09313d253 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -226,6 +226,9 @@ func (t *Toolset) validate() error { if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" && t.Type != "api" && t.Type != "openapi" && t.Type != "a2a" { return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi', 'a2a' or remote MCP toolsets") } + if t.Safer != nil && t.Type != "shell" { + return errors.New("safer can only be used with type 'shell'") + } if t.SudoAskpass != nil && t.Type != "shell" { return errors.New("sudo_askpass can only be used with type 'shell'") } diff --git a/pkg/config/toolset_validate_test.go b/pkg/config/toolset_validate_test.go index 11c74ef2c..fbdc13a51 100644 --- a/pkg/config/toolset_validate_test.go +++ b/pkg/config/toolset_validate_test.go @@ -200,6 +200,67 @@ agents: } } +func TestToolset_Validate_Safer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "safer on shell is allowed", + config: ` +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + safer: true +`, + }, + { + name: "safer false on shell is allowed", + config: ` +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + safer: false +`, + }, + { + name: "safer on non-shell toolset is rejected", + config: ` +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: filesystem + safer: true +`, + wantErr: "safer can only be used with type 'shell'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg latest.Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + func TestToolset_Validate_MCP_WorkingDir(t *testing.T) { t.Parallel() diff --git a/pkg/embeddedchat/embeddedchat_test.go b/pkg/embeddedchat/embeddedchat_test.go index 63935b33e..af8b7e43c 100644 --- a/pkg/embeddedchat/embeddedchat_test.go +++ b/pkg/embeddedchat/embeddedchat_test.go @@ -124,7 +124,7 @@ func TestSessionSendSurfacesConfirmationAndConfirmResumesRuntime(t *testing.T) { call := tools.ToolCall{ID: "call-1", Function: tools.FunctionCall{Name: "write_file"}} def := tools.Tool{Name: "write_file"} - rt.events <- dagentruntime.ToolCallConfirmation(call, def, "agent") + rt.events <- dagentruntime.ToolCallConfirmation(call, def, nil, "agent") event := receiveEvent(t, out) require.NotNil(t, event.Tool) diff --git a/pkg/runtime/event.go b/pkg/runtime/event.go index c09de30bc..6d0769c73 100644 --- a/pkg/runtime/event.go +++ b/pkg/runtime/event.go @@ -108,16 +108,18 @@ func ToolCall(toolCall tools.ToolCall, toolDefinition tools.Tool, agentName stri type ToolCallConfirmationEvent struct { AgentContext - Type string `json:"type"` - ToolCall tools.ToolCall `json:"tool_call"` - ToolDefinition tools.Tool `json:"tool_definition"` + Type string `json:"type"` + ToolCall tools.ToolCall `json:"tool_call"` + ToolDefinition tools.Tool `json:"tool_definition"` + Safety *tools.ToolCallSafety `json:"safety,omitempty"` } -func ToolCallConfirmation(toolCall tools.ToolCall, toolDefinition tools.Tool, agentName string) Event { +func ToolCallConfirmation(toolCall tools.ToolCall, toolDefinition tools.Tool, safety *tools.ToolCallSafety, agentName string) Event { return &ToolCallConfirmationEvent{ Type: "tool_call_confirmation", ToolCall: toolCall, ToolDefinition: toolDefinition, + Safety: safety, AgentContext: newAgentContext(agentName), } } diff --git a/pkg/runtime/tool_dispatch.go b/pkg/runtime/tool_dispatch.go index c42413a48..580ea66f3 100644 --- a/pkg/runtime/tool_dispatch.go +++ b/pkg/runtime/tool_dispatch.go @@ -94,8 +94,8 @@ func (e *sinkEmitter) EmitToolCallResponse(toolCallID string, tool tools.Tool, r e.events.Emit(ToolCallResponse(toolCallID, tool, result, output, agentName)) } -func (e *sinkEmitter) EmitToolCallConfirmation(toolCall tools.ToolCall, tool tools.Tool, agentName string) { - e.events.Emit(ToolCallConfirmation(toolCall, tool, agentName)) +func (e *sinkEmitter) EmitToolCallConfirmation(toolCall tools.ToolCall, tool tools.Tool, safety *tools.ToolCallSafety, agentName string) { + e.events.Emit(ToolCallConfirmation(toolCall, tool, safety, agentName)) } func (e *sinkEmitter) EmitHookBlocked(toolCall tools.ToolCall, tool tools.Tool, message, agentName string) { diff --git a/pkg/runtime/toolexec/dispatcher.go b/pkg/runtime/toolexec/dispatcher.go index 64a1e6621..204727df6 100644 --- a/pkg/runtime/toolexec/dispatcher.go +++ b/pkg/runtime/toolexec/dispatcher.go @@ -74,7 +74,7 @@ type Emitter interface { EmitToolCall(toolCall tools.ToolCall, tool tools.Tool, agentName string) EmitToolCallOutput(toolCallID string, tool tools.Tool, output, agentName string) EmitToolCallResponse(toolCallID string, tool tools.Tool, result *tools.ToolCallResult, output, agentName string) - EmitToolCallConfirmation(toolCall tools.ToolCall, tool tools.Tool, agentName string) + EmitToolCallConfirmation(toolCall tools.ToolCall, tool tools.Tool, safety *tools.ToolCallSafety, agentName string) EmitHookBlocked(toolCall tools.ToolCall, tool tools.Tool, message, agentName string) EmitMessageAdded(sessionID string, msg *session.Message, agentName string) } @@ -332,6 +332,12 @@ func (c *call) approveAndRun(ctx context.Context, runTool func() CallOutcome) Ca checkers = c.d.Permissions(c.sess) } + // Forced asks from safety validators must bypass deterministic + // auto-approval paths such as --yolo and permission allow rules. + if safety := c.assessSafety(); safety != nil && safety.Destructive { + return c.askUser(ctx, runTool) + } + // readOnlyHint is intentionally false here so the pre_tool_use hook // gets a turn before the read-only fast-path applies. decision := Decide( @@ -374,6 +380,13 @@ func (c *call) approveAndRun(ctx context.Context, runTool func() CallOutcome) Ca return c.askUser(ctx, runTool) } +func (c *call) assessSafety() *tools.ToolCallSafety { + if c.tool.SafetyValidator == nil { + return nil + } + return c.tool.SafetyValidator(c.tc) +} + // consultPreToolUseHook fires the pre_tool_use hook chain in the // approval flow, before the user is asked. // @@ -506,12 +519,14 @@ func denySourceForChecker(checkerSource string) string { // with an explicit allow or deny verdict; returning nothing falls // through to the interactive confirmation. func (c *call) askUser(ctx context.Context, runTool func() CallOutcome) CallOutcome { - if outcome, handled := c.runPermissionRequestHook(ctx, runTool); handled { - return outcome + if safety := c.assessSafety(); safety == nil || !safety.Destructive { + if outcome, handled := c.runPermissionRequestHook(ctx, runTool); handled { + return outcome + } } slog.DebugContext(ctx, "Tools not approved, waiting for resume", "tool", c.tc.Function.Name, "session_id", c.sess.ID) - c.em.EmitToolCallConfirmation(c.tc, c.tool, c.a.Name()) + c.em.EmitToolCallConfirmation(c.tc, c.tool, c.assessSafety(), c.a.Name()) if c.d.Hooks != nil { c.d.Hooks.NotifyUserInput(ctx, c.sess.ID, "tool confirmation") diff --git a/pkg/runtime/toolexec/dispatcher_test.go b/pkg/runtime/toolexec/dispatcher_test.go index 99792c886..5bb55a3ae 100644 --- a/pkg/runtime/toolexec/dispatcher_test.go +++ b/pkg/runtime/toolexec/dispatcher_test.go @@ -61,7 +61,7 @@ func (e *captureEmitter) EmitToolCallResponse(toolCallID string, _ tools.Tool, r }) } -func (e *captureEmitter) EmitToolCallConfirmation(tc tools.ToolCall, _ tools.Tool, _ string) { +func (e *captureEmitter) EmitToolCallConfirmation(tc tools.ToolCall, _ tools.Tool, _ *tools.ToolCallSafety, _ string) { e.confirmations = append(e.confirmations, tc) if e.confirmed != nil { select { @@ -437,6 +437,41 @@ func TestDispatcher_DenyByPermissionsEmitsErrorResponse(t *testing.T) { assert.Contains(t, em.responses[0].Output, "denied by test policy") } +func TestDispatcher_SafetyValidatorForcesPromptDespiteYolo(t *testing.T) { + a := newAgent() + sess := session.New() + sess.ToolsApproved = true + + tool := tools.Tool{ + Name: "shell", + SafetyValidator: func(tools.ToolCall) *tools.ToolCallSafety { + return &tools.ToolCallSafety{Destructive: true, BlastRadius: tools.BlastRadiusHigh} + }, + Handler: func(context.Context, tools.ToolCall) (*tools.ToolCallResult, error) { + panic("must not run before approval") + }, + } + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + d := &toolexec.Dispatcher{AgentFor: func(*session.Session) *agent.Agent { return a }} + em := &captureEmitter{confirmed: make(chan struct{})} + go func() { + <-em.confirmed + cancel() + }() + + d.Process(ctx, sess, []tools.ToolCall{{ + ID: "danger", + Function: tools.FunctionCall{Name: "shell", Arguments: `{"cmd":"rm -rf /tmp/x"}`}, + }}, []tools.Tool{tool}, em) + + require.Len(t, em.confirmations, 1) + require.Len(t, em.responses, 1) + assert.True(t, em.responses[0].IsError) + assert.Contains(t, em.responses[0].Output, "canceled by the user") +} + // TestDispatcher_ToolResponseTransformRewritesOutput pins the contract // of the new tool_response_transform hook: when a configured hook // returns HookSpecificOutput.UpdatedToolResponse, the dispatcher diff --git a/pkg/tools/builtin/shell/judge.go b/pkg/tools/builtin/shell/judge.go new file mode 100644 index 000000000..926725c6f --- /dev/null +++ b/pkg/tools/builtin/shell/judge.go @@ -0,0 +1,199 @@ +package shell + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/tools" +) + +// LexicalSignals enumerates the high-precision destructive verbs that +// gate the residual LLM judge. A shell command containing any of these +// (case-insensitive substring match) AND classified as no-match by +// assessDestructiveShellCommand is the only path that escalates to a +// Judge.Refine call. The list is deliberately short: a false positive +// here costs an LLM call, a false negative leaks through to the default +// BlastRadiusUnknown handling — where the user is still gated. +// +// We err on the side of false negatives because the LLM judge is a +// defence-in-depth layer on top of the deterministic pattern set and +// the BlastRadiusUnknown fall-through; anything the lexical gate +// misses can still trip safer mode's catch-all confirmation. +var LexicalSignals = []string{ + "wipe", + "destroy", + "drop", + "purge", + "nuke", + "obliterate", + "erase", + "clobber", + "reset", + "annihilate", +} + +// Judge is the residual LLM-backed classifier consulted when the +// deterministic regex pass in assessDestructiveShellCommand returns no +// match but the command nevertheless looks possibly-destructive (i.e. +// passes shouldConsultJudge). +// +// Implementations MUST honour the context's deadline; the validator +// applies a hard 500 ms timeout via context.WithTimeout. On timeout, +// error, or a nil safety return, the validator falls back to the +// default BlastRadiusUnknown verdict — fail-closed semantics that keep +// the user gated when the judge can't decide. +// +// Returning a non-nil ToolCallSafety with Destructive=false is the +// only path that downgrades a possibly-destructive command to "safe to +// pass without confirmation"; callers should reserve that for commands +// the judge is confident are not destructive (e.g. a typed wrapper +// around `mv` that moves a file between two test directories). +type Judge interface { + Refine(ctx context.Context, cmd string) (*tools.ToolCallSafety, error) +} + +// shouldConsultJudge reports whether the residual judge should be +// invoked for a command the deterministic regex pass returned nil for. +// Returns true only when cmd contains at least one entry from +// LexicalSignals as a case-insensitive substring. +// +// This keeps Judge.Refine off the hot path: on a clean shell stream of +// inspection commands (docker ps / logs / inspect, build commands), no +// lexical signal trips and no LLM call ever fires. +func shouldConsultJudge(cmd string) bool { + lower := strings.ToLower(cmd) + for _, sig := range LexicalSignals { + if strings.Contains(lower, sig) { + return true + } + } + return false +} + +// ProviderJudge is the default Judge implementation. It wraps a model +// provider and asks it to classify the command using a tight prompt +// that returns a structured JSON verdict. +// +// Intended to be wired from the runtime against a small fast model +// (Haiku-class) so the residual path stays bounded around 200–500 ms. +// The judge issues a single non-streaming-ish completion per call (the +// underlying provider exposes only streaming completions, so the +// implementation drains the stream into a buffer and parses the +// trailing JSON object once Recv() returns io.EOF). +type ProviderJudge struct { + provider provider.Provider +} + +// NewProviderJudge wraps p as a Judge. The provider is expected to be +// pre-configured with a small fast model and reasonable max-tokens +// settings; ProviderJudge itself takes no further options. +func NewProviderJudge(p provider.Provider) *ProviderJudge { + return &ProviderJudge{provider: p} +} + +// Refine asks the LLM whether cmd should be treated as destructive, +// and at what blast-radius tier. Returns nil when the judge is +// uncertain, the response is unparseable, or the LLM emits an empty +// verdict — callers treat nil as "fall through to the default +// BlastRadiusUnknown gate". An error is returned for transport or +// provider failures so the caller can distinguish a soft "I don't +// know" (nil, nil) from a hard "the judge could not run" (nil, err). +func (j *ProviderJudge) Refine(ctx context.Context, cmd string) (*tools.ToolCallSafety, error) { + stream, err := j.provider.CreateChatCompletionStream(ctx, []chat.Message{ + {Role: chat.MessageRoleSystem, Content: judgeSystemPrompt}, + {Role: chat.MessageRoleUser, Content: cmd}, + }, nil) + if err != nil { + return nil, fmt.Errorf("safer judge: open stream: %w", err) + } + defer stream.Close() + + var sb strings.Builder + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("safer judge: stream recv: %w", err) + } + for _, c := range resp.Choices { + sb.WriteString(c.Delta.Content) + } + } + return parseJudgeVerdict(sb.String()), nil +} + +// judgeSystemPrompt is the single-turn instruction the judge sends to +// the LLM. Wording is deliberately narrow: classify, single JSON +// object, fail-closed defaults. Smaller models occasionally prepend +// thinking-style preambles, so the parser locates the trailing JSON +// rather than assuming the response is pure JSON. +const judgeSystemPrompt = `You are a strict classifier for shell commands. ` + + `Given a single command, output ONLY a JSON object with two fields: ` + + `"blast_radius" (one of: "low", "medium", "high", "unknown") and ` + + `"reason" (a one-sentence explanation, max 25 words). ` + + `Use "high" only for commands that destroy data irreversibly. ` + + `Use "medium" for commands that destroy state recoverable from caches or rebuilds. ` + + `Use "low" if the command is non-destructive (read-only or trivially reversible). ` + + `Use "unknown" if the destructiveness cannot be determined from the command alone. ` + + `Output the JSON object and nothing else.` + +// parseJudgeVerdict extracts the trailing JSON object from the LLM +// response and maps it to a ToolCallSafety. +// +// Returns nil for: +// - missing or unparseable JSON +// - blank blast_radius field +// - blast_radius "unknown" (we keep the deterministic +// BlastRadiusUnknown fall-through rather than overriding with a +// judge-provided Unknown that means the same thing) +// +// Returns a non-destructive verdict only on explicit "low". +func parseJudgeVerdict(response string) *tools.ToolCallSafety { + start := strings.LastIndex(response, "{") + if start < 0 { + return nil + } + var v struct { + BlastRadius string `json:"blast_radius"` + Reason string `json:"reason"` + } + if err := json.Unmarshal([]byte(response[start:]), &v); err != nil { + return nil + } + if strings.TrimSpace(v.BlastRadius) == "" { + return nil + } + radius := blastRadiusLevel(v.BlastRadius) + if radius == tools.BlastRadiusUnknown { + // The judge couldn't decide either — let the caller fall through + // to safer-mode's existing Unknown handling. Returning the same + // verdict from the judge would shadow the caller's reason + // string ("Shell command requires safer-mode confirmation.") + // with a less informative one. + return nil + } + reason := "Safer-mode LLM judge: " + v.Reason + if radius == tools.BlastRadiusLow { + // Explicit low → judge is confident this is safe; downgrade + // out of the destructive path entirely. The runtime's existing + // `if safety.Destructive` gate skips forced confirmation. + return &tools.ToolCallSafety{ + Destructive: false, + BlastRadius: tools.BlastRadiusLow, + Reason: reason, + } + } + return &tools.ToolCallSafety{ + Destructive: true, + BlastRadius: radius, + Reason: reason, + } +} diff --git a/pkg/tools/builtin/shell/judge_test.go b/pkg/tools/builtin/shell/judge_test.go new file mode 100644 index 000000000..61e63dd62 --- /dev/null +++ b/pkg/tools/builtin/shell/judge_test.go @@ -0,0 +1,292 @@ +package shell + +import ( + "context" + "errors" + "testing" + + "github.com/docker/docker-agent/pkg/tools" +) + +// fakeJudge is a controllable Judge used to exercise the +// ValidateShellToolCall integration in isolation. Tests set Safety and +// Err directly; CallCount records how many times Refine fires so +// gating assertions can prove the LLM path was (or wasn't) entered. +type fakeJudge struct { + Safety *tools.ToolCallSafety + Err error + CallCount int + LastCmd string +} + +func (f *fakeJudge) Refine(_ context.Context, cmd string) (*tools.ToolCallSafety, error) { + f.CallCount++ + f.LastCmd = cmd + return f.Safety, f.Err +} + +// TestShouldConsultJudgeTriggers locks down the gate semantics every +// consumer of the Judge interface relies on. The validator pays the +// LLM round-trip ONLY when shouldConsultJudge returns true; making +// any change to these cases a deliberate red-test event. +func TestShouldConsultJudgeTriggers(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + cmd string + want bool + }{ + {"lexical signal present", "bun run drop-db", true}, + {"no lexical signal", "docker ps -a", false}, + {"uppercase signal — case-insensitive", "make NUKE", true}, + {"signal inside a flag", "myscript --reset", true}, + {"empty command", "", false}, + {"build commands have no signal", "go build ./...", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := shouldConsultJudge(tc.cmd); got != tc.want { + t.Errorf("shouldConsultJudge(%q) = %v, want %v", tc.cmd, got, tc.want) + } + }) + } +} + +// TestValidateShellToolCallJudgeGating covers the integration with +// ValidateShellToolCall. Five branches matter: +// +// 1. Pattern match → deterministic safety wins, judge NEVER fires. +// 2. Pattern miss + no lexical signal → BlastRadiusUnknown fall-through, +// judge NEVER fires. +// 3. Pattern miss + lexical signal + judge returns refined safety → +// judge's verdict wins. +// 4. Pattern miss + lexical signal + judge returns (nil, nil) → +// BlastRadiusUnknown fall-through (judge was uncertain). +// 5. Pattern miss + lexical signal + judge errors → BlastRadiusUnknown +// fall-through (fail-closed posture). +func TestValidateShellToolCallJudgeGating(t *testing.T) { + t.Parallel() + + makeCall := func(cmd string) tools.ToolCall { + return tools.ToolCall{ + Function: tools.FunctionCall{ + Name: ToolNameShell, + Arguments: `{"cmd":` + jsonString(cmd) + `}`, + }, + } + } + + t.Run("pattern match — judge not consulted", func(t *testing.T) { + t.Parallel() + judge := &fakeJudge{ + Safety: &tools.ToolCallSafety{ + Destructive: true, + BlastRadius: tools.BlastRadiusHigh, + Reason: "judge said high", + }, + } + h := &shellHandler{safer: true, judge: judge} + + got := h.ValidateShellToolCall(makeCall("rm -rf ./data")) + + if got == nil || got.BlastRadius != tools.BlastRadiusHigh { + t.Fatalf("deterministic verdict expected, got %+v", got) + } + if judge.CallCount != 0 { + t.Errorf("judge was consulted on a pattern match (%d calls)", judge.CallCount) + } + }) + + t.Run("pattern miss without lexical signal — judge not consulted", func(t *testing.T) { + t.Parallel() + judge := &fakeJudge{ + Safety: &tools.ToolCallSafety{BlastRadius: tools.BlastRadiusHigh}, + } + h := &shellHandler{safer: true, judge: judge} + + got := h.ValidateShellToolCall(makeCall("uptime")) + + if got == nil || got.BlastRadius != tools.BlastRadiusUnknown { + t.Fatalf("BlastRadiusUnknown fall-through expected, got %+v", got) + } + if judge.CallCount != 0 { + t.Errorf("judge was consulted without a lexical signal (%d calls)", judge.CallCount) + } + }) + + t.Run("pattern miss + lexical signal + judge returns refined verdict", func(t *testing.T) { + t.Parallel() + judge := &fakeJudge{ + Safety: &tools.ToolCallSafety{ + Destructive: true, + BlastRadius: tools.BlastRadiusHigh, + Reason: "Safer-mode LLM judge: drops the dev database", + }, + } + h := &shellHandler{safer: true, judge: judge} + + got := h.ValidateShellToolCall(makeCall("bun run drop-db")) + + if got == nil || got.BlastRadius != tools.BlastRadiusHigh { + t.Fatalf("judge verdict expected to win, got %+v", got) + } + if judge.CallCount != 1 { + t.Errorf("judge should fire exactly once, got %d calls", judge.CallCount) + } + if judge.LastCmd != "bun run drop-db" { + t.Errorf("judge received command %q, want %q", judge.LastCmd, "bun run drop-db") + } + }) + + t.Run("pattern miss + lexical signal + judge uncertain (nil, nil) — falls through", func(t *testing.T) { + t.Parallel() + judge := &fakeJudge{Safety: nil, Err: nil} + h := &shellHandler{safer: true, judge: judge} + + got := h.ValidateShellToolCall(makeCall("make wipe")) + + if got == nil || got.BlastRadius != tools.BlastRadiusUnknown { + t.Fatalf("BlastRadiusUnknown fall-through expected, got %+v", got) + } + if judge.CallCount != 1 { + t.Errorf("judge should fire exactly once, got %d calls", judge.CallCount) + } + }) + + t.Run("pattern miss + lexical signal + judge errors — fail-closed", func(t *testing.T) { + t.Parallel() + judge := &fakeJudge{ + Safety: &tools.ToolCallSafety{BlastRadius: tools.BlastRadiusLow}, // would otherwise downgrade + Err: errors.New("provider timeout"), + } + h := &shellHandler{safer: true, judge: judge} + + got := h.ValidateShellToolCall(makeCall("./script --reset-everything")) + + if got == nil || got.BlastRadius != tools.BlastRadiusUnknown { + t.Fatalf("fail-closed: BlastRadiusUnknown expected, got %+v", got) + } + }) + + t.Run("safer off — neither pattern nor judge consulted", func(t *testing.T) { + t.Parallel() + judge := &fakeJudge{ + Safety: &tools.ToolCallSafety{BlastRadius: tools.BlastRadiusHigh}, + } + h := &shellHandler{safer: false, judge: judge} + + got := h.ValidateShellToolCall(makeCall("rm -rf ./data")) + + if got != nil { + t.Fatalf("safer off must return nil, got %+v", got) + } + if judge.CallCount != 0 { + t.Errorf("judge consulted with safer off (%d calls)", judge.CallCount) + } + }) +} + +// TestParseJudgeVerdict covers the response-shape contract the +// ProviderJudge relies on: trailing JSON extraction, blast-radius +// mapping, low → non-destructive downgrade, unknown/missing → +// fall-through (nil). +func TestParseJudgeVerdict(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + response string + wantNil bool + wantRadius tools.BlastRadiusLevel + wantDestruc bool + }{ + { + name: "high verdict", + response: `{"blast_radius":"high","reason":"drops the database"}`, + wantRadius: tools.BlastRadiusHigh, + wantDestruc: true, + }, + { + name: "medium verdict", + response: `{"blast_radius":"medium","reason":"clears caches"}`, + wantRadius: tools.BlastRadiusMedium, + wantDestruc: true, + }, + { + name: "low verdict — downgrades to non-destructive", + response: `{"blast_radius":"low","reason":"git status only"}`, + wantRadius: tools.BlastRadiusLow, + wantDestruc: false, + }, + { + name: "unknown verdict — caller falls through", + response: `{"blast_radius":"unknown","reason":"cannot tell"}`, + wantNil: true, + }, + { + name: "blank blast_radius — caller falls through", + response: `{"blast_radius":"","reason":"empty"}`, + wantNil: true, + }, + { + name: "no JSON in response — caller falls through", + response: `the command looks safe to me`, + wantNil: true, + }, + { + name: "thinking preamble then JSON — trailing object extracted", + response: "Let me think...\n{\"blast_radius\":\"high\",\"reason\":\"rm -rf /\"}", + wantRadius: tools.BlastRadiusHigh, + wantDestruc: true, + }, + { + name: "malformed JSON — caller falls through", + response: `{"blast_radius":"high"`, // missing closing brace + wantNil: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := parseJudgeVerdict(tc.response) + if tc.wantNil { + if got != nil { + t.Fatalf("expected nil verdict, got %+v", got) + } + return + } + if got == nil { + t.Fatalf("expected non-nil verdict, got nil") + } + if got.BlastRadius != tc.wantRadius { + t.Errorf("BlastRadius = %v, want %v", got.BlastRadius, tc.wantRadius) + } + if got.Destructive != tc.wantDestruc { + t.Errorf("Destructive = %v, want %v", got.Destructive, tc.wantDestruc) + } + }) + } +} + +// jsonString returns s quoted as a JSON string literal (handles +// embedded quotes / backslashes). We avoid pulling in encoding/json +// here because the test helper is one line either way. +func jsonString(s string) string { + var b []byte + b = append(b, '"') + for _, r := range s { + switch r { + case '"', '\\': + b = append(b, '\\') + b = append(b, byte(r)) + default: + b = append(b, []byte(string(r))...) + } + } + b = append(b, '"') + return string(b) +} diff --git a/pkg/tools/builtin/shell/safer.go b/pkg/tools/builtin/shell/safer.go new file mode 100644 index 000000000..22fbde42d --- /dev/null +++ b/pkg/tools/builtin/shell/safer.go @@ -0,0 +1,161 @@ +package shell + +import ( + _ "embed" + "encoding/json" + "fmt" + "regexp" + "strings" + "sync" + + "github.com/docker/docker-agent/pkg/tools" +) + +//go:embed safety_patterns.json +var safetyPatternsJSON []byte + +type safetyPattern struct { + Pattern string + BlastRadius tools.BlastRadiusLevel + regexp *regexp.Regexp +} + +type safetyPatternEntry struct { + Pattern string `json:"pattern"` + BlastRadius string `json:"blast_radius"` +} + +var loadSafetyPatterns = sync.OnceValues(func() ([]safetyPattern, error) { + var root any + if err := json.Unmarshal(safetyPatternsJSON, &root); err != nil { + return nil, fmt.Errorf("parse shell safety patterns: %w", err) + } + + entries := collectSafetyPatternEntries(root) + patterns := make([]safetyPattern, 0, len(entries)) + for _, entry := range entries { + pattern := normalizeCommand(entry.Pattern) + re, err := regexp.Compile(patternToRegexp(pattern)) + if err != nil { + return nil, fmt.Errorf("compile shell safety pattern %q: %w", entry.Pattern, err) + } + patterns = append(patterns, safetyPattern{ + Pattern: entry.Pattern, + BlastRadius: blastRadiusLevel(entry.BlastRadius), + regexp: re, + }) + } + return patterns, nil +}) + +func collectSafetyPatternEntries(value any) []safetyPatternEntry { + switch v := value.(type) { + case []any: + var entries []safetyPatternEntry + for _, item := range v { + entries = append(entries, collectSafetyPatternEntries(item)...) + } + return entries + case map[string]any: + if pattern, ok := v["pattern"].(string); ok { + if blastRadius, ok := v["blast_radius"].(string); ok { + return []safetyPatternEntry{{Pattern: pattern, BlastRadius: blastRadius}} + } + } + var entries []safetyPatternEntry + for _, item := range v { + entries = append(entries, collectSafetyPatternEntries(item)...) + } + return entries + default: + return nil + } +} + +func patternToRegexp(pattern string) string { + var b strings.Builder + b.WriteString(`(?i)(?:^|.*\b)`) + for i := 0; i < len(pattern); { + switch pattern[i] { + case '<': + if end := strings.IndexByte(pattern[i:], '>'); end >= 0 { + b.WriteString(`\S+`) + i += end + 1 + continue + } + case '.': + if strings.HasPrefix(pattern[i:], "...") { + b.WriteString(`.*`) + i += len("...") + continue + } + } + b.WriteString(regexp.QuoteMeta(string(pattern[i]))) + i++ + } + b.WriteString(`(?:$|\b.*)`) + return b.String() +} + +func blastRadiusLevel(raw string) tools.BlastRadiusLevel { + switch strings.ToUpper(strings.TrimSpace(raw)) { + case "LOW": + return tools.BlastRadiusLow + case "MEDIUM", "LOW-MEDIUM": + return tools.BlastRadiusMedium + case "HIGH", "MEDIUM-HIGH": + return tools.BlastRadiusHigh + default: + return tools.BlastRadiusUnknown + } +} + +func assessDestructiveShellCommand(command string) *tools.ToolCallSafety { + patterns, err := loadSafetyPatterns() + if err != nil { + return &tools.ToolCallSafety{ + Destructive: true, + BlastRadius: tools.BlastRadiusUnknown, + Reason: err.Error(), + } + } + + normalized := normalizeCommand(command) + var best *tools.ToolCallSafety + bestSeverity := 0 + for _, pattern := range patterns { + if !pattern.regexp.MatchString(normalized) { + continue + } + severity := blastRadiusSeverity(pattern.BlastRadius) + if severity <= bestSeverity { + continue + } + bestSeverity = severity + best = &tools.ToolCallSafety{ + Destructive: true, + BlastRadius: pattern.BlastRadius, + Reason: "Command matches destructive operation: " + pattern.Pattern, + } + } + return best +} + +func blastRadiusSeverity(level tools.BlastRadiusLevel) int { + switch level { + case tools.BlastRadiusHigh: + return 4 + case tools.BlastRadiusUnknown: + return 3 + case tools.BlastRadiusMedium: + return 2 + case tools.BlastRadiusLow: + return 1 + default: + return 0 + } +} + +func normalizeCommand(command string) string { + return strings.Join(strings.Fields(strings.ToLower(command)), " ") +} diff --git a/pkg/tools/builtin/shell/safer_test.go b/pkg/tools/builtin/shell/safer_test.go new file mode 100644 index 000000000..c297808a1 --- /dev/null +++ b/pkg/tools/builtin/shell/safer_test.go @@ -0,0 +1,74 @@ +package shell + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/tools" +) + +func TestAssessDestructiveShellCommand(t *testing.T) { + tests := []struct { + name string + command string + destructive bool + level tools.BlastRadiusLevel + }{ + {name: "rm rf", command: "rm -rf /tmp/x", destructive: true, level: tools.BlastRadiusHigh}, + {name: "rm recursive", command: "rm -r /tmp/x", destructive: true, level: tools.BlastRadiusHigh}, + {name: "rm force file", command: "rm -f /tmp/x", destructive: true, level: tools.BlastRadiusMedium}, + {name: "plain rm file", command: "rm /tmp/x", destructive: true, level: tools.BlastRadiusLow}, + {name: "find delete", command: "find . -delete", destructive: true, level: tools.BlastRadiusHigh}, + {name: "docker compose down volumes", command: "docker compose down --volumes", destructive: true, level: tools.BlastRadiusHigh}, + {name: "docker system prune", command: "docker system prune", destructive: true, level: tools.BlastRadiusMedium}, + {name: "git reset out of scope is embedded", command: "git reset --hard", destructive: true, level: tools.BlastRadiusHigh}, + {name: "normal command", command: "ls -la", destructive: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := assessDestructiveShellCommand(tt.command) + if !tt.destructive { + assert.Nil(t, got) + return + } + require.NotNil(t, got) + assert.True(t, got.Destructive) + assert.Equal(t, tt.level, got.BlastRadius) + }) + } +} + +func TestSafetyPatternsLoad(t *testing.T) { + patterns, err := loadSafetyPatterns() + require.NoError(t, err) + assert.Greater(t, len(patterns), 50) +} + +func TestValidateShellToolCallRespectsSaferFlag(t *testing.T) { + args, err := json.Marshal(RunShellArgs{Cmd: "rm -rf /tmp/x"}) + require.NoError(t, err) + call := tools.ToolCall{Function: tools.FunctionCall{Name: ToolNameShell, Arguments: string(args)}} + + assert.Nil(t, (&shellHandler{}).ValidateShellToolCall(call)) + + safety := (&shellHandler{safer: true}).ValidateShellToolCall(call) + require.NotNil(t, safety) + assert.True(t, safety.Destructive) + assert.Equal(t, tools.BlastRadiusHigh, safety.BlastRadius) +} + +func TestValidateShellToolCallSaferWarnsForUnmatchedCommand(t *testing.T) { + args, err := json.Marshal(RunShellArgs{Cmd: "ls -la"}) + require.NoError(t, err) + call := tools.ToolCall{Function: tools.FunctionCall{Name: ToolNameShell, Arguments: string(args)}} + + safety := (&shellHandler{safer: true}).ValidateShellToolCall(call) + require.NotNil(t, safety) + assert.True(t, safety.Destructive) + assert.Equal(t, tools.BlastRadiusUnknown, safety.BlastRadius) + assert.Contains(t, safety.Reason, "safer-mode") +} diff --git a/pkg/tools/builtin/shell/safety_patterns.json b/pkg/tools/builtin/shell/safety_patterns.json new file mode 100644 index 000000000..ece9297d3 --- /dev/null +++ b/pkg/tools/builtin/shell/safety_patterns.json @@ -0,0 +1,112 @@ +{ + "version": "0.1", + "description": "Seed destructive-command taxonomy for the Gordon safety spike. Each entry pairs a command shape with a blast_radius tier and a category tag. Consumed by the L2 classifier (Phase 2). See linked-doodling-brooks.md Appendix A for full context.", + "blast_radius_legend": { + "HIGH": "Force L3 approval gate; irreversible data loss", + "MEDIUM": "Surface warning in the response stream; not gated", + "LOW": "Pass through unchanged", + "UNKNOWN": "Default to MEDIUM until compound-command parsing or wrapper recursion resolves" + }, + "filesystem": [ + { "pattern": "rm -rf ", "blast_radius": "HIGH", "category": "fs-delete", "notes": "recursive + force; irreversible" }, + { "pattern": "rm -r ", "blast_radius": "HIGH", "category": "fs-delete", "notes": "recursive; aborts only on tty interrupt" }, + { "pattern": "rm -R ", "blast_radius": "HIGH", "category": "fs-delete", "notes": "alias of -r" }, + { "pattern": "rm -f ", "blast_radius": "MEDIUM", "category": "fs-delete", "notes": "force suppresses warning; ok for tmp" }, + { "pattern": "rm ", "blast_radius": "LOW", "category": "fs-delete", "notes": "prompts on tty" }, + { "pattern": "rmdir ", "blast_radius": "LOW", "category": "fs-delete", "notes": "empty-dir only" }, + { "pattern": "find -delete", "blast_radius": "HIGH", "category": "fs-delete", "notes": "recursive selective delete" }, + { "pattern": "find -exec rm {} \\;", "blast_radius": "HIGH", "category": "fs-delete", "notes": "recursive selective delete via exec" }, + { "pattern": "shred ", "blast_radius": "HIGH", "category": "fs-secure-delete", "notes": "non-recoverable by design" }, + { "pattern": "shred -u ", "blast_radius": "HIGH", "category": "fs-secure-delete", "notes": "overwrite + unlink" }, + { "pattern": "dd if=/dev/zero of=", "blast_radius": "HIGH", "category": "fs-overwrite", "notes": "zeroes file content" }, + { "pattern": "dd if=... of=/dev/", "blast_radius": "HIGH", "category": "block-device", "notes": "wipes whole disk" }, + { "pattern": "mkfs. ", "blast_radius": "HIGH", "category": "block-device", "notes": "formats partition" }, + { "pattern": "mkfs -t ", "blast_radius": "HIGH", "category": "block-device", "notes": "alt invocation of mkfs" }, + { "pattern": "wipefs ", "blast_radius": "HIGH", "category": "block-device", "notes": "clears filesystem signatures" }, + { "pattern": "> ", "blast_radius": "MEDIUM", "category": "fs-overwrite", "notes": "shell truncate redirect; easy to miss" }, + { "pattern": "truncate -s 0 ", "blast_radius": "MEDIUM", "category": "fs-overwrite", "notes": "explicit zero-length" }, + { "pattern": "mv -f ", "blast_radius": "MEDIUM", "category": "fs-modify", "notes": "force overwrites dst" }, + { "pattern": "mv ", "blast_radius": "MEDIUM", "category": "fs-modify", "notes": "overwrites without prompt on non-tty" }, + { "pattern": "cp -f ", "blast_radius": "MEDIUM", "category": "fs-modify", "notes": "force overwrite" }, + { "pattern": "cp -rf ", "blast_radius": "MEDIUM-HIGH", "category": "fs-modify", "notes": "recursive force overwrite" }, + { "pattern": "rsync --delete ", "blast_radius": "HIGH", "category": "fs-modify", "notes": "deletes from dst what's missing in src" }, + { "pattern": "ln -sf ", "blast_radius": "LOW-MEDIUM", "category": "fs-modify", "notes": "replaces existing file/symlink" }, + { "pattern": "chmod -R 000 ", "blast_radius": "MEDIUM", "category": "fs-permissions", "notes": "recursive lockout" }, + { "pattern": "chown -R ", "blast_radius": "MEDIUM", "category": "fs-permissions", "notes": "recursive ownership change" }, + { "pattern": "chmod -R 777 ", "blast_radius": "LOW", "category": "fs-permissions", "notes": "security regression, reversible" } + ], + "docker": [ + { "pattern": "docker volume rm ", "blast_radius": "HIGH", "category": "dk-volume-del", "notes": "irreversible data loss" }, + { "pattern": "docker volume prune", "blast_radius": "MEDIUM", "category": "dk-volume-del", "notes": "unused only; can surprise" }, + { "pattern": "docker volume prune -f", "blast_radius": "MEDIUM", "category": "dk-volume-del", "notes": "bypasses prompt" }, + { "pattern": "docker volume prune --force", "blast_radius": "MEDIUM", "category": "dk-volume-del", "notes": "long-form of -f" }, + { "pattern": "docker container rm ", "blast_radius": "MEDIUM", "category": "dk-container-del", "notes": "container only, image preserved" }, + { "pattern": "docker container rm -v ", "blast_radius": "HIGH", "category": "dk-container-del", "notes": "also drops attached volumes" }, + { "pattern": "docker rm -fv ", "blast_radius": "HIGH", "category": "dk-container-del", "notes": "force-kill + volumes" }, + { "pattern": "docker rm -f -v ", "blast_radius": "HIGH", "category": "dk-container-del", "notes": "expanded form" }, + { "pattern": "docker container prune", "blast_radius": "LOW-MEDIUM", "category": "dk-container-del", "notes": "stopped only" }, + { "pattern": "docker image rm ", "blast_radius": "LOW-MEDIUM", "category": "dk-image-del", "notes": "rebuildable; image cache loss" }, + { "pattern": "docker rmi ", "blast_radius": "LOW-MEDIUM", "category": "dk-image-del", "notes": "shorthand of image rm" }, + { "pattern": "docker image prune", "blast_radius": "LOW", "category": "dk-image-del", "notes": "dangling only by default" }, + { "pattern": "docker image prune -a", "blast_radius": "MEDIUM", "category": "dk-image-del", "notes": "all unused; high rebuild cost" }, + { "pattern": "docker network rm ", "blast_radius": "LOW-MEDIUM", "category": "dk-network-del", "notes": "breaks running connectivity" }, + { "pattern": "docker network prune", "blast_radius": "LOW-MEDIUM", "category": "dk-network-del", "notes": "unused only" }, + { "pattern": "docker system prune", "blast_radius": "MEDIUM", "category": "dk-multi-del", "notes": "containers + images + networks, NO volumes" }, + { "pattern": "docker system prune -a", "blast_radius": "MEDIUM-HIGH", "category": "dk-multi-del", "notes": "adds all unused images" }, + { "pattern": "docker system prune --volumes", "blast_radius": "HIGH", "category": "dk-multi-del", "notes": "adds named-volume deletion" }, + { "pattern": "docker system prune -af --volumes", "blast_radius": "HIGH", "category": "dk-multi-del", "notes": "the YOLO combination" }, + { "pattern": "docker buildx prune", "blast_radius": "MEDIUM", "category": "dk-build-cache", "notes": "rebuild cost only" }, + { "pattern": "docker buildx prune --all", "blast_radius": "MEDIUM-HIGH", "category": "dk-build-cache", "notes": "every cached layer" }, + { "pattern": "docker compose down", "blast_radius": "LOW-MEDIUM", "category": "dk-compose", "notes": "preserves named volumes" }, + { "pattern": "docker compose down -v", "blast_radius": "HIGH", "category": "dk-compose", "notes": "drops named volumes" }, + { "pattern": "docker compose down --volumes", "blast_radius": "HIGH", "category": "dk-compose", "notes": "long form of -v" }, + { "pattern": "docker compose down --volumes --remove-orphans", "blast_radius": "HIGH", "category": "dk-compose", "notes": "aggressive cleanup" }, + { "pattern": "docker compose rm -v ", "blast_radius": "MEDIUM-HIGH", "category": "dk-compose", "notes": "volume removal per service" }, + { "pattern": "docker context rm ", "blast_radius": "LOW", "category": "dk-context", "notes": "reversible if endpoint known" }, + { "pattern": "docker plugin rm ", "blast_radius": "MEDIUM", "category": "dk-plugin", "notes": "plugin state typically lost" }, + { "pattern": "docker stop $(docker ps -q)", "blast_radius": "LOW-MEDIUM", "category": "dk-runtime", "notes": "stops but doesn't delete" }, + { "pattern": "docker kill ", "blast_radius": "LOW", "category": "dk-runtime", "notes": "abrupt stop, no data loss" }, + { "pattern": "docker exec rm -rf ", "blast_radius": "HIGH", "category": "dk-exec-host", "notes": "filesystem op via bind mount hits host" }, + { "pattern": "docker run --rm -v $(pwd):/x rm -rf /x", "blast_radius": "HIGH", "category": "dk-bind-mount", "notes": "host-bind YOLO pattern" }, + { "pattern": "docker run --rm -v /:/host ", "blast_radius": "HIGH", "category": "dk-bind-mount", "notes": "mounts host root — almost always wrong" } + ], + "unknown_dynamic": [ + { "pattern": "docker rm $(docker ps -aq)", "blast_radius": "MEDIUM", "resolution_path": "unpack $(...) — if inner is list-all, escalate based on outer op" }, + { "pattern": "xargs rm -rf < ", "blast_radius": "MEDIUM", "resolution_path": "if is readable in the working dir, classify its contents" }, + { "pattern": "rm -rf $VAR", "blast_radius": "MEDIUM", "resolution_path": "unresolvable; MEDIUM unless var bound to a known-safe literal in same compound" }, + { "pattern": "find ... -exec ...", "blast_radius": "match-inner", "resolution_path": "classify the inner command pattern" } + ], + "out_of_scope_v1": { + "rationale": "Local Docker state is the spike's threat model. Add when Gordon grows these surfaces or when telemetry shows the gap matters.", + "sql": [ + { "pattern": "DROP DATABASE ", "blast_radius": "HIGH", "category": "sql-ddl" }, + { "pattern": "DROP TABLE ", "blast_radius": "HIGH", "category": "sql-ddl" }, + { "pattern": "DROP SCHEMA CASCADE", "blast_radius": "HIGH", "category": "sql-ddl" }, + { "pattern": "TRUNCATE TABLE ", "blast_radius": "HIGH", "category": "sql-dml" }, + { "pattern": "DELETE FROM (no WHERE)", "blast_radius": "HIGH", "category": "sql-dml" }, + { "pattern": "UPDATE
SET ... (no WHERE)", "blast_radius": "HIGH", "category": "sql-dml" } + ], + "git": [ + { "pattern": "git push --force", "blast_radius": "HIGH", "category": "git-history" }, + { "pattern": "git push --force-with-lease", "blast_radius": "MEDIUM", "category": "git-history" }, + { "pattern": "git reset --hard", "blast_radius": "MEDIUM-HIGH", "category": "git-local" }, + { "pattern": "git clean -fd", "blast_radius": "MEDIUM-HIGH", "category": "git-local" }, + { "pattern": "git branch -D ", "blast_radius": "MEDIUM", "category": "git-local" }, + { "pattern": "git filter-branch / git filter-repo", "blast_radius": "HIGH", "category": "git-history" }, + { "pattern": "git stash drop / git stash clear", "blast_radius": "LOW-MEDIUM", "category": "git-local" } + ], + "cloud_apis": [ + { "pattern": "gcloud projects delete ", "blast_radius": "HIGH", "category": "cloud-gcp" }, + { "pattern": "aws s3 rb s3:// --force", "blast_radius": "HIGH", "category": "cloud-aws" }, + { "pattern": "kubectl delete -A", "blast_radius": "HIGH", "category": "cloud-k8s" } + ], + "package_managers": [ + { "pattern": "npm uninstall -g ", "blast_radius": "LOW-MEDIUM", "category": "pkg-mgr" }, + { "pattern": "pip uninstall -y ", "blast_radius": "LOW-MEDIUM", "category": "pkg-mgr" } + ], + "remote_code_exec": [ + { "pattern": "curl | bash", "blast_radius": "MEDIUM-HIGH", "category": "remote-code-exec", "notes": "URL response can change between classification and execution; flag separately" }, + { "pattern": "wget -qO- | sh", "blast_radius": "MEDIUM-HIGH", "category": "remote-code-exec", "notes": "same TOCTOU concern as curl|bash" } + ] + } +} diff --git a/pkg/tools/builtin/shell/shell.go b/pkg/tools/builtin/shell/shell.go index 69f9298a8..3d3e2e9ac 100644 --- a/pkg/tools/builtin/shell/shell.go +++ b/pkg/tools/builtin/shell/shell.go @@ -54,8 +54,16 @@ type shellHandler struct { env []string timeout time.Duration workingDir string - jobs *concurrent.Map[string, *backgroundJob] - jobCounter atomic.Int64 + safer bool + // judge, when non-nil, is the residual LLM-backed classifier + // consulted by ValidateShellToolCall after the deterministic regex + // pass returns no match and the command trips a lexical destructive + // signal (see shouldConsultJudge / LexicalSignals). nil disables + // the LLM path entirely — safer mode falls back to its default + // BlastRadiusUnknown verdict in that case. + judge Judge + jobs *concurrent.Map[string, *backgroundJob] + jobCounter atomic.Int64 // sudoAskpass opts this toolset into the one-time sudo privilege // escalation flow (SUDO_ASKPASS bridged to the elicitation handler). @@ -146,6 +154,65 @@ type RunShellArgs struct { Timeout int `json:"timeout,omitempty" jsonschema:"Timeout in seconds (default 30)"` } +func (h *shellHandler) ValidateShellToolCall(toolCall tools.ToolCall) *tools.ToolCallSafety { + if !h.safer || toolCall.Function.Name != ToolNameShell { + return nil + } + + var params RunShellArgs + args := toolCall.Function.Arguments + if args == "" { + args = "{}" + } + if err := json.Unmarshal([]byte(args), ¶ms); err != nil { + return nil + } + if safety := assessDestructiveShellCommand(params.Cmd); safety != nil { + return safety + } + // Pattern miss. If a residual Judge is wired up AND the command + // contains a destructive lexical signal, ask the judge for a + // refined verdict before falling back to BlastRadiusUnknown. The + // judge call is gated behind shouldConsultJudge so a clean stream + // of inspection commands never pays an LLM round-trip. + // + // Fail-closed: timeout, transport error, or a nil safety return + // all leave the default Unknown verdict in place — the user is + // still gated, just without the refined blast-radius label. + // + // context.Background() with a hard 500 ms cap is intentional: the + // validator type signature does not (yet) carry a context. A future + // change can plumb the dispatcher's context through. + if h.judge != nil && shouldConsultJudge(params.Cmd) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + if safety, err := h.judge.Refine(ctx, params.Cmd); err == nil && safety != nil { + return safety + } + } + return &tools.ToolCallSafety{ + Destructive: true, + BlastRadius: tools.BlastRadiusUnknown, + Reason: "Shell command requires safer-mode confirmation.", + } +} + +// SetJudge installs the residual LLM Judge that ValidateShellToolCall +// consults when the deterministic regex pass returns no match for a +// command containing a destructive lexical signal. Passing nil +// disables the LLM path; safer mode then falls back to its default +// BlastRadiusUnknown verdict for every pattern miss. +// +// Intended to be called once at toolset wiring, after New: the runtime +// constructs a small-model-backed ProviderJudge from its model config +// and hands it off here. Calling SetJudge after the toolset has +// started serving traffic is safe but races with in-flight validator +// invocations; redo the wiring during agent reload rather than mid-run +// when avoidable. +func (t *ToolSet) SetJudge(j Judge) { + t.handler.judge = j +} + // UnmarshalJSON accepts both the canonical "cmd" key and the common alias // "command" for the shell command parameter. // @@ -521,6 +588,9 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi env = append(env, os.Environ()...) ts := New(env, runConfig) + if toolset.Safer != nil && *toolset.Safer { + ts.handler.safer = true + } if toolset.SudoAskpass != nil && *toolset.SudoAskpass { ts.handler.sudoAskpass = true } @@ -602,6 +672,7 @@ func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { Parameters: tools.MustSchemaFor[RunShellArgs](), OutputSchema: tools.MustSchemaFor[string](), Handler: tools.NewHandler(t.handler.RunShell), + SafetyValidator: t.handler.ValidateShellToolCall, Annotations: tools.ToolAnnotations{Title: "Shell"}, AddDescriptionParameter: true, }, diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index dc72b7a54..bdfed695a 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -180,9 +180,30 @@ type Tool struct { OutputSchema any `json:"outputSchema"` Handler ToolHandler `json:"-"` AddDescriptionParameter bool `json:"-"` + // SafetyValidator inspects a concrete tool call just before the approval + // pipeline. When it reports a destructive call, the runtime must ask the + // user even if permissions or --yolo would otherwise auto-approve it. + SafetyValidator ToolCallSafetyValidator `json:"-"` // ModelOverride is the per-toolset model for the LLM turn that processes // this tool's results. Set automatically from the toolset "model" field. ModelOverride string `json:"-"` } type ToolAnnotations mcp.ToolAnnotations + +type BlastRadiusLevel string + +const ( + BlastRadiusLow BlastRadiusLevel = "low" + BlastRadiusMedium BlastRadiusLevel = "medium" + BlastRadiusHigh BlastRadiusLevel = "high" + BlastRadiusUnknown BlastRadiusLevel = "unknown" +) + +type ToolCallSafety struct { + Destructive bool `json:"destructive,omitempty"` + BlastRadius BlastRadiusLevel `json:"blast_radius,omitempty"` + Reason string `json:"reason,omitempty"` +} + +type ToolCallSafetyValidator func(toolCall ToolCall) *ToolCallSafety diff --git a/pkg/tui/components/toolconfirm/toolconfirm.go b/pkg/tui/components/toolconfirm/toolconfirm.go index e7d95b4e0..8cf5550cb 100644 --- a/pkg/tui/components/toolconfirm/toolconfirm.go +++ b/pkg/tui/components/toolconfirm/toolconfirm.go @@ -21,11 +21,30 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -// User-facing strings of the confirmation prompt. -const ( - Title = "Tool Confirmation" - Question = "Do you want to allow this tool call?" -) +const DestructiveWarningTitle = "WARNING: Destructrive Tool Confirmation" + +// Title returns the user-facing dialog title for this confirmation. +func Title(safety *tools.ToolCallSafety) string { + if safety != nil && safety.Destructive { + return DestructiveWarningTitle + } + return "Tool Confirmation" +} + +func BlastRadiusLevel(safety *tools.ToolCallSafety) tools.BlastRadiusLevel { + if safety == nil || safety.BlastRadius == "" { + return tools.BlastRadiusUnknown + } + return safety.BlastRadius +} + +// Question returns the user-facing confirmation question. +func Question(safety *tools.ToolCallSafety) string { + if safety == nil || !safety.Destructive { + return "Do you want to allow this tool call?" + } + return "This is a destructive tool call with blast radius level: " + string(BlastRadiusLevel(safety)) + ". Do you want to allow it?" +} // Decision is the user's answer to a tool confirmation. type Decision int diff --git a/pkg/tui/components/toolconfirm/toolconfirm_test.go b/pkg/tui/components/toolconfirm/toolconfirm_test.go index 552dcb338..d1413cef8 100644 --- a/pkg/tui/components/toolconfirm/toolconfirm_test.go +++ b/pkg/tui/components/toolconfirm/toolconfirm_test.go @@ -98,6 +98,13 @@ func TestRejectionReasonsAreStable(t *testing.T) { assert.Equal(t, []string{"bad_args", "wrong_tool", "unsafe", "clarify"}, ids) } +func TestTitleAndQuestionForDestructiveTool(t *testing.T) { + safety := &tools.ToolCallSafety{Destructive: true, BlastRadius: tools.BlastRadiusHigh} + assert.Equal(t, DestructiveWarningTitle, Title(safety)) + assert.Contains(t, Question(safety), "destructive tool call") + assert.Contains(t, Question(safety), "blast radius level: high") +} + func TestKeyMapDecisionFor(t *testing.T) { t.Parallel() diff --git a/pkg/tui/dialog/tool_confirmation.go b/pkg/tui/dialog/tool_confirmation.go index aeb79a17f..98084a18e 100644 --- a/pkg/tui/dialog/tool_confirmation.go +++ b/pkg/tui/dialog/tool_confirmation.go @@ -6,6 +6,7 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tui/components/messages" "github.com/docker/docker-agent/pkg/tui/components/toolconfirm" "github.com/docker/docker-agent/pkg/tui/core" @@ -53,6 +54,40 @@ func (d *toolConfirmationDialog) dialogDimensions() (dialogWidth, contentWidth i return dialogWidth, contentWidth } +func renderBlastRadiusLevel(level tools.BlastRadiusLevel) string { + return blastRadiusLevelStyle(level).Render(string(level)) +} + +func blastRadiusLevelStyle(level tools.BlastRadiusLevel) lipgloss.Style { + switch level { + case tools.BlastRadiusLow: + return lipgloss.NewStyle().Foreground(styles.Success).Bold(true) + case tools.BlastRadiusMedium: + return lipgloss.NewStyle().Foreground(styles.Warning).Bold(true) + case tools.BlastRadiusHigh: + return lipgloss.NewStyle().Foreground(styles.Error).Bold(true) + default: + return lipgloss.NewStyle().Foreground(styles.TextSecondary).Bold(true) + } +} + +func renderConfirmationTitle(safety *tools.ToolCallSafety, contentWidth int) string { + style := styles.DialogTitleStyle.Width(contentWidth) + if safety != nil && safety.Destructive { + style = style.Foreground(styles.Warning) + } + return style.Render(toolconfirm.Title(safety)) +} + +func renderConfirmationQuestion(safety *tools.ToolCallSafety, contentWidth int) string { + if safety == nil || !safety.Destructive { + return styles.DialogQuestionStyle.Width(contentWidth).Render(toolconfirm.Question(safety)) + } + level := renderBlastRadiusLevel(toolconfirm.BlastRadiusLevel(safety)) + question := "This is a destructive tool call with blast radius level: " + level + ". Do you want to allow it?" + return styles.DialogQuestionStyle.Width(contentWidth).Render(question) +} + // SetSize implements [Dialog]. func (d *toolConfirmationDialog) SetSize(width, height int) tea.Cmd { d.BaseDialog.SetSize(width, height) @@ -62,14 +97,13 @@ func (d *toolConfirmationDialog) SetSize(width, height int) tea.Cmd { maxDialogHeight := height * toolConfirmDialogHeightPercent / 100 // Measure fixed UI elements using the same rendering as View() - titleStyle := styles.DialogTitleStyle.Width(contentWidth) - title := titleStyle.Render(toolconfirm.Title) + title := renderConfirmationTitle(d.msg.Safety, contentWidth) titleHeight := lipgloss.Height(title) separator := d.renderSeparator(contentWidth) separatorHeight := lipgloss.Height(separator) - question := styles.DialogQuestionStyle.Width(contentWidth).Render(toolconfirm.Question) + question := renderConfirmationQuestion(d.msg.Safety, contentWidth) questionHeight := lipgloss.Height(question) options := d.renderOptions(contentWidth) @@ -234,8 +268,7 @@ func (d *toolConfirmationDialog) View() string { dialogStyle := styles.DialogStyle.Width(dialogWidth) - titleStyle := styles.DialogTitleStyle.Width(contentWidth) - title := titleStyle.Render(toolconfirm.Title) + title := renderConfirmationTitle(d.msg.Safety, contentWidth) // Separator separator := d.renderSeparator(contentWidth) @@ -251,7 +284,7 @@ func (d *toolConfirmationDialog) View() string { } // Confirmation prompt - question := styles.DialogQuestionStyle.Width(contentWidth).Render(toolconfirm.Question) + question := renderConfirmationQuestion(d.msg.Safety, contentWidth) options := d.renderOptions(contentWidth) parts = append(parts, "", question, "", options) diff --git a/pkg/tui/dialog/tool_confirmation_test.go b/pkg/tui/dialog/tool_confirmation_test.go new file mode 100644 index 000000000..52004a600 --- /dev/null +++ b/pkg/tui/dialog/tool_confirmation_test.go @@ -0,0 +1,30 @@ +package dialog + +import ( + "testing" + + "github.com/charmbracelet/x/ansi" + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/tools" + "github.com/docker/docker-agent/pkg/tui/components/toolconfirm" +) + +func TestRenderConfirmationTitleWarnsForDestructiveTool(t *testing.T) { + safety := &tools.ToolCallSafety{Destructive: true, BlastRadius: tools.BlastRadiusHigh} + + rendered := renderConfirmationTitle(safety, 80) + + assert.Contains(t, ansi.Strip(rendered), toolconfirm.DestructiveWarningTitle) + assert.Contains(t, rendered, "\x1b[") +} + +func TestRenderConfirmationQuestionColorsBlastRadius(t *testing.T) { + safety := &tools.ToolCallSafety{Destructive: true, BlastRadius: tools.BlastRadiusHigh} + + rendered := renderConfirmationQuestion(safety, 80) + + plain := ansi.Strip(rendered) + assert.Contains(t, plain, "blast radius level: high") + assert.Contains(t, rendered, "\x1b[") +}