Compare commits

...

8 Commits

Author SHA1 Message Date
renzhiyuan ff81cddc46 Merge branch 'feature/fiber' into dev 2025-09-22 14:23:39 +08:00
renzhiyuan 3f37427181 结构修改 2025-09-22 14:22:16 +08:00
renzhiyuan 1412554661 结构修改 2025-09-22 11:46:42 +08:00
renzhiyuan 570e13e527 结构修改 2025-09-22 11:46:06 +08:00
renzhiyuan 32866d59a1 结构修改 2025-09-22 11:23:28 +08:00
renzhiyuan b81c9ef137 结构修改 2025-09-22 11:03:40 +08:00
renzhiyuan 0ce09611de Merge remote-tracking branch 'origin/feature/fiber' into feature/fiber 2025-09-22 09:50:08 +08:00
renzhiyuan 03f6130a6e 结构修改 2025-09-22 09:49:48 +08:00
12 changed files with 369 additions and 181 deletions

View File

@ -26,3 +26,9 @@ redis:
db: db:
driver: mysql driver: mysql
source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai
tools:
zltxOrderDetail:
enabled: true
base_url: "https://gateway.dev.cdlsxd.cn/zltx_api/admin/direct/ai/"
api_key: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ1c2VyQ2VudGVyIiwiZXhwIjoxNzU2MTgyNTM1LCJuYmYiOjE3NTYxODA3MzUsImp0aSI6IjEiLCJQaG9uZSI6IjE4MDAwMDAwMDAwIiwiVXNlck5hbWUiOiJsc3hkIiwiUmVhbE5hbWUiOiLotoXnuqfnrqHnkIblkZgiLCJBY2NvdW50VHlwZSI6MSwiR3JvdXBDb2RlcyI6IlZDTF9DQVNISUVSLFZDTF9PUEVSQVRFLFZDTF9BRE1JTixWQ0xfQUFBLFZDTF9WQ0xfT1BFUkFULFZDTF9JTlZPSUNFLENSTV9BRE1JTixMSUFOTElBTl9BRE1JTixNQVJLRVRNQUcyX0FETUlOLFBIT05FQklMTF9BRE1JTixRSUFOWkhVX1NVUFBFUl9BRE0sTUFSS0VUSU5HU0FBU19TVVBFUkFETUlOLENBUkRfQ09ERSxDQVJEX1BST0NVUkVNRU5ULE1BUktFVElOR1NZU1RFTV9TVVBFUixTVEFUSVNUSUNBTFNZU1RFTV9BRE1JTixaTFRYX0FETUlOLFpMVFhfT1BFUkFURSIsIkRpbmdVc2VySWQiOiIxNjIwMjYxMjMwMjg5MzM4MzQifQ.N1xv1PYbcO8_jR5adaczc16YzGsr4z101gwEZdulkRaREBJNYTOnFrvRxTFx3RJTooXsqTqroE1MR84v_1WPX6BS6kKonA-kC1Jgot6yrt5rFWhGNGb2Cpr9rKIFCCQYmiGd3AUgDazEeaQ0_sodv3E-EXg9VfE1SX8nMcck9Yjnc8NCy7RTWaBIaSeOdZcEl-JfCD0S6GSx3oErp_hk-U9FKGwf60wAuDGTY1R0BP4BYpcEqS-C2LSnsSGyURi54Cuk5xH8r1WuF0Dm5bwAj5d7Hvs77-N_sUF-C5ONqyZJRAEhYLgcmN9RX_WQZfizdQJxizlTczdpzYfy-v-1eQ"

View File

@ -2,25 +2,31 @@ package biz
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/constant"
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools" "ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp" "ai_scheduler/tmpl/dataTemp"
"context" "context"
"encoding/json" "encoding/json"
"log" "fmt"
"strings"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"xorm.io/builder" "xorm.io/builder"
) )
// AiRouterService 智能路由服务 // AiRouterBiz 智能路由服务
type AiRouterService struct { type AiRouterBiz struct {
//aiClient entitys.AIClient //aiClient entitys.AIClient
toolManager *tools.Manager toolManager *tools.Manager
sessionImpl *impl.SessionImpl sessionImpl *impl.SessionImpl
@ -41,8 +47,9 @@ func NewAiRouterBiz(
hisImpl *impl.ChatImpl, hisImpl *impl.ChatImpl,
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, utilAgent *utils_ollama.UtilOllama,
) entitys.RouterService {
return &AiRouterService{ ) *AiRouterBiz {
return &AiRouterBiz{
//aiClient: aiClient, //aiClient: aiClient,
toolManager: toolManager, toolManager: toolManager,
sessionImpl: sessionImpl, sessionImpl: sessionImpl,
@ -55,13 +62,13 @@ func NewAiRouterBiz(
} }
// Route 执行智能路由 // Route 执行智能路由
func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) { func (r *AiRouterBiz) Route(ctx context.Context, req *entitys.ChatRequest) (*entitys.ChatResponse, error) {
return nil, nil return nil, nil
} }
// Route 执行智能路由 // Route 执行智能路由
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error { func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
session := c.Headers("X-Session", "") session := c.Headers("X-Session", "")
if len(session) == 0 { if len(session) == 0 {
@ -91,26 +98,27 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
return errors.SystemError return errors.SystemError
} }
toolDefinitions := r.registerTools(task) //toolDefinitions := r.registerTools(task)
//prompt := r.getPrompt(sysInfo, history, req.Text) //prompt := r.getPrompt(sysInfo, history, req.Text)
//意图预测 //意图预测
//msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt), prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
//llms.WithTools(toolDefinitions), //llms.WithTools(toolDefinitions),
// //llms.WithToolChoice(llms.FunctionCallBehaviorAuto), //llms.WithToolChoice("tool_name"),
// llms.WithJSONMode(),
//)
prompt := r.getPromptLLM(sysInfo, history, req.Text)
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
llms.WithTools(toolDefinitions),
llms.WithToolChoice("tool_name"),
llms.WithJSONMode(), llms.WithJSONMode(),
) )
if err != nil { if err != nil {
return errors.SystemError return errors.SystemError
} }
log.Info(match)
c.WriteMessage(1, []byte(msg.Choices[0].Content)) var matchJson entitys.Match
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
if err != nil {
return errors.SystemError
}
return r.handleMatch(c, &matchJson, task)
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
// 构建消息 // 构建消息
//messages := []entitys.Message{ //messages := []entitys.Message{
// { // {
@ -193,7 +201,98 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
return nil return nil
} }
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { func (r *AiRouterBiz) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
defer func() {
if err != nil {
c.WriteMessage(websocket.TextMessage, []byte(err.Error()))
}
c.WriteMessage(websocket.TextMessage, []byte("EOF"))
}()
if !matchJson.IsMatch {
c.WriteMessage(websocket.TextMessage, []byte(matchJson.Reasoning))
return
}
var pointTask *model.AiTask
for _, task := range tasks {
if task.Index == matchJson.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, matchJson)
}
switch pointTask.Type {
case constant.TaskTypeApi:
err = r.handleApiTask(c, matchJson, pointTask)
case constant.TaskTypeFunc:
err = r.handleTask(c, matchJson, pointTask)
default:
return r.handleOtherTask(c, matchJson)
}
return
}
func (r *AiRouterBiz) handleTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var configData entitys.ConfigDataTool
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = r.toolManager.ExecuteTool(c, configData.Tool, []byte(matchJson.Parameters))
if err != nil {
return
}
return
}
func (r *AiRouterBiz) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
c.WriteMessage(1, []byte(matchJson.Reasoning))
return
}
func (r *AiRouterBiz) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (err error) {
var (
request l_request.Request
auth = c.Headers("X-Authorization", "")
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigDataHttp
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Request, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr("00022", "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
c.WriteMessage(1, res.Content)
return
}
func (r *AiRouterBiz) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId}) cond = cond.And(builder.Eq{"session_id": sessionId})
@ -203,7 +302,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha
return return
} }
func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err error) { func (r *AiRouterBiz) getSysInfo(appKey string) (sysInfo model.AiSy, err error) {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": appKey}) cond = cond.And(builder.Eq{"app_key": appKey})
cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.IsNull{"delete_at"})
@ -212,7 +311,7 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err
return return
} }
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) { func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond() cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.Eq{"sys_id": sysId})
@ -223,7 +322,7 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error
return return
} }
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, 0) taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks { for _, task := range tasks {
var taskConfig entitys.TaskConfig var taskConfig entitys.TaskConfig
@ -244,7 +343,7 @@ func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
return taskPrompt return taskPrompt
} }
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message { func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var ( var (
prompt = make([]entitys.Message, 0) prompt = make([]entitys.Message, 0)
) )
@ -261,7 +360,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
return prompt return prompt
} }
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent { func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
var ( var (
prompt = make([]llms.MessageContent, 0) prompt = make([]llms.MessageContent, 0)
) )
@ -275,6 +374,11 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
Parts: []llms.ContentPart{ Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
}, },
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{ }, llms.MessageContent{
Role: llms.ChatMessageTypeHuman, Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{ Parts: []llms.ContentPart{
@ -285,7 +389,7 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
} }
// buildSystemPrompt 构建系统提示词 // buildSystemPrompt 构建系统提示词
func (r *AiRouterService) buildSystemPrompt(prompt string) string { func (r *AiRouterBiz) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 { if len(prompt) == 0 {
prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容" prompt = "[system] 你是一个智能路由系统,核心职责是 **精准解析用户意图并路由至对应任务模块**\n[rule]\n1.返回以下格式的JSON{ \"index\": \"工具索引index\", \"confidence\": 0.0-1.0,\"reasoning\": \"判断理由\"}\n2.严格返回字符串格式禁用markdown格式返回\n3.只返回json字符串不包含任何其他解释性文字\n4.当用户意图非常不清晰时使用,尝试进行追问具体希望查询内容"
} }
@ -293,7 +397,7 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string {
return prompt return prompt
} }
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) { func (r *AiRouterBiz) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his { for _, item := range his {
if len(chatHis.SessionId) == 0 { if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID chatHis.SessionId = item.SessionID
@ -311,61 +415,8 @@ func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.
return return
} }
// extractIntent 从AI响应中提取意图
func (r *AiRouterService) extractIntent(response *entitys.ChatResponse) string {
if response == nil || response.Message == "" {
return ""
}
// 尝试解析JSON
var intent struct {
Intent string `json:"intent"`
Confidence string `json:"confidence"`
Reasoning string `json:"reasoning"`
}
err := json.Unmarshal([]byte(response.Message), &intent)
if err != nil {
log.Printf("Failed to parse intent JSON: %v", err)
return ""
}
return intent.Intent
}
// handleOrderDiagnosis 处理订单诊断意图
func (r *AiRouterService) handleOrderDiagnosis(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
// 调用订单详情工具
//orderDetailTool, ok := r.toolManager.GetTool("zltxOrderDetail")
//if orderDetailTool == nil || !ok {
// return nil, fmt.Errorf("order detail tool not found")
//}
//orderDetailTool.Execute(ctx, json.RawMessage{})
//
//// 获取相关工具定义
//toolDefinitions := r.toolManager.GetToolDefinitions(constants.Caller(req.Caller))
//
//// 调用AI获取是否需要使用工具
//response, err := r.aiClient.Chat(ctx, messages, toolDefinitions)
//if err != nil {
// return nil, fmt.Errorf("failed to chat with AI: %w", err)
//}
//
//// 如果没有工具调用,直接返回
//if len(response.ToolCalls) == 0 {
// return response, nil
//}
//
//// 执行工具调用
//toolResults, err := r.toolManager.ExecuteToolCalls(ctx, response.ToolCalls)
//if err != nil {
// return nil, fmt.Errorf("failed to execute tools: %w", err)
//}
return nil, nil
}
// handleKnowledgeQA 处理知识问答意图 // handleKnowledgeQA 处理知识问答意图
func (r *AiRouterService) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) { func (r *AiRouterBiz) handleKnowledgeQA(ctx context.Context, req *entitys.ChatRequest, messages []entitys.Message) (*entitys.ChatResponse, error) {
return nil, nil return nil, nil
} }

View File

@ -0,0 +1,84 @@
package biz
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/utils"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/gofiber/fiber/v2/log"
)
func Test_task(t *testing.T) {
var c entitys.TaskConfig
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
err := json.Unmarshal([]byte(config), &c)
t.Log(err)
}
type configData struct {
Param map[string]interface{} `json:"param"`
Do map[string]interface{} `json:"do"`
}
func Test_Order(t *testing.T) {
routerBiz := in()
err := routerBiz.handleTask(nil, &entitys.Match{Index: "order_diagnosis", Parameters: `{"order_number":"12312312312"}`}, &model.AiTask{Config: `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`})
t.Log(err)
}
func in() *AiRouterBiz {
modDir, err := getModuleDir()
if err != nil {
panic("1")
}
configPath := flag.String("config", fmt.Sprintf("%s/config/config.yaml", modDir), "Path to configuration file")
flag.Parse()
configConfig, err := config.LoadConfig(*configPath)
if err != nil {
panic("加载配置失败")
}
allLogger := log.DefaultLogger()
manager := tools.NewManager(configConfig)
db, _ := utils.NewGormDb(configConfig)
sessionImpl := impl.NewSessionImpl(db)
sysImpl := impl.NewSysImpl(db)
taskImpl := impl.NewTaskImpl(db)
chatImpl := impl.NewChatImpl(db)
utilOllama := utils_ollama.NewUtilOllama(configConfig, allLogger)
routerBiz := NewAiRouterBiz(manager, sessionImpl, sysImpl, taskImpl, chatImpl, configConfig, utilOllama)
return routerBiz
}
func getModuleDir() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
modPath := filepath.Join(dir, "go.mod")
if _, err := os.Stat(modPath); err == nil {
return dir, nil // 找到 go.mod
}
// 向上查找父目录
parent := filepath.Dir(dir)
if parent == dir {
break // 到达根目录,未找到
}
dir = parent
}
return "", fmt.Errorf("go.mod not found in current directory or parents")
}

View File

@ -76,7 +76,6 @@ type ToolConfig struct {
Enabled bool `mapstructure:"enabled"` Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"` BaseURL string `mapstructure:"base_url"`
APIKey string `mapstructure:"api_key"` APIKey string `mapstructure:"api_key"`
BizSystem string `mapstructure:"biz_system"`
} }
// LoggingConfig 日志配置 // LoggingConfig 日志配置

View File

@ -11,6 +11,7 @@ const (
type TaskType int32 type TaskType int32
const ( const (
TaskTypeApi ConnStatus = iota + 1 TaskTypeApi = 1
TaskTypeKnowle TaskTypeKnowle = 2
TaskTypeFunc = 3
) )

View File

@ -66,12 +66,17 @@ type Tool interface {
Name() string Name() string
Description() string Description() string
Definition() ToolDefinition Definition() ToolDefinition
Execute(ctx context.Context, args json.RawMessage) (interface{}, error) Execute(c *websocket.Conn, args json.RawMessage) error
} }
// AIClient AI客户端接口 type ConfigDataHttp struct {
type AIClient interface { Param map[string]interface{} `json:"param"`
Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error) Request map[string]interface{} `json:"request"`
}
type ConfigDataTool struct {
Param map[string]interface{} `json:"param"`
Tool string `json:"tool"`
} }
// Message 消息 // Message 消息
@ -88,7 +93,13 @@ type FuncApi struct {
type TaskConfig struct { type TaskConfig struct {
Param interface{} `json:"param"` Param interface{} `json:"param"`
} }
type Match struct {
Confidence float64 `json:"confidence"`
Index string `json:"index"`
IsMatch bool `json:"is_match"`
Parameters string `json:"parameters"`
Reasoning string `json:"reasoning"`
}
type ChatHis struct { type ChatHis struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Messages []HisMessage `json:"messages"` Messages []HisMessage `json:"messages"`

View File

@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
ollama.WithModel(c.Ollama.Model), ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient), ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)), ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("1h"), ollama.WithKeepAlive("-1s"),
) )
if err != nil { if err != nil {
logger.Fatal(err) logger.Fatal(err)

View File

@ -1,6 +1,7 @@
package services package services
import ( import (
"ai_scheduler/internal/biz"
"ai_scheduler/internal/data/constant" "ai_scheduler/internal/data/constant"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway" "ai_scheduler/internal/gateway"
@ -17,15 +18,15 @@ import (
// ChatHandler 聊天处理器 // ChatHandler 聊天处理器
type ChatService struct { type ChatService struct {
routerService entitys.RouterService routerBiz *biz.AiRouterBiz
Gw *gateway.Gateway Gw *gateway.Gateway
mu sync.Mutex mu sync.Mutex
} }
// NewChatHandler 创建聊天处理器 // NewChatHandler 创建聊天处理器
func NewChatService(routerService entitys.RouterService, gw *gateway.Gateway) *ChatService { func NewChatService(routerService *biz.AiRouterBiz, gw *gateway.Gateway) *ChatService {
return &ChatService{ return &ChatService{
routerService: routerService, routerBiz: routerService,
Gw: gw, Gw: gw,
} }
} }
@ -100,7 +101,7 @@ func (h *ChatService) Chat(c *websocket.Conn) {
log.Println("JSON parse error:", err) log.Println("JSON parse error:", err)
continue continue
} }
err = h.routerService.RouteWithSocket(c, &req) err = h.routerBiz.RouteWithSocket(c, &req)
if err != nil { if err != nil {
log.Println("处理失败:", err) log.Println("处理失败:", err)
continue continue

View File

@ -1,14 +0,0 @@
package test
import (
"ai_scheduler/internal/entitys"
"encoding/json"
"testing"
)
func Test_task(t *testing.T) {
var c entitys.TaskConfig
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
err := json.Unmarshal([]byte(config), &c)
t.Log(err)
}

View File

@ -0,0 +1,41 @@
package test
import (
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/mapstructure"
"encoding/json"
"testing"
"gitea.cdlsxd.cn/self-tools/l_request"
)
func Test_task(t *testing.T) {
var c entitys.TaskConfig
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
err := json.Unmarshal([]byte(config), &c)
t.Log(err)
}
type configData struct {
Param map[string]interface{} `json:"param"`
Do map[string]interface{} `json:"do"`
}
func Test_task2(t *testing.T) {
var (
c l_request.Request
config configData
)
configJson := `{"tool": "zltxOrderDetail", "param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}}`
err := json.Unmarshal([]byte(configJson), &config)
if err != nil {
panic(err)
}
mapstructure.Decode(config.Do, &c)
t.Log(err)
}
func in() {
}

View File

@ -5,9 +5,10 @@ import (
"ai_scheduler/internal/data/constants" "ai_scheduler/internal/data/constants"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gofiber/websocket/v2"
) )
// Manager 工具管理器 // Manager 工具管理器
@ -22,16 +23,16 @@ func NewManager(config *config.Config) *Manager {
} }
// 注册天气工具 // 注册天气工具
if config.Tools.Weather.Enabled { //if config.Tools.Weather.Enabled {
weatherTool := NewWeatherTool() // weatherTool := NewWeatherTool()
m.tools[weatherTool.Name()] = weatherTool // m.tools[weatherTool.Name()] = weatherTool
} //}
//
// 注册计算器工具 //// 注册计算器工具
if config.Tools.Calculator.Enabled { //if config.Tools.Calculator.Enabled {
calcTool := NewCalculatorTool() // calcTool := NewCalculatorTool()
m.tools[calcTool.Name()] = calcTool // m.tools[calcTool.Name()] = calcTool
} //}
// 注册知识库工具 // 注册知识库工具
// if config.Knowledge.Enabled { // if config.Knowledge.Enabled {
@ -80,43 +81,43 @@ func (m *Manager) GetToolDefinitions(caller constants.Caller) []entitys.ToolDefi
} }
// ExecuteTool 执行工具 // ExecuteTool 执行工具
func (m *Manager) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (interface{}, error) { func (m *Manager) ExecuteTool(c *websocket.Conn, name string, args json.RawMessage) error {
tool, exists := m.GetTool(name) tool, exists := m.GetTool(name)
if !exists { if !exists {
return nil, fmt.Errorf("tool not found: %s", name) return fmt.Errorf("tool not found: %s", name)
} }
return tool.Execute(ctx, args) return tool.Execute(c, args)
} }
// ExecuteToolCalls 执行多个工具调用 // ExecuteToolCalls 执行多个工具调用
func (m *Manager) ExecuteToolCalls(ctx context.Context, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) { //func (m *Manager) ExecuteToolCalls(c *websocket.Conn, toolCalls []entitys.ToolCall) ([]entitys.ToolCall, error) {
results := make([]entitys.ToolCall, len(toolCalls)) // results := make([]entitys.ToolCall, len(toolCalls))
//
for i, toolCall := range toolCalls { // for i, toolCall := range toolCalls {
results[i] = toolCall // results[i] = toolCall
//
// 执行工具 // // 执行工具
result, err := m.ExecuteTool(ctx, toolCall.Function.Name, toolCall.Function.Arguments) // err := m.ExecuteTool(c, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil { // if err != nil {
// 将错误信息作为结果返回 // // 将错误信息作为结果返回
errorResult := map[string]interface{}{ // errorResult := map[string]interface{}{
"error": err.Error(), // "error": err.Error(),
} // }
resultBytes, _ := json.Marshal(errorResult) // resultBytes, _ := json.Marshal(errorResult)
results[i].Result = resultBytes // results[i].Result = resultBytes
} else { // } else {
// 将成功结果序列化 // // 将成功结果序列化
resultBytes, err := json.Marshal(result) // resultBytes, err := json.Marshal(result)
if err != nil { // if err != nil {
errorResult := map[string]interface{}{ // errorResult := map[string]interface{}{
"error": fmt.Sprintf("failed to serialize result: %v", err), // "error": fmt.Sprintf("failed to serialize result: %v", err),
} // }
resultBytes, _ = json.Marshal(errorResult) // resultBytes, _ = json.Marshal(errorResult)
} // }
results[i].Result = resultBytes // results[i].Result = resultBytes
} // }
} // }
//
return results, nil // return results, nil
} //}

View File

@ -4,10 +4,11 @@ import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2"
) )
// ZltxOrderDetailTool 直连天下订单详情工具 // ZltxOrderDetailTool 直连天下订单详情工具
@ -53,7 +54,7 @@ func (w *ZltxOrderDetailTool) Definition() entitys.ToolDefinition {
// ZltxOrderDetailRequest 直连天下订单详情请求参数 // ZltxOrderDetailRequest 直连天下订单详情请求参数
type ZltxOrderDetailRequest struct { type ZltxOrderDetailRequest struct {
Number string `json:"number"` OrderNumber string `json:"order_number"`
} }
// ZltxOrderDetailResponse 直连天下订单详情响应 // ZltxOrderDetailResponse 直连天下订单详情响应
@ -70,37 +71,43 @@ type ZltxOrderDetailData struct {
} }
// Execute 执行直连天下订单详情查询 // Execute 执行直连天下订单详情查询
func (w *ZltxOrderDetailTool) Execute(ctx context.Context, args json.RawMessage) (interface{}, error) { func (w *ZltxOrderDetailTool) Execute(c *websocket.Conn, args json.RawMessage) error {
var req ZltxOrderDetailRequest var req ZltxOrderDetailRequest
if err := json.Unmarshal(args, &req); err != nil { if err := json.Unmarshal(args, &req); err != nil {
return nil, fmt.Errorf("invalid zltxOrderDetail request: %w", err) return fmt.Errorf("invalid zltxOrderDetail request: %w", err)
} }
if req.Number == "" { if req.OrderNumber == "" {
return nil, fmt.Errorf("number is required") return fmt.Errorf("number is required")
} }
// 这里可以集成真实的直连天下订单详情API // 这里可以集成真实的直连天下订单详情API
return w.getZltxOrderDetail(ctx, req.Number), nil return w.getZltxOrderDetail(c, req.OrderNumber)
} }
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *ZltxOrderDetailTool) getZltxOrderDetail(ctx context.Context, number string) *ZltxOrderDetailResponse { func (w *ZltxOrderDetailTool) getZltxOrderDetail(c *websocket.Conn, number string) (err error) {
url := fmt.Sprintf("%s/admin/direct/ai/%s", w.config.BaseURL, number) //查询订单详情
authorization := fmt.Sprintf("Bearer %s", w.config.APIKey) var auth string
if c != nil {
auth = c.Headers("X-Authorization", "")
}
if len(auth) == 0 {
auth = w.config.APIKey
}
req := l_request.Request{
Url: fmt.Sprintf("%s%s", w.config.BaseURL, number),
Headers: map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", auth),
},
Method: "GET",
}
res, err := req.Send()
// 发送http请求
req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return &ZltxOrderDetailResponse{} return
} }
req.Header.Set("Authorization", authorization) c.WriteMessage(websocket.TextMessage, res.Content)
resp, err := http.DefaultClient.Do(req) return
if err != nil {
return &ZltxOrderDetailResponse{}
}
defer resp.Body.Close()
return &ZltxOrderDetailResponse{}
} }