From 068563b91491d60fab01fdd95d76f83f65e109b0 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Fri, 19 Sep 2025 09:32:02 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 4 +- go.sum | 6 +++ internal/biz/router.go | 71 ++++++++++++++++++++++------- internal/data/impl/chat_history.go | 5 +- internal/pkg/utils_ollama/ollama.go | 2 +- tmpl/dataTemp/queryTempl.go | 8 ++-- 6 files changed, 71 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index 7dca890..48c2db6 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/websocket/v2 v2.2.1 github.com/google/wire v0.7.0 + github.com/ollama/ollama v0.11.11 github.com/redis/go-redis/v9 v9.14.0 github.com/spf13/viper v1.17.0 github.com/tmc/langchaingo v0.1.13 @@ -31,7 +32,6 @@ require ( github.com/fasthttp/websocket v1.5.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -59,8 +59,8 @@ require ( github.com/valyala/tcplisten v1.0.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect + golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect - golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect diff --git a/go.sum b/go.sum index ac6b693..2542880 100644 --- a/go.sum +++ b/go.sum @@ -187,6 +187,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/ollama/ollama v0.11.11 h1:mErMiUGclp47rCDbSUmBiY2L76EpT0uIYRZVBO6qg/k= +github.com/ollama/ollama v0.11.11/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -265,6 +267,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -396,6 +400,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/biz/router.go b/internal/biz/router.go index 06a87e2..1471009 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -14,9 +14,8 @@ import ( "encoding/json" "log" - "github.com/tmc/langchaingo/llms" - "github.com/gofiber/websocket/v2" + "github.com/tmc/langchaingo/llms" "xorm.io/builder" ) @@ -93,13 +92,25 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo } toolDefinitions := r.registerTools(task) - prompt := r.getPrompt(sysInfo, history, req.Text) - msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt), + //prompt := r.getPrompt(sysInfo, history, req.Text) + + //意图预测 + //msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt), + // llms.WithTools(toolDefinitions), + // //llms.WithToolChoice(llms.FunctionCallBehaviorAuto), + // llms.WithJSONMode(), + //) + prompt := r.getPromptLLM(sysInfo, history, req.Text) + msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, llms.WithTools(toolDefinitions), - llms.WithToolChoice(llms.FunctionCallBehaviorAuto), + llms.WithToolChoice("tool_name"), llms.WithJSONMode(), ) - c.WriteMessage(1, []byte(msg)) + if err != nil { + return errors.SystemError + } + + c.WriteMessage(1, []byte(msg.Choices[0].Content)) // 构建消息 //messages := []entitys.Message{ // { @@ -187,7 +198,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": sessionId}) - _, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his) + _, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc") return } @@ -207,25 +218,28 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) - _, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks) + _, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return } func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool { - taskPrompt := make([]llms.Tool, len(tasks)) - for k, task := range tasks { + taskPrompt := make([]llms.Tool, 0) + for _, task := range tasks { var taskConfig entitys.TaskConfig err := json.Unmarshal([]byte(task.Config), &taskConfig) if err != nil { continue } - taskPrompt[k].Type = "function" - taskPrompt[k].Function = &llms.FunctionDefinition{ - Name: task.Name, - Description: task.Desc, - Parameters: taskConfig.Param, - } + taskPrompt = append(taskPrompt, llms.Tool{ + Type: "function", + Function: &llms.FunctionDefinition{ + Name: task.Index, + Description: task.Desc, + Parameters: taskConfig.Param, + }, + }) + } return taskPrompt } @@ -234,7 +248,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi var ( prompt = make([]entitys.Message, 0) ) - prompt = append(prompt, entitys.Message{}, entitys.Message{ + prompt = append(prompt, entitys.Message{ Role: "system", Content: r.buildSystemPrompt(sysInfo.SysPrompt), }, entitys.Message{ @@ -247,6 +261,29 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi return prompt } +func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent { + var ( + prompt = make([]llms.MessageContent, 0) + ) + prompt = append(prompt, llms.MessageContent{ + Role: llms.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), + }, + }, llms.MessageContent{ + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextPart(reqInput), + }, + }) + return prompt +} + // buildSystemPrompt 构建系统提示词 func (r *AiRouterService) buildSystemPrompt(prompt string) string { if len(prompt) == 0 { diff --git a/internal/data/impl/chat_history.go b/internal/data/impl/chat_history.go index 3a164a4..6f6027d 100644 --- a/internal/data/impl/chat_history.go +++ b/internal/data/impl/chat_history.go @@ -5,9 +5,10 @@ import ( "ai_scheduler/tmpl/dataTemp" "ai_scheduler/utils" "context" + "time" + "github.com/gofiber/fiber/v2/log" "gorm.io/gorm" - "time" ) type ChatImpl struct { @@ -50,7 +51,7 @@ func (impl *ChatImpl) AsyncProcess(ctx context.Context) { return // 定时打印通道大小 case <-time.After(time.Second * 5): - log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel)) + //log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel)) } } } diff --git a/internal/pkg/utils_ollama/ollama.go b/internal/pkg/utils_ollama/ollama.go index 8f6aac8..f5546bb 100644 --- a/internal/pkg/utils_ollama/ollama.go +++ b/internal/pkg/utils_ollama/ollama.go @@ -2,7 +2,6 @@ package utils_ollama import ( "ai_scheduler/internal/config" - "net/http" "os" @@ -19,6 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama { ollama.WithModel(c.Ollama.Model), ollama.WithHTTPClient(http.DefaultClient), ollama.WithServerURL(getUrl(c)), + ollama.WithKeepAlive("1h"), ) if err != nil { logger.Fatal(err) diff --git a/tmpl/dataTemp/queryTempl.go b/tmpl/dataTemp/queryTempl.go index 8eb7352..e3d98e3 100644 --- a/tmpl/dataTemp/queryTempl.go +++ b/tmpl/dataTemp/queryTempl.go @@ -105,13 +105,15 @@ func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{}) return err } -func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) { +func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) { var ( query, _ = builder.ToBoundSQL(*cond) model = k.Db.Model(k.Model).Where(query) total int64 ) - + if len(orderBy) == 0 { + orderBy = "updated_at desc" + } // 1. 计算总数 if err = model.Count(&total).Error; err != nil { return nil, err @@ -123,7 +125,7 @@ func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, resul // 3. 查询数据(确保 result 是指针,如 &[]User) if err = model.Limit(pageBoIn.GetSize()). Offset(pageBoIn.GetOffset()). - Order("updated_at desc"). + Order(orderBy). Find(result).Error; err != nil { return nil, err }