ai-courseware/eino-project/internal/pkg/adkutil/adkutil.go

199 lines
5.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package adkutil
import (
"context"
"encoding/json"
"errors"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
krlog "github.com/go-kratos/kratos/v2/log"
)
type Result struct {
Customized any
Message *schema.Message
}
func QueryJSON[T any](ctx context.Context, agent adk.Agent, query string) (*T, error) {
r, err := Query(ctx, agent, query)
if err != nil {
return nil, err
}
if r.Customized != nil {
if v, ok := r.Customized.(*T); ok {
return v, nil
}
b, _ := json.Marshal(r.Customized)
var out T
if json.Unmarshal(b, &out) == nil {
return &out, nil
}
}
if r.Message != nil && r.Message.Content != "" {
var out T
if json.Unmarshal([]byte(r.Message.Content), &out) == nil {
return &out, nil
}
}
return nil, errors.New("agent output not match target type")
}
func QueryWithLogger(ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (Result, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent})
it := runner.Query(ctx, query)
var out Result
var lastErr error
if logger != nil {
logger.Infof("agent query start: %s", query)
}
for {
ev, ok := it.Next()
if !ok || ev == nil {
break
}
if logger != nil {
logger.Infof("agent event received: err=%v", ev.Err)
}
if ev.Err != nil {
lastErr = ev.Err
}
if ev.Output != nil {
if ev.Output.CustomizedOutput != nil {
out.Customized = ev.Output.CustomizedOutput
if logger != nil {
b, _ := json.Marshal(ev.Output.CustomizedOutput)
logger.Infof("agent customized output=%s", string(b))
}
}
if ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil {
out.Message = msg
if logger != nil {
logger.Infof("agent message role=%s content=%s", msg.Role, msg.Content)
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
if tc.Function.Name != "" {
logger.Infof("agent tool call name=%s args=%s", tc.Function.Name, tc.Function.Arguments)
}
}
}
}
}
}
}
}
if out.Customized != nil || out.Message != nil {
return out, nil
}
if lastErr != nil {
return out, lastErr
}
return out, errors.New("agent no output")
}
func QueryJSONWithLogger[T any](ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (*T, error) {
r, err := QueryWithLogger(ctx, agent, query, logger)
if err != nil {
return nil, err
}
if r.Customized != nil {
if v, ok := r.Customized.(*T); ok {
return v, nil
}
b, _ := json.Marshal(r.Customized)
var out T
if json.Unmarshal(b, &out) == nil {
return &out, nil
}
}
if r.Message != nil && r.Message.Content != "" {
var out T
if json.Unmarshal([]byte(r.Message.Content), &out) == nil {
return &out, nil
}
}
return nil, errors.New("agent output not match target type")
}
// Query 对 Agent 发起一次非流式请求,并提取统一结果
// - 优先返回工具的 CustomizedOutput结构化输出
// - 若无,则回退到最终的 MessageOutput文本或可解析JSON
// - 若均无则返回错误agent no output
func Query(ctx context.Context, agent adk.Agent, query string) (Result, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent})
it := runner.Query(ctx, query)
var out Result
var lastErr error
for {
ev, ok := it.Next()
if !ok || ev == nil {
break
}
if ev.Err != nil {
lastErr = ev.Err
break
}
if ev.Output != nil {
if ev.Output.CustomizedOutput != nil {
out.Customized = ev.Output.CustomizedOutput
}
if ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil {
out.Message = msg
}
}
}
}
if out.Customized != nil || out.Message != nil {
return out, nil
}
if lastErr != nil {
return out, lastErr
}
return out, errors.New("agent no output")
}
// ToPayload 将 Query 的结果转换为可直接返回的 payload
// - CustomizedOutput 直接返回
// - MessageOutput 尝试当作 JSON 解析,失败则包装为 {"message": text}
func ToPayload(res Result) any {
if res.Customized != nil {
return res.Customized
}
if res.Message != nil {
var obj any
if json.Unmarshal([]byte(res.Message.Content), &obj) == nil {
return obj
}
return map[string]any{"message": res.Message.Content}
}
return nil
}
// Stream 以流式方式消费 Agent 输出,将 MessageOutput 的内容按片段写入 channel
// 用于 SSE/WS 等场景;工具直接返回通常会合并到最终消息
func Stream(ctx context.Context, agent adk.Agent, query string) (<-chan string, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent, EnableStreaming: true})
it := runner.Query(ctx, query)
ch := make(chan string, 8)
go func() {
defer close(ch)
for {
ev, ok := it.Next()
if !ok || ev == nil {
break
}
if ev.Output != nil && ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil && msg.Content != "" {
ch <- msg.Content
}
}
}
}()
return ch, nil
}