结构修改

This commit is contained in:
renzhiyuan 2025-10-09 15:09:00 +08:00
parent f052000d59
commit 366a0b1e85
8 changed files with 69 additions and 23 deletions

View File

@ -6,6 +6,7 @@ server:
ollama: ollama:
base_url: "http://127.0.0.1:11434" base_url: "http://127.0.0.1:11434"
model: "qwen3-coder:480b-cloud" model: "qwen3-coder:480b-cloud"
generate_model: "deepseek-v3.1:671b-cloud"
timeout: "120s" timeout: "120s"
level: "info" level: "info"
format: "json" format: "json"

View File

@ -1,6 +1,7 @@
package llm_service package llm_service
import ( import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
@ -16,18 +17,24 @@ import (
type OllamaService struct { type OllamaService struct {
client *utils_ollama.Client client *utils_ollama.Client
config *config.Config
} }
func NewOllamaGenerate( func NewOllamaGenerate(
client *utils_ollama.Client, client *utils_ollama.Client,
config *config.Config,
) *OllamaService { ) *OllamaService {
return &OllamaService{ return &OllamaService{
client: client, client: client,
config: config,
} }
} }
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) { func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks) prompt, err := r.getPrompt(ctx, requireData)
if err != nil {
return
}
toolDefinitions := r.registerToolsOllama(requireData.Tasks) toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions) match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil { if err != nil {
@ -53,21 +60,45 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
return return
} }
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message { func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
var ( var (
prompt = make([]api.Message, 0) prompt = make([]api.Message, 0)
) )
prompt = append(prompt, api.Message{ prompt = append(prompt, api.Message{
Role: "system", Role: "system",
Content: buildSystemPrompt(sysInfo.SysPrompt), Content: buildSystemPrompt(requireData.Sys.SysPrompt),
}, api.Message{ }, api.Message{
Role: "assistant", Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))), Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
}, api.Message{ }, api.Message{
Role: "user", Role: "user",
Content: reqInput, Content: requireData.UserInput,
}) })
return prompt
if len(requireData.ImgByte) > 0 {
_, err := r.RecognizeWithImg(ctx, requireData)
if err != nil {
return nil, err
}
}
return prompt, nil
}
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: "图片识别中。。。",
Type: entitys.ResponseLoading,
}
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
Model: r.config.Ollama.GenerateModel,
Stream: new(bool),
System: "识别图片内容",
Prompt: requireData.UserInput,
})
return
} }
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool { func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {

View File

@ -181,7 +181,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
} }
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) { func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
return
if len(imgUrl) == 0 { if len(imgUrl) == 0 {
return return
} }

View File

@ -41,6 +41,7 @@ type ServerConfig struct {
type OllamaConfig struct { type OllamaConfig struct {
BaseURL string `mapstructure:"base_url"` BaseURL string `mapstructure:"base_url"`
Model string `mapstructure:"model"` Model string `mapstructure:"model"`
GenerateModel string `mapstructure:"generate_model"`
Timeout time.Duration `mapstructure:"timeout"` Timeout time.Duration `mapstructure:"timeout"`
} }

View File

@ -59,10 +59,13 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
return return
} }
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) { func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
if len(model) == 0 {
model = c.config.Model
}
// 构建聊天请求 // 构建聊天请求
req := &api.ChatRequest{ req := &api.ChatRequest{
Model: c.config.Model, Model: model,
Messages: messages, Messages: messages,
Stream: nil, Stream: nil,
Think: &api.ThinkValue{Value: true}, Think: &api.ThinkValue{Value: true},
@ -90,6 +93,14 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa
return return
} }
func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
result = resp
return nil
})
return
}
// convertResponse 转换响应格式 // convertResponse 转换响应格式
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse { func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
//result := &entitys.ChatResponse{ //result := &entitys.ChatResponse{

View File

@ -71,7 +71,7 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
} }
// 普通对话 // 普通对话
chat := NewNormalChatTool(m.llm) chat := NewNormalChatTool(m.llm, config)
m.tools[chat.Name()] = chat m.tools[chat.Name()] = chat
return m return m

View File

@ -1,6 +1,7 @@
package tools package tools
import ( import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
@ -14,11 +15,12 @@ import (
// NormalChatTool 普通对话 // NormalChatTool 普通对话
type NormalChatTool struct { type NormalChatTool struct {
llm *utils_ollama.Client llm *utils_ollama.Client
config *config.Config
} }
// NewNormalChatTool 实例普通对话 // NewNormalChatTool 实例普通对话
func NewNormalChatTool(llm *utils_ollama.Client) *NormalChatTool { func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool {
return &NormalChatTool{llm: llm} return &NormalChatTool{llm: llm, config: config}
} }
// Name 返回工具名称 // Name 返回工具名称
@ -56,11 +58,11 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据 // getMockZltxOrderDetail 获取模拟直连天下订单详情数据
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) { func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
requireData.Ch <- entitys.Response{ //requireData.Ch <- entitys.Response{
Index: w.Name(), // Index: w.Name(),
Content: "<think></think>", // Content: "<think></think>",
Type: entitys.ResponseStream, // Type: entitys.ResponseStream,
} //}
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{ err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
{ {
@ -75,7 +77,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat
Role: "user", Role: "user",
Content: chat.ChatContent, Content: chat.ChatContent,
}, },
}, w.Name()) }, w.Name(), w.config.Ollama.GenerateModel)
if err != nil { if err != nil {
return fmt.Errorf("%s", err) return fmt.Errorf("%s", err)
} }

View File

@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
Role: "user", Role: "user",
Content: requireData.UserInput, Content: requireData.UserInput,
}, },
}, w.Name()) }, w.Name(), "")
if err != nil { if err != nil {
return fmt.Errorf("订单日志解析失败:%s", err) return fmt.Errorf("订单日志解析失败:%s", err)
} }