diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 903de9a80..66d4a2700 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -26,6 +26,7 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.agents.LlmAgent; import com.google.adk.events.Event; +import com.google.adk.plugins.Plugin; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; import com.google.adk.sessions.State; @@ -46,6 +47,7 @@ public class AgentTool extends BaseTool { private final BaseAgent agent; private final boolean skipSummarization; + private final boolean includePlugins; public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath) throws ConfigurationException { @@ -62,21 +64,34 @@ public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath) } BaseAgent agent = resolvedAgents.get(0); - return AgentTool.create(agent, args.getOrDefault("skipSummarization", false).booleanValue()); + return AgentTool.create( + agent, + args.getOrDefault("skipSummarization", false).booleanValue(), + args.getOrDefault("includePlugins", false).booleanValue()); + } + + public static AgentTool create( + BaseAgent agent, boolean skipSummarization, boolean includePlugins) { + return new AgentTool(agent, skipSummarization, includePlugins); } public static AgentTool create(BaseAgent agent, boolean skipSummarization) { - return new AgentTool(agent, skipSummarization); + return new AgentTool(agent, skipSummarization, /* includePlugins= */ false); } public static AgentTool create(BaseAgent agent) { - return new AgentTool(agent, false); + return new AgentTool(agent, /* skipSummarization= */ false, /* includePlugins= */ false); } protected AgentTool(BaseAgent agent, boolean skipSummarization) { + this(agent, skipSummarization, /* includePlugins= */ false); + } + + protected AgentTool(BaseAgent agent, boolean skipSummarization, boolean includePlugins) { super(agent.name(), agent.description()); this.agent = agent; this.skipSummarization = skipSummarization; + this.includePlugins = includePlugins; } @VisibleForTesting @@ -159,9 +174,11 @@ public Single> runAsync(Map args, ToolContex content = Content.fromParts(Part.fromText(input.toString())); } - Runner runner = new InMemoryRunner(this.agent, toolContext.agentName()); - // Session state is final, can't update to toolContext state - // session.toBuilder().setState(toolContext.getState()); + ImmutableList plugins = + this.includePlugins + ? ImmutableList.of(toolContext.invocationContext().pluginManager()) + : ImmutableList.of(); + Runner runner = new InMemoryRunner(this.agent, toolContext.agentName(), plugins); return runner .sessionService() .createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null) diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index f96e2bd17..b37db6611 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -28,6 +28,8 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.SequentialAgent; import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.Plugin; +import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; @@ -41,6 +43,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -704,6 +707,169 @@ public void declaration_emptySequentialAgent_fallsBackToRequest() { .build()); } + @Test + public void call_withIncludePluginsTrue_propagatesPlugins() throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = + AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ true); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isTrue(); + } + + @Test + public void call_withIncludePluginsFalse_doesNotPropagatePlugins() throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = + AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ false); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isFalse(); + } + + @Test + public void call_createWithAgentOnly_defaultsIncludePluginsToFalse() throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = AgentTool.create(testAgent); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isFalse(); + } + + @Test + public void call_createWithAgentAndSkipSummarization_defaultsIncludePluginsToFalse() + throws Exception { + AtomicBoolean callbackCalled = new AtomicBoolean(false); + Plugin mockPlugin = + new Plugin() { + @Override + public String getName() { + return "mock_plugin"; + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + callbackCalled.set(true); + return Maybe.empty(); + } + }; + LlmAgent testAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("agent_name") + .description("agent description") + .build(); + AgentTool agentTool = AgentTool.create(testAgent, /* skipSummarization= */ true); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(testAgent) + .session(session) + .sessionService(sessionService) + .pluginManager(new PluginManager(ImmutableList.of(mockPlugin))) + .build(); + ToolContext toolContext = ToolContext.builder(invocationContext).build(); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + assertThat(callbackCalled.get()).isFalse(); + } + private ToolContext createToolContext(BaseAgent agent) { Session session = sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();