240 lines
6.1 KiB
Go
240 lines
6.1 KiB
Go
package adkutil
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
|
||
"github.com/cloudwego/eino/adk"
|
||
"github.com/cloudwego/eino/schema"
|
||
krlog "github.com/go-kratos/kratos/v2/log"
|
||
)
|
||
|
||
type Result struct {
|
||
Customized any
|
||
Message *schema.Message
|
||
}
|
||
|
||
func QueryJSON[T any](ctx context.Context, agent adk.Agent, query string) (*T, error) {
|
||
r, err := Query(ctx, agent, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if r.Customized != nil {
|
||
if v, ok := r.Customized.(*T); ok {
|
||
return v, nil
|
||
}
|
||
b, _ := json.Marshal(r.Customized)
|
||
var out T
|
||
if json.Unmarshal(b, &out) == nil {
|
||
return &out, nil
|
||
}
|
||
}
|
||
if r.Message != nil && r.Message.Content != "" {
|
||
var out T
|
||
if json.Unmarshal([]byte(r.Message.Content), &out) == nil {
|
||
return &out, nil
|
||
}
|
||
}
|
||
return nil, errors.New("agent output not match target type")
|
||
}
|
||
|
||
func QueryWithLogger(ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (Result, error) {
|
||
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent})
|
||
it := runner.Query(ctx, query)
|
||
var out Result
|
||
var lastErr error
|
||
if logger != nil {
|
||
logger.Infof("agent query start: %s", query)
|
||
}
|
||
for {
|
||
ev, ok := it.Next()
|
||
if !ok || ev == nil {
|
||
break
|
||
}
|
||
if logger != nil {
|
||
logger.Infof("agent event received: err=%v", ev.Err)
|
||
}
|
||
if ev.Err != nil {
|
||
lastErr = ev.Err
|
||
}
|
||
if ev.Output != nil {
|
||
if ev.Output.CustomizedOutput != nil {
|
||
out.Customized = ev.Output.CustomizedOutput
|
||
if logger != nil {
|
||
b, _ := json.Marshal(ev.Output.CustomizedOutput)
|
||
logger.Infof("agent customized output=%s", string(b))
|
||
}
|
||
}
|
||
if ev.Output.MessageOutput != nil {
|
||
msg, _ := ev.Output.MessageOutput.GetMessage()
|
||
if msg != nil {
|
||
out.Message = msg
|
||
if logger != nil {
|
||
logger.Infof("agent message role=%s content=%s", msg.Role, msg.Content)
|
||
if len(msg.ToolCalls) > 0 {
|
||
for _, tc := range msg.ToolCalls {
|
||
if tc.Function.Name != "" {
|
||
logger.Infof("agent tool call name=%s args=%s", tc.Function.Name, tc.Function.Arguments)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if out.Customized != nil || out.Message != nil {
|
||
return out, nil
|
||
}
|
||
if lastErr != nil {
|
||
return out, lastErr
|
||
}
|
||
return out, errors.New("agent no output")
|
||
}
|
||
|
||
func QueryJSONWithLogger[T any](ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (*T, error) {
|
||
r, err := QueryWithLogger(ctx, agent, query, logger)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if r.Customized != nil {
|
||
if v, ok := r.Customized.(*T); ok {
|
||
return v, nil
|
||
}
|
||
b, _ := json.Marshal(r.Customized)
|
||
var out T
|
||
if json.Unmarshal(b, &out) == nil {
|
||
return &out, nil
|
||
}
|
||
}
|
||
if r.Message != nil && r.Message.Content != "" {
|
||
var out T
|
||
if json.Unmarshal([]byte(r.Message.Content), &out) == nil {
|
||
return &out, nil
|
||
}
|
||
}
|
||
return nil, errors.New("agent output not match target type")
|
||
}
|
||
|
||
// Query 对 Agent 发起一次非流式请求,并提取统一结果
|
||
// - 优先返回工具的 CustomizedOutput(结构化输出)
|
||
// - 若无,则回退到最终的 MessageOutput(文本或可解析JSON)
|
||
// - 若均无,则返回错误(agent no output)
|
||
func Query(ctx context.Context, agent adk.Agent, query string) (Result, error) {
|
||
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent})
|
||
it := runner.Query(ctx, query)
|
||
var out Result
|
||
var lastErr error
|
||
for {
|
||
ev, ok := it.Next()
|
||
if !ok || ev == nil {
|
||
break
|
||
}
|
||
if ev.Err != nil {
|
||
lastErr = ev.Err
|
||
break
|
||
}
|
||
if ev.Output != nil {
|
||
if ev.Output.CustomizedOutput != nil {
|
||
out.Customized = ev.Output.CustomizedOutput
|
||
}
|
||
if ev.Output.MessageOutput != nil {
|
||
msg, _ := ev.Output.MessageOutput.GetMessage()
|
||
if msg != nil {
|
||
out.Message = msg
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if out.Customized != nil || out.Message != nil {
|
||
return out, nil
|
||
}
|
||
if lastErr != nil {
|
||
return out, lastErr
|
||
}
|
||
return out, errors.New("agent no output")
|
||
}
|
||
|
||
// ToPayload 将 Query 的结果转换为可直接返回的 payload
|
||
// - CustomizedOutput 直接返回
|
||
// - MessageOutput 尝试当作 JSON 解析,失败则包装为 {"message": text}
|
||
func ToPayload(res Result) any {
|
||
if res.Customized != nil {
|
||
return res.Customized
|
||
}
|
||
if res.Message != nil {
|
||
var obj any
|
||
if json.Unmarshal([]byte(res.Message.Content), &obj) == nil {
|
||
return obj
|
||
}
|
||
return map[string]any{"message": res.Message.Content}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Stream 以流式方式消费 Agent 输出,将 MessageOutput 的内容按片段写入 channel
|
||
// 用于 SSE/WS 等场景;工具直接返回通常会合并到最终消息
|
||
func Stream(ctx context.Context, agent adk.Agent, query string) (<-chan string, error) {
|
||
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent, EnableStreaming: true})
|
||
it := runner.Query(ctx, query)
|
||
ch := make(chan string, 8)
|
||
go func() {
|
||
defer close(ch)
|
||
for {
|
||
ev, ok := it.Next()
|
||
if !ok || ev == nil {
|
||
break
|
||
}
|
||
if ev.Output != nil && ev.Output.MessageOutput != nil {
|
||
msg, _ := ev.Output.MessageOutput.GetMessage()
|
||
if msg != nil && msg.Content != "" {
|
||
ch <- msg.Content
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
return ch, nil
|
||
}
|
||
|
||
// StreamWithLogger:记录事件以判断是否触发模型流式(MessageOutput多次)或仅工具输出(CustomizedOutput)
|
||
func StreamWithLogger(ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (<-chan string, error) {
|
||
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent, EnableStreaming: true})
|
||
it := runner.Query(ctx, query)
|
||
ch := make(chan string, 8)
|
||
go func() {
|
||
defer close(ch)
|
||
var msgCount int
|
||
for {
|
||
ev, ok := it.Next()
|
||
if !ok || ev == nil {
|
||
if logger != nil {
|
||
logger.Infof("stream finished, message_count=%d", msgCount)
|
||
}
|
||
break
|
||
}
|
||
|
||
logger.Infof("stream event: %v", ev)
|
||
|
||
if ev.Output != nil {
|
||
if ev.Output.CustomizedOutput != nil && logger != nil {
|
||
logger.Infof("customized output received")
|
||
}
|
||
if ev.Output.MessageOutput != nil {
|
||
msg, _ := ev.Output.MessageOutput.GetMessage()
|
||
if msg != nil {
|
||
msgCount++
|
||
if logger != nil {
|
||
logger.Infof("message event #%d role=%s len=%d", msgCount, msg.Role, len(msg.Content))
|
||
}
|
||
if msg.Content != "" {
|
||
ch <- msg.Content
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
return ch, nil
|
||
}
|