结构修改

This commit is contained in:
renzhiyuan 2025-09-19 09:32:02 +08:00
parent 303dd39cb3
commit 068563b914
6 changed files with 71 additions and 25 deletions

4
go.mod
View File

@ -11,6 +11,7 @@ require (
github.com/gofiber/fiber/v2 v2.52.9
github.com/gofiber/websocket/v2 v2.2.1
github.com/google/wire v0.7.0
github.com/ollama/ollama v0.11.11
github.com/redis/go-redis/v9 v9.14.0
github.com/spf13/viper v1.17.0
github.com/tmc/langchaingo v0.1.13
@ -31,7 +32,6 @@ require (
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
@ -59,8 +59,8 @@ require (
github.com/valyala/tcplisten v1.0.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect

6
go.sum
View File

@ -187,6 +187,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/ollama/ollama v0.11.11 h1:mErMiUGclp47rCDbSUmBiY2L76EpT0uIYRZVBO6qg/k=
github.com/ollama/ollama v0.11.11/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -265,6 +267,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -396,6 +400,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -14,9 +14,8 @@ import (
"encoding/json"
"log"
"github.com/tmc/langchaingo/llms"
"github.com/gofiber/websocket/v2"
"github.com/tmc/langchaingo/llms"
"xorm.io/builder"
)
@ -93,13 +92,25 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
}
toolDefinitions := r.registerTools(task)
prompt := r.getPrompt(sysInfo, history, req.Text)
msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
//prompt := r.getPrompt(sysInfo, history, req.Text)
//意图预测
//msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
// llms.WithTools(toolDefinitions),
// //llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
// llms.WithJSONMode(),
//)
prompt := r.getPromptLLM(sysInfo, history, req.Text)
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
llms.WithTools(toolDefinitions),
llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
llms.WithToolChoice("tool_name"),
llms.WithJSONMode(),
)
c.WriteMessage(1, []byte(msg))
if err != nil {
return errors.SystemError
}
c.WriteMessage(1, []byte(msg.Choices[0].Content))
// 构建消息
//messages := []entitys.Message{
// {
@ -187,7 +198,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
return
}
@ -207,25 +218,28 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks)
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, len(tasks))
for k, task := range tasks {
taskPrompt := make([]llms.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt[k].Type = "function"
taskPrompt[k].Function = &llms.FunctionDefinition{
Name: task.Name,
Description: task.Desc,
Parameters: taskConfig.Param,
}
taskPrompt = append(taskPrompt, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: task.Index,
Description: task.Desc,
Parameters: taskConfig.Param,
},
})
}
return taskPrompt
}
@ -234,7 +248,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{}, entitys.Message{
prompt = append(prompt, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
@ -247,6 +261,29 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
return prompt
}
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent {
var (
prompt = make([]llms.MessageContent, 0)
)
prompt = append(prompt, llms.MessageContent{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
},
}, llms.MessageContent{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart(reqInput),
},
})
return prompt
}
// buildSystemPrompt 构建系统提示词
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
if len(prompt) == 0 {

View File

@ -5,9 +5,10 @@ import (
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
"context"
"time"
"github.com/gofiber/fiber/v2/log"
"gorm.io/gorm"
"time"
)
type ChatImpl struct {
@ -50,7 +51,7 @@ func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
return
// 定时打印通道大小
case <-time.After(time.Second * 5):
log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
//log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
}
}
}

View File

@ -2,7 +2,6 @@ package utils_ollama
import (
"ai_scheduler/internal/config"
"net/http"
"os"
@ -19,6 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("1h"),
)
if err != nil {
logger.Fatal(err)

View File

@ -105,13 +105,15 @@ func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{})
return err
}
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.Model).Where(query)
total int64
)
if len(orderBy) == 0 {
orderBy = "updated_at desc"
}
// 1. 计算总数
if err = model.Count(&total).Error; err != nil {
return nil, err
@ -123,7 +125,7 @@ func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, resul
// 3. 查询数据(确保 result 是指针,如 &[]User
if err = model.Limit(pageBoIn.GetSize()).
Offset(pageBoIn.GetOffset()).
Order("updated_at desc").
Order(orderBy).
Find(result).Error; err != nil {
return nil, err
}