ai-courseware/eino-project/internal/domain/workflow/zltx_product.go

86 lines
2.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package workflow
import (
"context"
"fmt"
"strings"
"eino-project/internal/domain/agent"
"eino-project/internal/domain/llm"
"eino-project/internal/pkg/adkutil"
"github.com/cloudwego/eino/compose"
)
type ZltxProductWorkflow struct {
models llm.LLM
}
type productItem struct {
ID string `json:"id"`
Name string `json:"name"`
Price int `json:"price"`
Description string `json:"description"`
}
type productQueryRes struct {
Items []*productItem `json:"items"`
Source string `json:"source"`
Count int `json:"count"`
}
func NewZltxProductWorkflow(models llm.LLM) *ZltxProductWorkflow {
return &ZltxProductWorkflow{models: models}
}
func (w *ZltxProductWorkflow) Run(ctx context.Context, message string) (string, error) {
g := compose.NewGraph[map[string]any, string]()
_ = g.AddLambdaNode("preprocess", compose.InvokableLambda(func(ctx context.Context, in map[string]any) (string, error) {
raw, _ := in["message"].(string)
q := strings.TrimSpace(raw)
return q, nil
}))
_ = g.AddLambdaNode("agent_call", compose.InvokableLambda(func(ctx context.Context, q string) (productQueryRes, error) {
ag := agent.NewProductChatAgent(ctx, w.models)
out, err := adkutil.QueryJSON[productQueryRes](ctx, ag, q)
if err != nil || out == nil {
return productQueryRes{}, err
}
return *out, nil
}))
_ = g.AddLambdaNode("describe", compose.InvokableLambda(func(ctx context.Context, res productQueryRes) (string, error) {
if res.Count <= 0 || len(res.Items) == 0 {
return "未找到相关商品", nil
}
var b strings.Builder
fmt.Fprintf(&b, "共%d条来源%s。", res.Count, res.Source)
for i, it := range res.Items {
if i == 0 {
fmt.Fprintf(&b, "首条:%s(编号%s),价格%d%s。", it.Name, it.ID, it.Price, it.Description)
continue
}
fmt.Fprintf(&b, "%s(编号%s),价格%d%s", it.Name, it.ID, it.Price, it.Description)
}
return b.String(), nil
}))
_ = g.AddEdge(compose.START, "preprocess")
_ = g.AddEdge("preprocess", "agent_call")
_ = g.AddEdge("agent_call", "describe")
_ = g.AddEdge("describe", compose.END)
r, err := g.Compile(ctx)
if err != nil {
return "", err
}
in := map[string]any{"message": message}
ret, err := r.Invoke(ctx, in)
if err != nil {
return "", err
}
return ret, nil
}