Compare commits

..

3 Commits

Author SHA1 Message Date
wolter 303dd39cb3 feat: 合并分支 2025-09-18 18:30:35 +08:00
wolter c91e5519e9 Merge branch 'refs/heads/feature/fiber' into feature/session
# Conflicts:
#	internal/data/model/ai_chat_his.gen.go
2025-09-18 18:18:11 +08:00
renzhiyuan 99c397aabf 结构修改 2025-09-18 17:56:08 +08:00
30 changed files with 488 additions and 114 deletions

View File

@ -11,7 +11,7 @@ import (
func main() {
configPath := flag.String("config", "config.yaml", "Path to configuration file")
configPath := flag.String("config", "./config/config.yaml", "Path to configuration file")
flag.Parse()
bc, err := config.LoadConfig(*configPath)
if err != nil {

View File

@ -5,7 +5,7 @@ server:
ollama:
base_url: "http://localhost:11434"
model: "deepseek-r1:8b"
model: "qwen3:8b"
timeout: "120s"
level: "info"
format: "json"
@ -13,7 +13,7 @@ ollama:
sys:
session_len: 3
Redis:
redis:
host: 47.97.27.195:6379
type: node
pass: lansexiongdi@666
@ -23,6 +23,6 @@ Redis:
maxIdleTime: 30 #每个连接最大空闲时间,如果超过了这个时间会被关闭
tls: 30
db:
DB:
db:
driver: mysql
source: transfer:Lsxd@34234QW@tcp(lsxdpolar.rwlb.rds.aliyuncs.com:3306)/transfer?charset=utf8mb4&parseTime=true&
source: root:SD###sdf323r343@tcp(121.199.38.107:3306)/sys_ai?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai

5
go.mod
View File

@ -5,12 +5,12 @@ go 1.24.0
toolchain go1.24.7
require (
gitea.cdlsxd.cn/self-tools/l_request v1.0.8
github.com/emirpasic/gods v1.18.1
github.com/go-kratos/kratos/v2 v2.9.1
github.com/gofiber/fiber/v2 v2.52.9
github.com/gofiber/websocket/v2 v2.2.1
github.com/google/wire v0.7.0
github.com/ollama/ollama v0.11.10
github.com/redis/go-redis/v9 v9.14.0
github.com/spf13/viper v1.17.0
github.com/tmc/langchaingo v0.1.13
@ -31,6 +31,7 @@ require (
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
@ -58,8 +59,8 @@ require (
github.com/valyala/tcplisten v1.0.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect

8
go.sum
View File

@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
gitea.cdlsxd.cn/self-tools/l_request v1.0.8 h1:FaKRql9mCVcSoaGqPeBOAruZ52slzRngQ6VRTYKNSsA=
gitea.cdlsxd.cn/self-tools/l_request v1.0.8/go.mod h1:Qf4hVXm2Eu5vOvwXk8D7U0q/aekMCkZ4Fg9wnRKlasQ=
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s=
gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
@ -185,8 +187,6 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/ollama/ollama v0.11.10 h1:J9zaoTPwIXOrYXCRAqI7rV4cJ+FOMuQc/vBqQ5GIdWg=
github.com/ollama/ollama v0.11.10/go.mod h1:9+1//yWPsDE2u+l1a5mpaKrYw4VdnSsRU3ioq5BvMms=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -265,8 +265,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -398,8 +396,6 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -0,0 +1,49 @@
package biz
import (
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"context"
)
type ChatHistoryBiz struct {
chatRepo *impl.ChatImpl
}
func NewChatHistoryBiz(chatRepo *impl.ChatImpl) *ChatHistoryBiz {
s := &ChatHistoryBiz{
chatRepo: chatRepo,
}
go s.AsyncProcess(context.Background())
return s
}
func (s *ChatHistoryBiz) create(ctx context.Context, sessionID, role, content string) error {
chat := model.AiChatHi{
SessionID: sessionID,
Role: role,
Content: content,
}
return s.chatRepo.Create(&chat)
}
// 添加会话历史
func (s *ChatHistoryBiz) Create(ctx context.Context, chat entitys.ChatHistory) error {
return s.create(ctx, chat.SessionID, chat.Role.String(), chat.Content)
}
// 异步添加会话历史
func (s *ChatHistoryBiz) AsyncCreate(ctx context.Context, chat entitys.ChatHistory) {
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
SessionID: chat.SessionID,
Role: chat.Role.String(),
Content: chat.Content,
})
}
// 异步处理会话历史
func (s *ChatHistoryBiz) AsyncProcess(ctx context.Context) {
s.chatRepo.AsyncProcess(ctx)
}

View File

@ -2,4 +2,4 @@ package biz
import "github.com/google/wire"
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz)
var ProviderSetBiz = wire.NewSet(NewAiRouterBiz, NewSessionBiz, NewChatHistoryBiz)

View File

@ -6,14 +6,15 @@ import (
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
"log"
"strings"
"github.com/tmc/langchaingo/llms"
"github.com/gofiber/websocket/v2"
"xorm.io/builder"
@ -21,23 +22,36 @@ import (
// AiRouterService 智能路由服务
type AiRouterService struct {
aiClient entitys.AIClient
//aiClient entitys.AIClient
toolManager *tools.Manager
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
utilAgent *utils_ollama.UtilOllama
}
// NewRouterService 创建路由服务
func NewAiRouterBiz(aiClient entitys.AIClient, toolManager *tools.Manager, sessionImpl *impl.SessionImpl, sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, conf *config.Config) entitys.RouterService {
func NewAiRouterBiz(
//aiClient entitys.AIClient,
toolManager *tools.Manager,
sessionImpl *impl.SessionImpl,
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
utilAgent *utils_ollama.UtilOllama,
) entitys.RouterService {
return &AiRouterService{
aiClient: aiClient,
//aiClient: aiClient,
toolManager: toolManager,
sessionImpl: sessionImpl,
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
utilAgent: utilAgent,
}
}
@ -50,48 +64,42 @@ func (r *AiRouterService) Route(ctx context.Context, req *entitys.ChatRequest) (
// Route 执行智能路由
func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRequest) error {
session := c.Headers("x-session", "")
session := c.Headers("X-Session", "")
if len(session) == 0 {
return errors.SessionNotFound
}
auth := c.Headers("x-authorization", "")
auth := c.Headers("X-Authorization", "")
if len(auth) == 0 {
return errors.AuthNotFound
}
key := c.Headers("x-app-key", "")
key := c.Headers("X-App-Key", "")
if len(key) == 0 {
return errors.KeyNotFound
}
sysInfo, err := r.getSysInfo(session)
sysInfo, err := r.getSysInfo(key)
if err != nil {
return errors.SysNotFound
}
history, err := r.getSessionHis(session)
history, err := r.getSessionChatHis(session)
if err != nil {
return errors.SystemError
}
taskPrompt, err := r.getTasks(sysInfo.SysID)
task, err := r.getTasks(sysInfo.SysID)
if err != nil {
return errors.SystemError
}
var (
messages = make([]entitys.Message, 0)
toolDefinitions := r.registerTools(task)
prompt := r.getPrompt(sysInfo, history, req.Text)
msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
llms.WithTools(toolDefinitions),
llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
llms.WithJSONMode(),
)
messages = append(messages, entitys.Message{}, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: r.buildIntentPrompt(history, task),
}, entitys.Message{
Role: "user",
Content: req.Text,
})
c.WriteMessage(1, []byte(msg))
// 构建消息
//messages := []entitys.Message{
// {
@ -174,14 +182,13 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
return nil
}
func (r *AiRouterService) getSessionHis(sessionId string) (his []model.AiSession, err error) {
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": sessionId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = r.sessionImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
_, err = r.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: r.conf.Sys.SessionLen}, &his)
return
}
@ -194,22 +201,50 @@ func (r *AiRouterService) getSysInfo(appKey string) (sysInfo model.AiSy, err err
return
}
func (r *AiRouterService) getTasks(sysId int32) (taskPrompt []llms.FunctionDefinition, err error) {
var tasks []model.AiTask
func (r *AiRouterService) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = r.taskImpl.GetOneBySearchToStrut(&cond, &tasks)
taskPrompt = make([]llms.FunctionDefinition, len(tasks))
_, err = r.taskImpl.GetListToStruct(&cond, nil, &tasks)
return
}
func (r *AiRouterService) registerTools(tasks []model.AiTask) []llms.Tool {
taskPrompt := make([]llms.Tool, len(tasks))
for k, task := range tasks {
taskPrompt[k] = llms.FunctionDefinition{
var taskConfig entitys.TaskConfig
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt[k].Type = "function"
taskPrompt[k].Function = &llms.FunctionDefinition{
Name: task.Name,
Description: task.Desc,
Parameters: task.Parameters,
Parameters: taskConfig.Param,
}
}
return
return taskPrompt
}
func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []entitys.Message {
var (
prompt = make([]entitys.Message, 0)
)
prompt = append(prompt, entitys.Message{}, entitys.Message{
Role: "system",
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
}, entitys.Message{
Role: "assistant",
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
}, entitys.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
// buildSystemPrompt 构建系统提示词
@ -221,28 +256,22 @@ func (r *AiRouterService) buildSystemPrompt(prompt string) string {
return prompt
}
func (r *AiRouterService) buildIntentPrompt(his []model.AiSession, task []model.AiTask) string {
prompt := `##任务
分析用户输入判断用户的意图类型,没有使用Markdown格式的json格式回复
##意图类型
1. product_diagnosis - 商品诊断用户想要查询诊断或了解商品相关信息
2. order_diagnosis - 订单诊断用户想要查询诊断或了解订单相关信息
3. knowledge_qa - 知识问答用户想要进行一般性问答或获取知识信息
##判断规则
1.当用户意图不够清晰且不匹配 knowledge_qa 以外意图时使用knowledge_qa
2.当用户意图非常不清晰时使用 unknown
##格式要求
1.返回以下格式的JSON
{ "intent": "product_diagnosis" | "order_diagnosis" | "knowledge_qa" | "unknown", "confidence": 0.0-1.0,"reasoning": "判断理由"}
2.严格返回字符串格式禁用markdown格式返回
3.只返回json字符串不包含任何其他解释性文字
## 用户当前的问题是
{user_input}
`
prompt = strings.ReplaceAll(prompt, "{user_input}", userInput)
return prompt
func (r *AiRouterService) buildAssistant(his []model.AiChatHi) (chatHis entitys.ChatHis) {
for _, item := range his {
if len(chatHis.SessionId) == 0 {
chatHis.SessionId = item.SessionID
}
chatHis.Messages = append(chatHis.Messages, entitys.HisMessage{
Role: item.Role,
Content: item.Content,
Timestamp: item.CreateAt.Format("2006-01-02 15:04:05"),
})
}
chatHis.Context = entitys.HisContext{
UserLanguage: "zh-CN",
SystemMode: "technical_support",
}
return
}
// extractIntent 从AI响应中提取意图

View File

@ -1,6 +1,7 @@
package biz
import (
"ai_scheduler/internal/constants"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
@ -15,6 +16,7 @@ import (
type SessionBiz struct {
sessionRepo *impl.SessionImpl
sysRepo *impl.SysImpl
chatRepo *impl.ChatImpl
conf *config.Config
}
@ -28,14 +30,19 @@ func NewSessionBiz(conf *config.Config, sessionImpl *impl.SessionImpl, sysImpl *
}
// InitSession 初始化会话 ,当天存在则返回会话,如果不存在则创建一个
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (sessionId string, err error) {
func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRequest) (result *entitys.SessionInitResponse, err error) {
// 获取系统配置
sysConfig, has, err := s.sysRepo.FindOne(s.sysRepo.WithSysId(req.SysId))
if err != nil {
return "", err
return
} else if !has {
return "", fmt.Errorf("sys not found")
err = fmt.Errorf("sys not found")
return
}
result = &entitys.SessionInitResponse{
Chat: make([]entitys.ChatHistory, 0),
}
// 获取 当天的session
@ -46,20 +53,56 @@ func (s *SessionBiz) SessionInit(ctx context.Context, req *entitys.SessionInitRe
s.sysRepo.WithSysId(sysConfig.SysID), // 条件系统ID
)
if err != nil {
return "", err
return
} else if !has {
// 不存在,创建一个
session = model.AiSession{
SysID: sysConfig.SysID,
SessionID: utils.UUID(),
UserID: req.UserId,
}
err = s.sessionRepo.Create(&session)
if err != nil {
return "", err
return
}
chat := entitys.ChatHistory{
SessionID: session.SessionID,
Role: constants.RoleSystem,
Content: sysConfig.Prologue,
}
result.Chat = append(result.Chat, chat)
// 开场白写入会话历史
s.chatRepo.AsyncCreate(ctx, model.AiChatHi{
SessionID: chat.SessionID,
Role: chat.Role.String(),
Content: chat.Content,
})
} else {
// 存在,返回会话历史
var chatList []model.AiChatHi
chatList, err = s.chatRepo.FindAll(
s.chatRepo.WithSessionId(session.SessionID), // 条件会话ID
s.chatRepo.OrderByDesc("create_at"), // 排序:按创建时间降序
s.chatRepo.WithLimit(constants.ChatHistoryLimit), // 限制返回条数
)
if err != nil {
return
}
// 转换为 entitys.ChatHistory 类型
for _, chat := range chatList {
result.Chat = append(result.Chat, entitys.ChatHistory{
SessionID: chat.SessionID,
Role: constants.Caller(chat.Role),
Content: chat.Content,
})
}
}
return session.SessionID, nil
return
}
// SessionList 会话列表

View File

@ -14,8 +14,13 @@ type Config struct {
Sys SysConfig `mapstructure:"sys"`
Tools ToolsConfig `mapstructure:"tools"`
Logging LoggingConfig `mapstructure:"logging"`
Redis *Redis `protobuf:"bytes,1,opt,name=Redis,proto3" json:"Redis,omitempty"`
DB *DB `protobuf:"bytes,3,opt,name=TransDB,proto3" json:"TransDB,omitempty"`
Redis Redis `mapstructure:"redis"`
DB DB `mapstructure:"db"`
// LLM *LLM `mapstructure:"llm"`
}
type LLM struct {
Model string `mapstructure:"model"`
}
// SysConfig 系统配置
@ -37,24 +42,24 @@ type OllamaConfig struct {
}
type Redis struct {
Host string `protobuf:"bytes,1,opt,name=host,proto3" json:"host,omitempty"`
Type string `protobuf:"bytes,2,opt,name=type,proto3" json:"type,omitempty"`
Pass string `protobuf:"bytes,3,opt,name=pass,proto3" json:"pass,omitempty"`
Key string `protobuf:"bytes,4,opt,name=key,proto3" json:"key,omitempty"`
Tls int32 `protobuf:"varint,5,opt,name=tls,proto3" json:"tls,omitempty"`
Db int32 `protobuf:"varint,6,opt,name=db,proto3" json:"db,omitempty"`
MaxIdle int32 `protobuf:"varint,7,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"`
PoolSize int32 `protobuf:"varint,8,opt,name=poolSize,proto3" json:"poolSize,omitempty"`
MaxIdleTime int32 `protobuf:"varint,9,opt,name=maxIdleTime,proto3" json:"maxIdleTime,omitempty"`
Host string `mapstructure:"host"`
Type string `mapstructure:"type"`
Pass string `mapstructure:"pass"`
Key string `mapstructure:"key"`
Tls int32 `mapstructure:"tls"`
Db int32 `mapstructure:"db"`
MaxIdle int32 `mapstructure:"maxIdle"`
PoolSize int32 `mapstructure:"poolSize"`
MaxIdleTime int32 `mapstructure:"maxIdleTime"`
}
type DB struct {
Driver string `protobuf:"bytes,1,opt,name=driver,proto3" json:"driver,omitempty"`
Source string `protobuf:"bytes,2,opt,name=source,proto3" json:"source,omitempty"`
MaxIdle int32 `protobuf:"varint,3,opt,name=maxIdle,proto3" json:"maxIdle,omitempty"`
MaxOpen int32 `protobuf:"varint,4,opt,name=maxOpen,proto3" json:"maxOpen,omitempty"`
MaxLifetime int32 `protobuf:"varint,5,opt,name=maxLifetime,proto3" json:"maxLifetime,omitempty"`
IsDebug bool `protobuf:"varint,6,opt,name=isDebug,proto3" json:"isDebug,omitempty"`
Driver string `mapstructure:"driver"`
Source string `mapstructure:"source"`
MaxIdle int32 `mapstructure:"maxIdle"`
MaxOpen int32 `mapstructure:"maxOpen"`
MaxLifetime int32 `mapstructure:"maxLifetime"`
IsDebug bool `mapstructure:"isDebug"`
}
// ToolsConfig 工具配置

View File

@ -5,6 +5,14 @@ type Caller string
const (
CallerZltx Caller = "zltx" // 直连天下
CallerHyt Caller = "hyt" // 货易通
// 角色, 系统角色,用户角色
RoleSystem Caller = "system" // 系统角色
RoleUser Caller = "user" // 用户角色
RoleAssistant Caller = "assistant" // 助手角色
// 分页默认条数
ChatHistoryLimit = 10
)
func (c Caller) String() string {

View File

@ -7,3 +7,10 @@ const (
ConnStatusNormal
ConnStatusIgnore
)
type TaskType int32
const (
TaskTypeApi ConnStatus = iota + 1
TaskTypeKnowle
)

View File

@ -52,6 +52,7 @@ type BaseRepository[P PO] interface {
WithId(id interface{}) CondFunc // 查询id
WithStatus(status int) CondFunc // 查询status
GetDb() *gorm.DB // 获取数据库连接
WithLimit(limit int) CondFunc // 限制返回条数
}
// PaginationResult 分页查询结果
@ -207,3 +208,9 @@ func (this *BaseModel[P]) WithStatus(status int) CondFunc {
func (this *BaseModel[P]) GetDb() *gorm.DB {
return this.Db
}
func (this *BaseModel[P]) WithLimit(limit int) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Limit(limit)
}
}

View File

@ -0,0 +1,15 @@
package impl
//import (
// "ai_scheduler/internal/data/model"
// "ai_scheduler/tmpl/dataTemp"
// "ai_scheduler/utils"
//)
//type ChatHisImpl struct {
// dataTemp.DataTemp
//}
//
//func NewChatHisImpl(db *utils.Db) *ChatHisImpl {
// return &ChatHisImpl{*dataTemp.NewDataTemp(db, new(model.AiChatHi))}
//}

View File

@ -0,0 +1,56 @@
package impl
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/tmpl/dataTemp"
"ai_scheduler/utils"
"context"
"github.com/gofiber/fiber/v2/log"
"gorm.io/gorm"
"time"
)
type ChatImpl struct {
dataTemp.DataTemp
BaseRepository[model.AiChatHi]
chatChannel chan model.AiChatHi
}
func NewChatImpl(db *utils.Db) *ChatImpl {
return &ChatImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiChatHi)),
BaseRepository: NewBaseModel[model.AiChatHi](db.Client),
chatChannel: make(chan model.AiChatHi, 100),
}
}
// WithSessionId 条件会话ID
func (impl *ChatImpl) WithSessionId(sessionId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("session_id = ?", sessionId)
}
}
// 异步添加会话历史
func (impl *ChatImpl) AsyncCreate(ctx context.Context, chat model.AiChatHi) {
impl.chatChannel <- chat
}
// 异步处理会话历史
func (impl *ChatImpl) AsyncProcess(ctx context.Context) {
for {
select {
case chat := <-impl.chatChannel:
log.Infof("ChatHistoryAsyncProcess chat: %v", chat)
if err := impl.Create(&chat); err != nil {
log.Errorf("ChatHistoryAsyncProcess err: %v", err)
}
case <-ctx.Done():
log.Infof("ChatHistoryAsyncProcess ctx done")
return
// 定时打印通道大小
case <-time.After(time.Second * 5):
log.Infof("ChatHistoryAsyncProcess channel len: %d", len(impl.chatChannel))
}
}
}

View File

@ -4,4 +4,4 @@ import (
"github.com/google/wire"
)
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl)
var ProviderImpl = wire.NewSet(NewSessionImpl, NewSysImpl, NewTaskImpl, NewChatImpl)

View File

@ -10,13 +10,13 @@ import (
type SessionImpl struct {
dataTemp.DataTemp
BaseModel[model.AiSession]
BaseRepository[model.AiSession]
}
func NewSessionImpl(db *utils.Db) *SessionImpl {
return &SessionImpl{
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
BaseModel: BaseModel[model.AiSession]{},
DataTemp: *dataTemp.NewDataTemp(db, new(model.AiSession)),
BaseRepository: NewBaseModel[model.AiSession](db.Client),
}
}
@ -40,3 +40,9 @@ func (s *SessionImpl) WithSysId(sysId interface{}) CondFunc {
return db.Where("sys_id = ?", sysId)
}
}
func (impl *SessionImpl) WithSessionId(sessionId interface{}) CondFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("session_id = ?", sessionId)
}
}

View File

@ -14,7 +14,7 @@ const TableNameAiChatHi = "ai_chat_his"
type AiChatHi struct {
HisID int64 `gorm:"column:his_id;primaryKey" json:"his_id"`
SessionID string `gorm:"column:session_id;not null" json:"session_id"`
Role string `gorm:"column:role;not null" json:"role"`
Role string `gorm:"column:role;not null;comment:system系统输出assistant助手输出,user用户输入" json:"role"` // system系统输出assistant助手输出,user用户输入
Content string `gorm:"column:content;not null" json:"content"`
CreateAt *time.Time `gorm:"column:create_at;default:CURRENT_TIMESTAMP" json:"create_at"`
}

View File

@ -0,0 +1,9 @@
package entitys
import "ai_scheduler/internal/constants"
type ChatHistory struct {
SessionID string `json:"session_id"`
Role constants.Caller `json:"role"`
Content string `json:"content"`
}

View File

@ -5,6 +5,10 @@ type SessionInitRequest struct {
UserId string `json:"user_id"`
}
type SessionInitResponse struct {
Chat []ChatHistory `json:"chat"`
}
type SessionListRequest struct {
SysId string `json:"sys_id"`
UserId string `json:"user_id"`

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/websocket/v2"
)
@ -79,8 +80,29 @@ type Message struct {
Content string `json:"content"`
}
type Func struct {
Parameters struct{} `json:"parameters"`
type FuncApi struct {
Param interface{} `json:"param"`
Request l_request.Request `json:"request"`
}
type TaskConfig struct {
Param interface{} `json:"param"`
}
type ChatHis struct {
SessionId string `json:"session_id"`
Messages []HisMessage `json:"messages"`
Context HisContext `json:"context"`
}
type HisMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Timestamp string `json:"timestamp"`
}
type HisContext struct {
UserLanguage string `json:"user_language"`
SystemMode string `json:"system_mode"`
}
// RouterService 路由服务接口

View File

@ -1 +1,8 @@
package pkg
import "encoding/json"
func JsonStringIgonErr(data interface{}) string {
dataByte, _ := json.Marshal(data)
return string(dataByte)
}

View File

@ -1,12 +1,13 @@
package pkg
import (
"ai_scheduler/internal/pkg/ollama"
"ai_scheduler/internal/pkg/utils_ollama"
"github.com/google/wire"
)
var ProviderSetClient = wire.NewSet(
NewRdb,
NewGormDb,
ollama.NewClient,
utils_ollama.NewUtilOllama,
)

View File

@ -32,7 +32,7 @@ func NewRdb(c *config.Config) *Rdb {
}
// buildRdb 构建redis client
func buildRdb(c *config.Redis) *redis.Client {
func buildRdb(c config.Redis) *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: c.Host,

View File

@ -10,7 +10,7 @@ import (
"gorm.io/gorm"
)
func DBConn(c *config.DB) (*gorm.DB, func()) {
func DBConn(c config.DB) (*gorm.DB, func()) {
mysqlConn, err := sql.Open(c.Driver, c.Source)
gormDB, err := gorm.Open(
mysql.New(mysql.Config{Conn: mysqlConn}),

View File

@ -1,4 +1,4 @@
package ollama
package utils_ollama
import (
"ai_scheduler/internal/config"
@ -9,6 +9,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/tmc/langchaingo/llms/ollama"
)
// Client Ollama客户端适配器

View File

@ -0,0 +1,47 @@
package utils_ollama
import (
"ai_scheduler/internal/config"
"net/http"
"os"
"github.com/gofiber/fiber/v2/log"
"github.com/tmc/langchaingo/llms/ollama"
)
type UtilOllama struct {
Llm *ollama.LLM
}
func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
llm, err := ollama.New(
ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)),
)
if err != nil {
logger.Fatal(err)
panic(err)
}
return &UtilOllama{
Llm: llm,
}
}
//func (o *UtilOllama) a() {
// var agent agents.Agent
// agent = agents.NewOneShotAgent(llm, tools, opts...)
//
// agents.NewExecutor()
//}
func getUrl(c *config.Config) string {
baseURL := c.Ollama.BaseURL
envURL := os.Getenv("OLLAMA_BASE_URL")
if envURL != "" {
baseURL = envURL
}
return baseURL
}

View File

@ -65,7 +65,7 @@ func routerHttp(app *fiber.App, sessionService *services.SessionService) {
return nil
})
r.Post("/session/init", sessionService.SessionInit)
r.Post("/session/init", sessionService.SessionInit) // 会话初始化,不存在则创建,存在则返回会话ID和默认条数会话历史
r.Post("/session/list", sessionService.SessionList)
}

View File

@ -8,11 +8,13 @@ import (
type SessionService struct {
sessionBiz *biz.SessionBiz
chatBiz *biz.ChatHistoryBiz
}
func NewSessionService(sessionBiz *biz.SessionBiz) *SessionService {
func NewSessionService(sessionBiz *biz.SessionBiz, chatBiz *biz.ChatHistoryBiz) *SessionService {
return &SessionService{
sessionBiz: sessionBiz,
chatBiz: chatBiz,
}
}
@ -23,11 +25,9 @@ func (s *SessionService) SessionInit(c *fiber.Ctx) error {
return err
}
sessionId, err := s.sessionBiz.SessionInit(c.Context(), req)
result, err := s.sessionBiz.SessionInit(c.Context(), req)
return handRes(c, err, fiber.Map{
"session_id": sessionId,
})
return handRes(c, err, result)
}
// SessionList 获取会话列表

View File

@ -0,0 +1,14 @@
package test
import (
"ai_scheduler/internal/entitys"
"encoding/json"
"testing"
)
func Test_task(t *testing.T) {
var c entitys.TaskConfig
config := `{"param": {"type": "object", "required": ["number"], "properties": {"number": {"type": "string", "description": "订单编号/流水号"}}}, "request": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
err := json.Unmarshal([]byte(config), &c)
t.Log(err)
}

View File

@ -5,6 +5,8 @@ import (
"ai_scheduler/utils"
"context"
"database/sql"
"fmt"
"reflect"
"github.com/go-kratos/kratos/v2/log"
"gorm.io/gorm"
@ -103,7 +105,52 @@ func (k DataTemp) GetOneBySearchToStrut(cond *builder.Cond, result interface{})
return err
}
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result []interface{}) (pageBoOut *RespPageBo, err error) {
func (k DataTemp) GetListToStruct(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.Model).Where(query)
total int64
)
// 1. 计算总数
if err = model.Count(&total).Error; err != nil {
return nil, err
}
// 2. 设置分页
pageBoOut = pageBoOut.SetDataByReq(total, pageBoIn)
// 3. 查询数据(确保 result 是指针,如 &[]User
if err = model.Limit(pageBoIn.GetSize()).
Offset(pageBoIn.GetOffset()).
Order("updated_at desc").
Find(result).Error; err != nil {
return nil, err
}
// 4. 可选:用 reflect 处理结果(确保类型正确)
val := reflect.ValueOf(result)
if val.Kind() != reflect.Ptr {
return nil, fmt.Errorf("result must be a pointer")
}
val = val.Elem() // 解引用
if val.Kind() == reflect.Slice {
for i := 0; i < val.Len(); i++ {
elem := val.Index(i)
if elem.Kind() == reflect.Struct {
// 示例:打印某个字段
if field := elem.FieldByName("ID"); field.IsValid() {
fmt.Println("ID:", field.Interface())
}
}
}
}
return pageBoOut, nil
}
func (k DataTemp) GetListToStruct2(cond *builder.Cond, pageBoIn *ReqPageBo, result interface{}) (pageBoOut *RespPageBo, err error) {
var (
query, _ = builder.ToBoundSQL(*cond)
model = k.Db.Model(k.Model).Where(query)