ai-courseware/eino-project/internal/pkg/adkutil/adkutil.go

240 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package adkutil
import (
"context"
"encoding/json"
"errors"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
krlog "github.com/go-kratos/kratos/v2/log"
)
type Result struct {
Customized any
Message *schema.Message
}
func QueryJSON[T any](ctx context.Context, agent adk.Agent, query string) (*T, error) {
r, err := Query(ctx, agent, query)
if err != nil {
return nil, err
}
if r.Customized != nil {
if v, ok := r.Customized.(*T); ok {
return v, nil
}
b, _ := json.Marshal(r.Customized)
var out T
if json.Unmarshal(b, &out) == nil {
return &out, nil
}
}
if r.Message != nil && r.Message.Content != "" {
var out T
if json.Unmarshal([]byte(r.Message.Content), &out) == nil {
return &out, nil
}
}
return nil, errors.New("agent output not match target type")
}
func QueryWithLogger(ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (Result, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent})
it := runner.Query(ctx, query)
var out Result
var lastErr error
if logger != nil {
logger.Infof("agent query start: %s", query)
}
for {
ev, ok := it.Next()
if !ok || ev == nil {
break
}
if logger != nil {
logger.Infof("agent event received: err=%v", ev.Err)
}
if ev.Err != nil {
lastErr = ev.Err
}
if ev.Output != nil {
if ev.Output.CustomizedOutput != nil {
out.Customized = ev.Output.CustomizedOutput
if logger != nil {
b, _ := json.Marshal(ev.Output.CustomizedOutput)
logger.Infof("agent customized output=%s", string(b))
}
}
if ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil {
out.Message = msg
if logger != nil {
logger.Infof("agent message role=%s content=%s", msg.Role, msg.Content)
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
if tc.Function.Name != "" {
logger.Infof("agent tool call name=%s args=%s", tc.Function.Name, tc.Function.Arguments)
}
}
}
}
}
}
}
}
if out.Customized != nil || out.Message != nil {
return out, nil
}
if lastErr != nil {
return out, lastErr
}
return out, errors.New("agent no output")
}
func QueryJSONWithLogger[T any](ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (*T, error) {
r, err := QueryWithLogger(ctx, agent, query, logger)
if err != nil {
return nil, err
}
if r.Customized != nil {
if v, ok := r.Customized.(*T); ok {
return v, nil
}
b, _ := json.Marshal(r.Customized)
var out T
if json.Unmarshal(b, &out) == nil {
return &out, nil
}
}
if r.Message != nil && r.Message.Content != "" {
var out T
if json.Unmarshal([]byte(r.Message.Content), &out) == nil {
return &out, nil
}
}
return nil, errors.New("agent output not match target type")
}
// Query 对 Agent 发起一次非流式请求,并提取统一结果
// - 优先返回工具的 CustomizedOutput结构化输出
// - 若无,则回退到最终的 MessageOutput文本或可解析JSON
// - 若均无则返回错误agent no output
func Query(ctx context.Context, agent adk.Agent, query string) (Result, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent})
it := runner.Query(ctx, query)
var out Result
var lastErr error
for {
ev, ok := it.Next()
if !ok || ev == nil {
break
}
if ev.Err != nil {
lastErr = ev.Err
break
}
if ev.Output != nil {
if ev.Output.CustomizedOutput != nil {
out.Customized = ev.Output.CustomizedOutput
}
if ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil {
out.Message = msg
}
}
}
}
if out.Customized != nil || out.Message != nil {
return out, nil
}
if lastErr != nil {
return out, lastErr
}
return out, errors.New("agent no output")
}
// ToPayload 将 Query 的结果转换为可直接返回的 payload
// - CustomizedOutput 直接返回
// - MessageOutput 尝试当作 JSON 解析,失败则包装为 {"message": text}
func ToPayload(res Result) any {
if res.Customized != nil {
return res.Customized
}
if res.Message != nil {
var obj any
if json.Unmarshal([]byte(res.Message.Content), &obj) == nil {
return obj
}
return map[string]any{"message": res.Message.Content}
}
return nil
}
// Stream 以流式方式消费 Agent 输出,将 MessageOutput 的内容按片段写入 channel
// 用于 SSE/WS 等场景;工具直接返回通常会合并到最终消息
func Stream(ctx context.Context, agent adk.Agent, query string) (<-chan string, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent, EnableStreaming: true})
it := runner.Query(ctx, query)
ch := make(chan string, 8)
go func() {
defer close(ch)
for {
ev, ok := it.Next()
if !ok || ev == nil {
break
}
if ev.Output != nil && ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil && msg.Content != "" {
ch <- msg.Content
}
}
}
}()
return ch, nil
}
// StreamWithLogger记录事件以判断是否触发模型流式MessageOutput多次或仅工具输出CustomizedOutput
func StreamWithLogger(ctx context.Context, agent adk.Agent, query string, logger *krlog.Helper) (<-chan string, error) {
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: agent, EnableStreaming: true})
it := runner.Query(ctx, query)
ch := make(chan string, 8)
go func() {
defer close(ch)
var msgCount int
for {
ev, ok := it.Next()
if !ok || ev == nil {
if logger != nil {
logger.Infof("stream finished, message_count=%d", msgCount)
}
break
}
logger.Infof("stream event: %v", ev)
if ev.Output != nil {
if ev.Output.CustomizedOutput != nil && logger != nil {
logger.Infof("customized output received")
}
if ev.Output.MessageOutput != nil {
msg, _ := ev.Output.MessageOutput.GetMessage()
if msg != nil {
msgCount++
if logger != nil {
logger.Infof("message event #%d role=%s len=%d", msgCount, msg.Role, len(msg.Content))
}
if msg.Content != "" {
ch <- msg.Content
}
}
}
}
}
}()
return ch, nil
}