结构修改
This commit is contained in:
parent
f052000d59
commit
366a0b1e85
|
@ -6,6 +6,7 @@ server:
|
|||
ollama:
|
||||
base_url: "http://127.0.0.1:11434"
|
||||
model: "qwen3-coder:480b-cloud"
|
||||
generate_model: "deepseek-v3.1:671b-cloud"
|
||||
timeout: "120s"
|
||||
level: "info"
|
||||
format: "json"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package llm_service
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
|
@ -16,18 +17,24 @@ import (
|
|||
|
||||
type OllamaService struct {
|
||||
client *utils_ollama.Client
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewOllamaGenerate(
|
||||
client *utils_ollama.Client,
|
||||
config *config.Config,
|
||||
) *OllamaService {
|
||||
return &OllamaService{
|
||||
client: client,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
||||
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
|
||||
prompt, err := r.getPrompt(ctx, requireData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
||||
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
|
||||
if err != nil {
|
||||
|
@ -53,21 +60,45 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
|
|||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
|
||||
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
||||
|
||||
var (
|
||||
prompt = make([]api.Message, 0)
|
||||
)
|
||||
prompt = append(prompt, api.Message{
|
||||
Role: "system",
|
||||
Content: buildSystemPrompt(sysInfo.SysPrompt),
|
||||
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
||||
}, api.Message{
|
||||
Role: "assistant",
|
||||
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
|
||||
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
|
||||
}, api.Message{
|
||||
Role: "user",
|
||||
Content: reqInput,
|
||||
Content: requireData.UserInput,
|
||||
})
|
||||
return prompt
|
||||
|
||||
if len(requireData.ImgByte) > 0 {
|
||||
_, err := r.RecognizeWithImg(ctx, requireData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
requireData.Ch <- entitys.Response{
|
||||
Index: "",
|
||||
Content: "图片识别中。。。",
|
||||
Type: entitys.ResponseLoading,
|
||||
}
|
||||
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||
Model: r.config.Ollama.GenerateModel,
|
||||
Stream: new(bool),
|
||||
System: "识别图片内容",
|
||||
Prompt: requireData.UserInput,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||
|
|
|
@ -181,7 +181,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
|
|||
}
|
||||
|
||||
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
|
||||
return
|
||||
|
||||
if len(imgUrl) == 0 {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ type ServerConfig struct {
|
|||
type OllamaConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Model string `mapstructure:"model"`
|
||||
GenerateModel string `mapstructure:"generate_model"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
|
|
|
@ -59,10 +59,13 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) {
|
||||
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
|
||||
if len(model) == 0 {
|
||||
model = c.config.Model
|
||||
}
|
||||
// 构建聊天请求
|
||||
req := &api.ChatRequest{
|
||||
Model: c.config.Model,
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Stream: nil,
|
||||
Think: &api.ThinkValue{Value: true},
|
||||
|
@ -90,6 +93,14 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
|
||||
err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
|
||||
result = resp
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// convertResponse 转换响应格式
|
||||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||||
//result := &entitys.ChatResponse{
|
||||
|
|
|
@ -71,7 +71,7 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
|
|||
}
|
||||
|
||||
// 普通对话
|
||||
chat := NewNormalChatTool(m.llm)
|
||||
chat := NewNormalChatTool(m.llm, config)
|
||||
m.tools[chat.Name()] = chat
|
||||
|
||||
return m
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
|
@ -14,11 +15,12 @@ import (
|
|||
// NormalChatTool 普通对话
|
||||
type NormalChatTool struct {
|
||||
llm *utils_ollama.Client
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewNormalChatTool 实例普通对话
|
||||
func NewNormalChatTool(llm *utils_ollama.Client) *NormalChatTool {
|
||||
return &NormalChatTool{llm: llm}
|
||||
func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool {
|
||||
return &NormalChatTool{llm: llm, config: config}
|
||||
}
|
||||
|
||||
// Name 返回工具名称
|
||||
|
@ -56,11 +58,11 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
|
|||
|
||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
||||
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
|
||||
requireData.Ch <- entitys.Response{
|
||||
Index: w.Name(),
|
||||
Content: "<think></think>",
|
||||
Type: entitys.ResponseStream,
|
||||
}
|
||||
//requireData.Ch <- entitys.Response{
|
||||
// Index: w.Name(),
|
||||
// Content: "<think></think>",
|
||||
// Type: entitys.ResponseStream,
|
||||
//}
|
||||
|
||||
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
|
||||
{
|
||||
|
@ -75,7 +77,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat
|
|||
Role: "user",
|
||||
Content: chat.ChatContent,
|
||||
},
|
||||
}, w.Name())
|
||||
}, w.Name(), w.config.Ollama.GenerateModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf(":%s", err)
|
||||
}
|
||||
|
|
|
@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
|
|||
Role: "user",
|
||||
Content: requireData.UserInput,
|
||||
},
|
||||
}, w.Name())
|
||||
}, w.Name(), "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("订单日志解析失败:%s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue