diff --git a/config/config.yaml b/config/config.yaml index ea79049..9f9ec5d 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -6,6 +6,7 @@ server: ollama: base_url: "http://127.0.0.1:11434" model: "qwen3-coder:480b-cloud" + generate_model: "deepseek-v3.1:671b-cloud" timeout: "120s" level: "info" format: "json" diff --git a/internal/biz/llm_service/ollama.go b/internal/biz/llm_service/ollama.go index 00929aa..9a7ee25 100644 --- a/internal/biz/llm_service/ollama.go +++ b/internal/biz/llm_service/ollama.go @@ -1,6 +1,7 @@ package llm_service import ( + "ai_scheduler/internal/config" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" @@ -16,18 +17,24 @@ import ( type OllamaService struct { client *utils_ollama.Client + config *config.Config } func NewOllamaGenerate( client *utils_ollama.Client, + config *config.Config, ) *OllamaService { return &OllamaService{ client: client, + config: config, } } func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) { - prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks) + prompt, err := r.getPrompt(ctx, requireData) + if err != nil { + return + } toolDefinitions := r.registerToolsOllama(requireData.Tasks) match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions) if err != nil { @@ -53,21 +60,45 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity return } -func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message { +func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) { + var ( prompt = make([]api.Message, 0) ) prompt = append(prompt, api.Message{ Role: "system", - Content: buildSystemPrompt(sysInfo.SysPrompt), + Content: buildSystemPrompt(requireData.Sys.SysPrompt), }, api.Message{ Role: "assistant", - Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))), + Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))), }, api.Message{ Role: "user", - Content: reqInput, + Content: requireData.UserInput, }) - return prompt + + if len(requireData.ImgByte) > 0 { + _, err := r.RecognizeWithImg(ctx, requireData) + if err != nil { + return nil, err + } + + } + return prompt, nil +} + +func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) { + requireData.Ch <- entitys.Response{ + Index: "", + Content: "图片识别中。。。", + Type: entitys.ResponseLoading, + } + desc, err = r.client.Generation(ctx, &api.GenerateRequest{ + Model: r.config.Ollama.GenerateModel, + Stream: new(bool), + System: "识别图片内容", + Prompt: requireData.UserInput, + }) + return } func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { diff --git a/internal/biz/router.go b/internal/biz/router.go index c88fa62..d26933c 100644 --- a/internal/biz/router.go +++ b/internal/biz/router.go @@ -181,7 +181,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura } func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) { - return + if len(imgUrl) == 0 { return } diff --git a/internal/config/config.go b/internal/config/config.go index 2e80b18..5cb97d8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -39,9 +39,10 @@ type ServerConfig struct { // OllamaConfig Ollama配置 type OllamaConfig struct { - BaseURL string `mapstructure:"base_url"` - Model string `mapstructure:"model"` - Timeout time.Duration `mapstructure:"timeout"` + BaseURL string `mapstructure:"base_url"` + Model string `mapstructure:"model"` + GenerateModel string `mapstructure:"generate_model"` + Timeout time.Duration `mapstructure:"timeout"` } type Redis struct { diff --git a/internal/pkg/utils_ollama/client.go b/internal/pkg/utils_ollama/client.go index fddbf0a..a5bc9ca 100644 --- a/internal/pkg/utils_ollama/client.go +++ b/internal/pkg/utils_ollama/client.go @@ -59,10 +59,13 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [ return } -func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) { +func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) { + if len(model) == 0 { + model = c.config.Model + } // 构建聊天请求 req := &api.ChatRequest{ - Model: c.config.Model, + Model: model, Messages: messages, Stream: nil, Think: &api.ThinkValue{Value: true}, @@ -90,6 +93,14 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa return } +func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) { + err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error { + result = resp + return nil + }) + return +} + // convertResponse 转换响应格式 func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse { //result := &entitys.ChatResponse{ diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 4ef85a6..8136efb 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -71,7 +71,7 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager { } // 普通对话 - chat := NewNormalChatTool(m.llm) + chat := NewNormalChatTool(m.llm, config) m.tools[chat.Name()] = chat return m diff --git a/internal/tools/normal_chat.go b/internal/tools/normal_chat.go index 9ef8903..56010fc 100644 --- a/internal/tools/normal_chat.go +++ b/internal/tools/normal_chat.go @@ -1,6 +1,7 @@ package tools import ( + "ai_scheduler/internal/config" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg/utils_ollama" @@ -13,12 +14,13 @@ import ( // NormalChatTool 普通对话 type NormalChatTool struct { - llm *utils_ollama.Client + llm *utils_ollama.Client + config *config.Config } // NewNormalChatTool 实例普通对话 -func NewNormalChatTool(llm *utils_ollama.Client) *NormalChatTool { - return &NormalChatTool{llm: llm} +func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool { + return &NormalChatTool{llm: llm, config: config} } // Name 返回工具名称 @@ -56,11 +58,11 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi // getMockZltxOrderDetail 获取模拟直连天下订单详情数据 func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { - requireData.Ch <- entitys.Response{ - Index: w.Name(), - Content: "", - Type: entitys.ResponseStream, - } + //requireData.Ch <- entitys.Response{ + // Index: w.Name(), + // Content: "", + // Type: entitys.ResponseStream, + //} err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ { @@ -75,7 +77,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat Role: "user", Content: chat.ChatContent, }, - }, w.Name()) + }, w.Name(), w.config.Ollama.GenerateModel) if err != nil { return fmt.Errorf(":%s", err) } diff --git a/internal/tools/zltx_order_detail.go b/internal/tools/zltx_order_detail.go index 84ff6f6..6c9bd9c 100644 --- a/internal/tools/zltx_order_detail.go +++ b/internal/tools/zltx_order_detail.go @@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat Role: "user", Content: requireData.UserInput, }, - }, w.Name()) + }, w.Name(), "") if err != nil { return fmt.Errorf("订单日志解析失败:%s", err) }