结构修改
This commit is contained in:
parent
303dd39cb3
commit
068563b914
4
go.mod
4
go.mod
|
@ -11,6 +11,7 @@ require (
|
||||||
github.com/gofiber/fiber/v2 v2.52.9
|
github.com/gofiber/fiber/v2 v2.52.9
|
||||||
github.com/gofiber/websocket/v2 v2.2.1
|
github.com/gofiber/websocket/v2 v2.2.1
|
||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
|
github.com/ollama/ollama v0.11.11
|
||||||
github.com/redis/go-redis/v9 v9.14.0
|
github.com/redis/go-redis/v9 v9.14.0
|
||||||
github.com/spf13/viper v1.17.0
|
github.com/spf13/viper v1.17.0
|
||||||
github.com/tmc/langchaingo v0.1.13
|
github.com/tmc/langchaingo v0.1.13
|
||||||
|
@ -31,7 +32,6 @@ require (
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
@ -59,8 +59,8 @@ require (
|
||||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
|
golang.org/x/crypto v0.36.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
golang.org/x/net v0.38.0 // indirect
|
|
||||||
golang.org/x/sys v0.31.0 // indirect
|
golang.org/x/sys v0.31.0 // indirect
|
||||||
golang.org/x/text v0.23.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -187,6 +187,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
|
||||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
|
github.com/ollama/ollama v0.11.11 h1:mErMiUGclp47rCDbSUmBiY2L76EpT0uIYRZVBO6qg/k=
|
||||||
|
github.com/ollama/ollama v0.11.11/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
@ -265,6 +267,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
|
@ -396,6 +400,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||||
|
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
|
|
@ -14,9 +14,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/tmc/langchaingo/llms"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,13 +92,25 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
|
||||||
}
|
}
|
||||||
|
|
||||||
toolDefinitions := r.registerTools(task)
|
toolDefinitions := r.registerTools(task)
|
||||||
prompt := r.getPrompt(sysInfo, history, req.Text)
|
//prompt := r.getPrompt(sysInfo, history, req.Text)
|
||||||
msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
|
||||||
|
//意图预测
|
||||||
|
//msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
|
||||||
|
// llms.WithTools(toolDefinitions),
|
||||||
|
// //llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
||||||
|
// llms.WithJSONMode(),
|
||||||
|
//)
|
||||||
|
prompt := r.getPromptLLM(sysInfo, history, req.Text)
|
||||||
|
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
||||||
llms.WithTools(toolDefinitions),
|
llms.WithTools(toolDefinitions),
|
||||||
llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
|
llms.WithToolChoice("tool_name"),
|
||||||
llms.WithJSONMode(),
|
llms.WithJSONMode(),
|
||||||
)
|
)
|
||||||
c.WriteMessage(1, []byte(msg))
|
if err != nil {
|
||||||
|
return errors.SystemError
|
||||||
|
}
|
||||||
|
|
||||||
|
c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
||||||
// 构建消息
|
// 构建消息
|
||||||
//messages := []entitys.Message{
|
//messages := []entitys.Message{
|
||||||
// {
|
// {
|
||||||
|
@ -187,7 +198,7 @@ func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiCha
|
||||||
cond := builder.NewCond()
|
cond := builder.NewCond()
|
||||||
cond = cond.And(builder.Eq{"session_id": sessionId})
|
cond = cond.And(builder.Eq{"session_id": sessionId})
|
||||||
|
|
||||||
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
|
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his, "his_id asc")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -207,25 +218,28 @@ func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error
|
||||||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||||||
cond = cond.And(builder.IsNull{"delete_at"})
|
cond = cond.And(builder.IsNull{"delete_at"})
|
||||||
cond = cond.And(builder.Eq{"status": 1})
|
cond = cond.And(builder.Eq{"status": 1})
|
||||||
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks)
|
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
|
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
|
||||||
taskPrompt := make([]llms.Tool, len(tasks))
|
taskPrompt := make([]llms.Tool, 0)
|
||||||
for k, task := range tasks {
|
for _, task := range tasks {
|
||||||
var taskConfig entitys.TaskConfig
|
var taskConfig entitys.TaskConfig
|
||||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
taskPrompt[k].Type = "function"
|
taskPrompt = append(taskPrompt, llms.Tool{
|
||||||
taskPrompt[k].Function = &llms.FunctionDefinition{
|
Type: "function",
|
||||||
Name: task.Name,
|
Function: &llms.FunctionDefinition{
|
||||||
Description: task.Desc,
|
Name: task.Index,
|
||||||
Parameters: taskConfig.Param,
|
Description: task.Desc,
|
||||||
}
|
Parameters: taskConfig.Param,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
return taskPrompt
|
return taskPrompt
|
||||||
}
|
}
|
||||||
|
@ -234,7 +248,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
|
||||||
var (
|
var (
|
||||||
prompt = make([]entitys.Message, 0)
|
prompt = make([]entitys.Message, 0)
|
||||||
)
|
)
|
||||||
prompt = append(prompt, entitys.Message{}, entitys.Message{
|
prompt = append(prompt, entitys.Message{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||||||
}, entitys.Message{
|
}, entitys.Message{
|
||||||
|
@ -247,6 +261,29 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent {
|
||||||
|
var (
|
||||||
|
prompt = make([]llms.MessageContent, 0)
|
||||||
|
)
|
||||||
|
prompt = append(prompt, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeSystem,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeTool,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
||||||
|
},
|
||||||
|
}, llms.MessageContent{
|
||||||
|
Role: llms.ChatMessageTypeHuman,
|
||||||
|
Parts: []llms.ContentPart{
|
||||||
|
llms.TextPart(reqInput),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
|
||||||
// buildSystemPrompt 构建系统提示词
|
// buildSystemPrompt 构建系统提示词
|
||||||
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
func (r *AiRouterService) buildSystemPrompt(prompt string) string {
|
||||||
if len(prompt) == 0 {
|
if len(prompt) == 0 {
|
||||||
|
|
|
@ -5,9 +5,10 @@ import (
|
||||||
"ai_scheduler/tmpl/dataTemp"
|
"ai_scheduler/tmpl/dataTemp"
|
||||||
"ai_scheduler/utils"
|
"ai_scheduler/utils"
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatImpl struct {
|
type ChatImpl struct {
|
||||||
|
@ -50,7 +51,7 @@ func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
|
||||||
return
|
return
|
||||||
// 定时打印通道大小
|
// 定时打印通道大小
|
||||||
case <-time.After(time.Second * 5):
|
case <-time.After(time.Second * 5):
|
||||||
log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
|
//log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package utils_ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
@ -19,6 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
|
||||||
ollama.WithModel(c.Ollama.Model),
|
ollama.WithModel(c.Ollama.Model),
|
||||||
ollama.WithHTTPClient(http.DefaultClient),
|
ollama.WithHTTPClient(http.DefaultClient),
|
||||||
ollama.WithServerURL(getUrl(c)),
|
ollama.WithServerURL(getUrl(c)),
|
||||||
|
ollama.WithKeepAlive("1h"),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
|
|
|
@ -105,13 +105,15 @@ func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
|
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}, orderBy string) (pageBoOut *RespPageBo, err error) {
|
||||||
var (
|
var (
|
||||||
query, _ = builder.ToBoundSQL(*cond)
|
query, _ = builder.ToBoundSQL(*cond)
|
||||||
model = k.Db.Model(k.Model).Where(query)
|
model = k.Db.Model(k.Model).Where(query)
|
||||||
total int64
|
total int64
|
||||||
)
|
)
|
||||||
|
if len(orderBy) == 0 {
|
||||||
|
orderBy = "updated_at desc"
|
||||||
|
}
|
||||||
// 1. 计算总数
|
// 1. 计算总数
|
||||||
if err = model.Count(&total).Error; err != nil {
|
if err = model.Count(&total).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -123,7 +125,7 @@ func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, resul
|
||||||
// 3. 查询数据(确保 result 是指针,如 &[]User)
|
// 3. 查询数据(确保 result 是指针,如 &[]User)
|
||||||
if err = model.Limit(pageBoIn.GetSize()).
|
if err = model.Limit(pageBoIn.GetSize()).
|
||||||
Offset(pageBoIn.GetOffset()).
|
Offset(pageBoIn.GetOffset()).
|
||||||
Order("updated_at desc").
|
Order(orderBy).
|
||||||
Find(result).Error; err != nil {
|
Find(result).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue