Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions go/api/client/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,20 @@ package client
import (
"context"
"fmt"
"net/url"

api "github.com/kagent-dev/kagent/go/api/httpapi"
"github.com/kagent-dev/kagent/go/api/v1alpha2"
)

// Agent defines the agent operations
type Agent interface {
ListAgents(ctx context.Context, opts ...ListAgentsOptions) (*api.StandardResponse[[]api.AgentResponse], error)
ListAgents(ctx context.Context) (*api.StandardResponse[[]api.AgentResponse], error)
CreateAgent(ctx context.Context, request *v1alpha2.Agent) (*api.StandardResponse[*v1alpha2.Agent], error)
GetAgent(ctx context.Context, agentRef string) (*api.StandardResponse[*api.AgentResponse], error)
UpdateAgent(ctx context.Context, request *v1alpha2.Agent) (*api.StandardResponse[*v1alpha2.Agent], error)
DeleteAgent(ctx context.Context, agentRef string) error
}

// ListAgentsOptions configures ListAgents requests.
type ListAgentsOptions struct {
Namespace string
}

// agentClient handles agent-related requests
type agentClient struct {
client *BaseClient
Expand All @@ -33,23 +27,14 @@ func NewAgentClient(client *BaseClient) Agent {
return &agentClient{client: client}
}

// ListAgents lists all agents for a user. When Namespace is set, only agents in that namespace are returned.
func (c *agentClient) ListAgents(ctx context.Context, opts ...ListAgentsOptions) (*api.StandardResponse[[]api.AgentResponse], error) {
if len(opts) > 1 {
return nil, fmt.Errorf("ListAgents accepts at most one options argument")
}

// ListAgents lists all agents for a user
func (c *agentClient) ListAgents(ctx context.Context) (*api.StandardResponse[[]api.AgentResponse], error) {
userID := c.client.GetUserIDOrDefault("")
if userID == "" {
return nil, fmt.Errorf("userID is required")
}

path := "/api/agents"
if len(opts) > 0 && opts[0].Namespace != "" {
path += "?namespace=" + url.QueryEscape(opts[0].Namespace)
}

resp, err := c.client.Get(ctx, path, userID)
resp, err := c.client.Get(ctx, "/api/agents", userID)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions go/api/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"context"
"time"

a2a "github.com/a2aproject/a2a-go/v2/a2a"
"github.com/kagent-dev/kagent/go/api/v1alpha2"
"github.com/pgvector/pgvector-go"
"trpc.group/trpc-go/trpc-a2a-go/protocol"
)

type QueryOptions struct {
Expand All @@ -24,8 +24,8 @@ type Client interface {
StoreFeedback(ctx context.Context, feedback *Feedback) error
StoreSession(ctx context.Context, session *Session) error
StoreAgent(ctx context.Context, agent *Agent) error
StoreTask(ctx context.Context, task *protocol.Task) error
StorePushNotification(ctx context.Context, config *protocol.TaskPushNotificationConfig) error
StoreTask(ctx context.Context, task *a2a.Task) error
StorePushNotification(ctx context.Context, config *a2a.PushConfig) error
StoreToolServer(ctx context.Context, toolServer *ToolServer) (*ToolServer, error)
StoreEvents(ctx context.Context, messages ...*Event) error

Expand All @@ -40,23 +40,23 @@ type Client interface {
// Get methods
GetSession(ctx context.Context, sessionID string, userID string) (*Session, error)
GetAgent(ctx context.Context, name string) (*Agent, error)
GetTask(ctx context.Context, id string) (*protocol.Task, error)
GetTask(ctx context.Context, id string) (*a2a.Task, error)
GetTool(ctx context.Context, name string) (*Tool, error)
GetToolServer(ctx context.Context, name string) (*ToolServer, error)
GetPushNotification(ctx context.Context, taskID string, configID string) (*protocol.TaskPushNotificationConfig, error)
GetPushNotification(ctx context.Context, taskID string, configID string) (*a2a.PushConfig, error)

// List methods
ListTools(ctx context.Context) ([]Tool, error)
ListFeedback(ctx context.Context, userID string) ([]Feedback, error)
ListTasksForSession(ctx context.Context, sessionID string) ([]*protocol.Task, error)
ListTasksForSession(ctx context.Context, sessionID string) ([]*a2a.Task, error)
ListSessions(ctx context.Context, userID string) ([]Session, error)
ListSessionsForAgent(ctx context.Context, agentID string, userID string) ([]Session, error)
ListSessionsForAgentAllUsers(ctx context.Context, agentID string) ([]Session, error)
ListAgents(ctx context.Context) ([]Agent, error)
ListToolServers(ctx context.Context) ([]ToolServer, error)
ListToolsForServer(ctx context.Context, serverName string, groupKind string) ([]Tool, error)
ListEventsForSession(ctx context.Context, sessionID, userID string, options QueryOptions) ([]*Event, error)
ListPushNotifications(ctx context.Context, taskID string) ([]*protocol.TaskPushNotificationConfig, error)
ListPushNotifications(ctx context.Context, taskID string) ([]*a2a.PushConfig, error)

// Helper methods
RefreshToolsForServer(ctx context.Context, serverName string, groupKind string, tools ...*v1alpha2.MCPTool) error
Expand Down
48 changes: 25 additions & 23 deletions go/api/database/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"encoding/json"
"time"

a2a "github.com/a2aproject/a2a-go/v2/a2a"
"github.com/kagent-dev/kagent/go/api/adk"
"github.com/kagent-dev/kagent/go/api/v1alpha2"
"github.com/pgvector/pgvector-go"
"trpc.group/trpc-go/trpc-a2a-go/protocol"
)

type Agent struct {
Expand All @@ -32,16 +32,16 @@ type Event struct {
Data string `json:"data"` // JSON-serialized protocol.Message
}

func (m *Event) Parse() (protocol.Message, error) {
var data protocol.Message
func (m *Event) Parse() (a2a.Message, error) {
var data a2a.Message
if err := json.Unmarshal([]byte(m.Data), &data); err != nil {
return protocol.Message{}, err
return a2a.Message{}, err
}
return data, nil
}

func ParseMessages(messages []Event) ([]*protocol.Message, error) {
result := make([]*protocol.Message, 0, len(messages))
func ParseMessages(messages []Event) ([]*a2a.Message, error) {
result := make([]*a2a.Message, 0, len(messages))
for _, message := range messages {
parsed, err := message.Parse()
if err != nil {
Expand Down Expand Up @@ -77,24 +77,25 @@ type Session struct {
}

type Task struct {
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty"`
Data string `json:"data"` // JSON-serialized task data
SessionID string `json:"session_id"`
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty"`
Data string `json:"data"` // JSON-serialized task data
ProtocolVersion *string `json:"protocol_version,omitempty"`
SessionID string `json:"session_id"`
}

func (t *Task) Parse() (protocol.Task, error) {
var data protocol.Task
func (t *Task) Parse() (a2a.Task, error) {
var data a2a.Task
if err := json.Unmarshal([]byte(t.Data), &data); err != nil {
return protocol.Task{}, err
return a2a.Task{}, err
}
return data, nil
}

func ParseTasks(tasks []Task) ([]*protocol.Task, error) {
result := make([]*protocol.Task, 0, len(tasks))
func ParseTasks(tasks []Task) ([]*a2a.Task, error) {
result := make([]*a2a.Task, 0, len(tasks))
for _, task := range tasks {
parsed, err := task.Parse()
if err != nil {
Expand All @@ -106,12 +107,13 @@ func ParseTasks(tasks []Task) ([]*protocol.Task, error) {
}

type PushNotification struct {
ID string `json:"id"`
TaskID string `json:"task_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty"`
Data string `json:"data"` // JSON-serialized push notification config
ID string `json:"id"`
TaskID string `json:"task_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty"`
Data string `json:"data"` // JSON-serialized push notification config
ProtocolVersion *string `json:"protocol_version,omitempty"`
}

// FeedbackIssueType represents the category of feedback issue
Expand Down
58 changes: 45 additions & 13 deletions go/core/internal/a2a/a2a_handler_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@ package a2a
import (
"fmt"
"net/http"
"slices"
"strings"
"sync"

a2atype "github.com/a2aproject/a2a-go/v2/a2a"
a2aclient "github.com/a2aproject/a2a-go/v2/a2aclient"
"github.com/a2aproject/a2a-go/v2/a2acompat/a2av0"
"github.com/a2aproject/a2a-go/v2/a2asrv"
"github.com/gorilla/mux"
authimpl "github.com/kagent-dev/kagent/go/core/internal/httpserver/auth"
common "github.com/kagent-dev/kagent/go/core/internal/utils"
"github.com/kagent-dev/kagent/go/core/pkg/auth"
"trpc.group/trpc-go/trpc-a2a-go/client"
"trpc.group/trpc-go/trpc-a2a-go/server"
)

// A2AHandlerMux is an interface that defines methods for adding, getting, and removing agentic task handlers.
type A2AHandlerMux interface {
SetAgentHandler(
agentRef string,
client *client.A2AClient,
card server.AgentCard,
tracing server.Middleware,
client *a2aclient.Client,
card a2atype.AgentCard,
tracing middleware,
) error
RemoveAgentHandler(
agentRef string,
Expand All @@ -38,6 +41,10 @@ type handlerMux struct {

var _ A2AHandlerMux = &handlerMux{}

type middleware interface {
Wrap(next http.Handler) http.Handler
}

func NewA2AHttpMux(agentPathPrefix, sandboxPathPrefix string, authenticator auth.AuthProvider) *handlerMux {
return &handlerMux{
handlers: make(map[string]http.Handler),
Expand All @@ -49,23 +56,48 @@ func NewA2AHttpMux(agentPathPrefix, sandboxPathPrefix string, authenticator auth

func (a *handlerMux) SetAgentHandler(
agentRef string,
client *client.A2AClient,
card server.AgentCard,
tracing server.Middleware,
client *a2aclient.Client,
card a2atype.AgentCard,
tracing middleware,
) error {
middlewares := []server.Middleware{authimpl.NewA2AAuthenticator(a.authenticator)}
// TODO(cleanup): Replace this protocol mux with the standard v1 handler stack once legacy clients/runtimes are unsupported.
requestHandler := NewPassthroughRequestHandler(client, &card)
legacyJSONRPCHandler := a2av0.NewJSONRPCHandler(requestHandler)
v1JSONRPCHandler := a2asrv.NewJSONRPCHandler(requestHandler)
cardHandler := a2asrv.NewAgentCardHandler(a2av0.NewStaticAgentCardProducer(&card))
wellKnownPath := "/" + strings.TrimPrefix(a2asrv.WellKnownAgentCardPath, "/")

var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, wellKnownPath) {
cardHandler.ServeHTTP(w, r)
return
}
wireVersion, err := common.NegotiateA2AWireVersion(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch wireVersion {
case common.A2AWireVersionLegacy:
legacyJSONRPCHandler.ServeHTTP(w, r)
case common.A2AWireVersionV1:
v1JSONRPCHandler.ServeHTTP(w, r)
default:
http.Error(w, fmt.Sprintf("unknown negotiated A2A wire version %q", wireVersion), http.StatusBadRequest)
}
})
middlewares := []middleware{authimpl.NewA2AAuthenticator(a.authenticator)}
if tracing != nil {
middlewares = append(middlewares, tracing)
}
srv, err := server.NewA2AServer(card, NewPassthroughManager(client), server.WithMiddleWare(middlewares...))
if err != nil {
return fmt.Errorf("failed to create A2A server: %w", err)
for _, middleware := range slices.Backward(middlewares) {
handler = middleware.Wrap(handler)
}

a.lock.Lock()
defer a.lock.Unlock()

a.handlers[agentRef] = srv.Handler()
a.handlers[agentRef] = handler

return nil
}
Expand Down
Loading
Loading