Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions core/src/main/java/com/google/adk/tools/AgentTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
import com.google.adk.agents.LlmAgent;
import com.google.adk.events.Event;
import com.google.adk.plugins.Plugin;
import com.google.adk.runner.InMemoryRunner;
import com.google.adk.runner.Runner;
import com.google.adk.sessions.State;
Expand All @@ -46,6 +47,7 @@ public class AgentTool extends BaseTool {

private final BaseAgent agent;
private final boolean skipSummarization;
private final boolean includePlugins;

public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath)
throws ConfigurationException {
Expand All @@ -62,21 +64,34 @@ public static BaseTool fromConfig(ToolArgsConfig args, String configAbsPath)
}

BaseAgent agent = resolvedAgents.get(0);
return AgentTool.create(agent, args.getOrDefault("skipSummarization", false).booleanValue());
return AgentTool.create(
agent,
args.getOrDefault("skipSummarization", false).booleanValue(),
args.getOrDefault("includePlugins", false).booleanValue());
}

public static AgentTool create(
BaseAgent agent, boolean skipSummarization, boolean includePlugins) {
return new AgentTool(agent, skipSummarization, includePlugins);
}

public static AgentTool create(BaseAgent agent, boolean skipSummarization) {
return new AgentTool(agent, skipSummarization);
return new AgentTool(agent, skipSummarization, /* includePlugins= */ false);
}

public static AgentTool create(BaseAgent agent) {
return new AgentTool(agent, false);
return new AgentTool(agent, /* skipSummarization= */ false, /* includePlugins= */ false);
}

protected AgentTool(BaseAgent agent, boolean skipSummarization) {
this(agent, skipSummarization, /* includePlugins= */ false);
}

protected AgentTool(BaseAgent agent, boolean skipSummarization, boolean includePlugins) {
super(agent.name(), agent.description());
this.agent = agent;
this.skipSummarization = skipSummarization;
this.includePlugins = includePlugins;
}

@VisibleForTesting
Expand Down Expand Up @@ -159,9 +174,11 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
content = Content.fromParts(Part.fromText(input.toString()));
}

Runner runner = new InMemoryRunner(this.agent, toolContext.agentName());
// Session state is final, can't update to toolContext state
// session.toBuilder().setState(toolContext.getState());
ImmutableList<Plugin> plugins =
this.includePlugins
? ImmutableList.of(toolContext.invocationContext().pluginManager())
: ImmutableList.of();
Runner runner = new InMemoryRunner(this.agent, toolContext.agentName(), plugins);
return runner
.sessionService()
.createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null)
Expand Down
166 changes: 166 additions & 0 deletions core/src/test/java/com/google/adk/tools/AgentToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.SequentialAgent;
import com.google.adk.models.LlmResponse;
import com.google.adk.plugins.Plugin;
import com.google.adk.plugins.PluginManager;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.Session;
import com.google.adk.testing.TestLlm;
Expand All @@ -41,6 +43,7 @@
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -704,6 +707,169 @@ public void declaration_emptySequentialAgent_fallsBackToRequest() {
.build());
}

@Test
public void call_withIncludePluginsTrue_propagatesPlugins() throws Exception {
AtomicBoolean callbackCalled = new AtomicBoolean(false);
Plugin mockPlugin =
new Plugin() {
@Override
public String getName() {
return "mock_plugin";
}

@Override
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
callbackCalled.set(true);
return Maybe.empty();
}
};
LlmAgent testAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("agent_name")
.description("agent description")
.build();
AgentTool agentTool =
AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ true);
Session session =
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
InvocationContext invocationContext =
InvocationContext.builder()
.invocationId(InvocationContext.newInvocationContextId())
.agent(testAgent)
.session(session)
.sessionService(sessionService)
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
.build();
ToolContext toolContext = ToolContext.builder(invocationContext).build();

Map<String, Object> unused =
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();

assertThat(callbackCalled.get()).isTrue();
}

@Test
public void call_withIncludePluginsFalse_doesNotPropagatePlugins() throws Exception {
AtomicBoolean callbackCalled = new AtomicBoolean(false);
Plugin mockPlugin =
new Plugin() {
@Override
public String getName() {
return "mock_plugin";
}

@Override
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
callbackCalled.set(true);
return Maybe.empty();
}
};
LlmAgent testAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("agent_name")
.description("agent description")
.build();
AgentTool agentTool =
AgentTool.create(testAgent, /* skipSummarization= */ false, /* includePlugins= */ false);
Session session =
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
InvocationContext invocationContext =
InvocationContext.builder()
.invocationId(InvocationContext.newInvocationContextId())
.agent(testAgent)
.session(session)
.sessionService(sessionService)
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
.build();
ToolContext toolContext = ToolContext.builder(invocationContext).build();

Map<String, Object> unused =
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();

assertThat(callbackCalled.get()).isFalse();
}

@Test
public void call_createWithAgentOnly_defaultsIncludePluginsToFalse() throws Exception {
AtomicBoolean callbackCalled = new AtomicBoolean(false);
Plugin mockPlugin =
new Plugin() {
@Override
public String getName() {
return "mock_plugin";
}

@Override
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
callbackCalled.set(true);
return Maybe.empty();
}
};
LlmAgent testAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("agent_name")
.description("agent description")
.build();
AgentTool agentTool = AgentTool.create(testAgent);
Session session =
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
InvocationContext invocationContext =
InvocationContext.builder()
.invocationId(InvocationContext.newInvocationContextId())
.agent(testAgent)
.session(session)
.sessionService(sessionService)
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
.build();
ToolContext toolContext = ToolContext.builder(invocationContext).build();

Map<String, Object> unused =
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();

assertThat(callbackCalled.get()).isFalse();
}

@Test
public void call_createWithAgentAndSkipSummarization_defaultsIncludePluginsToFalse()
throws Exception {
AtomicBoolean callbackCalled = new AtomicBoolean(false);
Plugin mockPlugin =
new Plugin() {
@Override
public String getName() {
return "mock_plugin";
}

@Override
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
callbackCalled.set(true);
return Maybe.empty();
}
};
LlmAgent testAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("agent_name")
.description("agent description")
.build();
AgentTool agentTool = AgentTool.create(testAgent, /* skipSummarization= */ true);
Session session =
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
InvocationContext invocationContext =
InvocationContext.builder()
.invocationId(InvocationContext.newInvocationContextId())
.agent(testAgent)
.session(session)
.sessionService(sessionService)
.pluginManager(new PluginManager(ImmutableList.of(mockPlugin)))
.build();
ToolContext toolContext = ToolContext.builder(invocationContext).build();

Map<String, Object> unused =
agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet();

assertThat(callbackCalled.get()).isFalse();
}

private ToolContext createToolContext(BaseAgent agent) {
Session session =
sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet();
Expand Down
Loading