From 366a0b1e85458ebb2c73094001ef5b21bbdcb634 Mon Sep 17 00:00:00 2001
From: renzhiyuan <465386466@qq.com>
Date: Thu, 9 Oct 2025 15:09:00 +0800
Subject: [PATCH] =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
config/config.yaml | 1 +
internal/biz/llm_service/ollama.go | 43 +++++++++++++++++++++++++----
internal/biz/router.go | 2 +-
internal/config/config.go | 7 +++--
internal/pkg/utils_ollama/client.go | 15 ++++++++--
internal/tools/manager.go | 2 +-
internal/tools/normal_chat.go | 20 ++++++++------
internal/tools/zltx_order_detail.go | 2 +-
8 files changed, 69 insertions(+), 23 deletions(-)
diff --git a/config/config.yaml b/config/config.yaml
index ea79049..9f9ec5d 100644
--- a/config/config.yaml
+++ b/config/config.yaml
@@ -6,6 +6,7 @@ server:
ollama:
base_url: "http://127.0.0.1:11434"
model: "qwen3-coder:480b-cloud"
+ generate_model: "deepseek-v3.1:671b-cloud"
timeout: "120s"
level: "info"
format: "json"
diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go
index 00929aa..9a7ee25 100644
--- a/internal/biz/llm_service/ollama.go
+++ b/internal/biz/llm_service/ollama.go
@@ -1,6 +1,7 @@
package llm_service
import (
+ "ai_scheduler/internal/config"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
@@ -16,18 +17,24 @@ import (
type OllamaService struct {
client *utils_ollama.Client
+ config *config.Config
}
func NewOllamaGenerate(
client *utils_ollama.Client,
+ config *config.Config,
) *OllamaService {
return &OllamaService{
client: client,
+ config: config,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
- prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
+ prompt, err := r.getPrompt(ctx, requireData)
+ if err != nil {
+ return
+ }
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
@@ -53,21 +60,45 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
return
}
-func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
+func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
+
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
- Content: buildSystemPrompt(sysInfo.SysPrompt),
+ Content: buildSystemPrompt(requireData.Sys.SysPrompt),
}, api.Message{
Role: "assistant",
- Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
+ Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
}, api.Message{
Role: "user",
- Content: reqInput,
+ Content: requireData.UserInput,
})
- return prompt
+
+ if len(requireData.ImgByte) > 0 {
+ _, err := r.RecognizeWithImg(ctx, requireData)
+ if err != nil {
+ return nil, err
+ }
+
+ }
+ return prompt, nil
+}
+
+func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
+ requireData.Ch <- entitys.Response{
+ Index: "",
+ Content: "图片识别中。。。",
+ Type: entitys.ResponseLoading,
+ }
+ desc, err = r.client.Generation(ctx, &api.GenerateRequest{
+ Model: r.config.Ollama.GenerateModel,
+ Stream: new(bool),
+ System: "识别图片内容",
+ Prompt: requireData.UserInput,
+ })
+ return
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
diff --git a/internal/biz/router.go b/internal/biz/router.go
index c88fa62..d26933c 100644
--- a/internal/biz/router.go
+++ b/internal/biz/router.go
@@ -181,7 +181,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
}
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
- return
+
if len(imgUrl) == 0 {
return
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 2e80b18..5cb97d8 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -39,9 +39,10 @@ type ServerConfig struct {
// OllamaConfig Ollama配置
type OllamaConfig struct {
- BaseURL string `mapstructure:"base_url"`
- Model string `mapstructure:"model"`
- Timeout time.Duration `mapstructure:"timeout"`
+ BaseURL string `mapstructure:"base_url"`
+ Model string `mapstructure:"model"`
+ GenerateModel string `mapstructure:"generate_model"`
+ Timeout time.Duration `mapstructure:"timeout"`
}
type Redis struct {
diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go
index fddbf0a..a5bc9ca 100644
--- a/internal/pkg/utils_ollama/client.go
+++ b/internal/pkg/utils_ollama/client.go
@@ -59,10 +59,13 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
return
}
-func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) {
+func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
+ if len(model) == 0 {
+ model = c.config.Model
+ }
// 构建聊天请求
req := &api.ChatRequest{
- Model: c.config.Model,
+ Model: model,
Messages: messages,
Stream: nil,
Think: &api.ThinkValue{Value: true},
@@ -90,6 +93,14 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa
return
}
+func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
+ err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
+ result = resp
+ return nil
+ })
+ return
+}
+
// convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{
diff --git a/internal/tools/manager.go b/internal/tools/manager.go
index 4ef85a6..8136efb 100644
--- a/internal/tools/manager.go
+++ b/internal/tools/manager.go
@@ -71,7 +71,7 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
}
// 普通对话
- chat := NewNormalChatTool(m.llm)
+ chat := NewNormalChatTool(m.llm, config)
m.tools[chat.Name()] = chat
return m
diff --git a/internal/tools/normal_chat.go b/internal/tools/normal_chat.go
index 9ef8903..56010fc 100644
--- a/internal/tools/normal_chat.go
+++ b/internal/tools/normal_chat.go
@@ -1,6 +1,7 @@
package tools
import (
+ "ai_scheduler/internal/config"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
@@ -13,12 +14,13 @@ import (
// NormalChatTool 普通对话
type NormalChatTool struct {
- llm *utils_ollama.Client
+ llm *utils_ollama.Client
+ config *config.Config
}
// NewNormalChatTool 实例普通对话
-func NewNormalChatTool(llm *utils_ollama.Client) *NormalChatTool {
- return &NormalChatTool{llm: llm}
+func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool {
+ return &NormalChatTool{llm: llm, config: config}
}
// Name 返回工具名称
@@ -56,11 +58,11 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
- requireData.Ch <- entitys.Response{
- Index: w.Name(),
- Content: "",
- Type: entitys.ResponseStream,
- }
+ //requireData.Ch <- entitys.Response{
+ // Index: w.Name(),
+ // Content: "",
+ // Type: entitys.ResponseStream,
+ //}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{
@@ -75,7 +77,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat
Role: "user",
Content: chat.ChatContent,
},
- }, w.Name())
+ }, w.Name(), w.config.Ollama.GenerateModel)
if err != nil {
return fmt.Errorf(":%s", err)
}
diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go
index 84ff6f6..6c9bd9c 100644
--- a/internal/tools/zltx_order_detail.go
+++ b/internal/tools/zltx_order_detail.go
@@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
Role: "user",
Content: requireData.UserInput,
},
- }, w.Name())
+ }, w.Name(), "")
if err != nil {
return fmt.Errorf("订单日志解析失败:%s", err)
}