结构修改
This commit is contained in:
parent
f052000d59
commit
366a0b1e85
|
@ -6,6 +6,7 @@ server:
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://127.0.0.1:11434"
|
base_url: "http://127.0.0.1:11434"
|
||||||
model: "qwen3-coder:480b-cloud"
|
model: "qwen3-coder:480b-cloud"
|
||||||
|
generate_model: "deepseek-v3.1:671b-cloud"
|
||||||
timeout: "120s"
|
timeout: "120s"
|
||||||
level: "info"
|
level: "info"
|
||||||
format: "json"
|
format: "json"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package llm_service
|
package llm_service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/model"
|
"ai_scheduler/internal/data/model"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
|
@ -16,18 +17,24 @@ import (
|
||||||
|
|
||||||
type OllamaService struct {
|
type OllamaService struct {
|
||||||
client *utils_ollama.Client
|
client *utils_ollama.Client
|
||||||
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOllamaGenerate(
|
func NewOllamaGenerate(
|
||||||
client *utils_ollama.Client,
|
client *utils_ollama.Client,
|
||||||
|
config *config.Config,
|
||||||
) *OllamaService {
|
) *OllamaService {
|
||||||
return &OllamaService{
|
return &OllamaService{
|
||||||
client: client,
|
client: client,
|
||||||
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
||||||
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
|
prompt, err := r.getPrompt(ctx, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
||||||
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
|
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,21 +60,45 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entity
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
|
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
prompt = make([]api.Message, 0)
|
prompt = make([]api.Message, 0)
|
||||||
)
|
)
|
||||||
prompt = append(prompt, api.Message{
|
prompt = append(prompt, api.Message{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: buildSystemPrompt(sysInfo.SysPrompt),
|
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
||||||
}, api.Message{
|
}, api.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
|
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
|
||||||
}, api.Message{
|
}, api.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: reqInput,
|
Content: requireData.UserInput,
|
||||||
})
|
})
|
||||||
return prompt
|
|
||||||
|
if len(requireData.ImgByte) > 0 {
|
||||||
|
_, err := r.RecognizeWithImg(ctx, requireData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return prompt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||||
|
requireData.Ch <- entitys.Response{
|
||||||
|
Index: "",
|
||||||
|
Content: "图片识别中。。。",
|
||||||
|
Type: entitys.ResponseLoading,
|
||||||
|
}
|
||||||
|
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||||
|
Model: r.config.Ollama.GenerateModel,
|
||||||
|
Stream: new(bool),
|
||||||
|
System: "识别图片内容",
|
||||||
|
Prompt: requireData.UserInput,
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||||
|
|
|
@ -181,7 +181,7 @@ func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Dura
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
|
func (r *AiRouterBiz) getImgData(imgUrl string, requireData *entitys.RequireData) (err error) {
|
||||||
return
|
|
||||||
if len(imgUrl) == 0 {
|
if len(imgUrl) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ type ServerConfig struct {
|
||||||
type OllamaConfig struct {
|
type OllamaConfig struct {
|
||||||
BaseURL string `mapstructure:"base_url"`
|
BaseURL string `mapstructure:"base_url"`
|
||||||
Model string `mapstructure:"model"`
|
Model string `mapstructure:"model"`
|
||||||
|
GenerateModel string `mapstructure:"generate_model"`
|
||||||
Timeout time.Duration `mapstructure:"timeout"`
|
Timeout time.Duration `mapstructure:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,10 +59,13 @@ func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools [
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string) (err error) {
|
func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messages []api.Message, index string, model string) (err error) {
|
||||||
|
if len(model) == 0 {
|
||||||
|
model = c.config.Model
|
||||||
|
}
|
||||||
// 构建聊天请求
|
// 构建聊天请求
|
||||||
req := &api.ChatRequest{
|
req := &api.ChatRequest{
|
||||||
Model: c.config.Model,
|
Model: model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Stream: nil,
|
Stream: nil,
|
||||||
Think: &api.ThinkValue{Value: true},
|
Think: &api.ThinkValue{Value: true},
|
||||||
|
@ -90,6 +93,14 @@ func (c *Client) ChatStream(ctx context.Context, ch chan entitys.Response, messa
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Generation(ctx context.Context, generateRequest *api.GenerateRequest) (result api.GenerateResponse, err error) {
|
||||||
|
err = c.client.Generate(ctx, generateRequest, func(resp api.GenerateResponse) error {
|
||||||
|
result = resp
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// convertResponse 转换响应格式
|
// convertResponse 转换响应格式
|
||||||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||||||
//result := &entitys.ChatResponse{
|
//result := &entitys.ChatResponse{
|
||||||
|
|
|
@ -71,7 +71,7 @@ func NewManager(config *config.Config, llm *utils_ollama.Client) *Manager {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 普通对话
|
// 普通对话
|
||||||
chat := NewNormalChatTool(m.llm)
|
chat := NewNormalChatTool(m.llm, config)
|
||||||
m.tools[chat.Name()] = chat
|
m.tools[chat.Name()] = chat
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package tools
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
|
@ -14,11 +15,12 @@ import (
|
||||||
// NormalChatTool 普通对话
|
// NormalChatTool 普通对话
|
||||||
type NormalChatTool struct {
|
type NormalChatTool struct {
|
||||||
llm *utils_ollama.Client
|
llm *utils_ollama.Client
|
||||||
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNormalChatTool 实例普通对话
|
// NewNormalChatTool 实例普通对话
|
||||||
func NewNormalChatTool(llm *utils_ollama.Client) *NormalChatTool {
|
func NewNormalChatTool(llm *utils_ollama.Client, config *config.Config) *NormalChatTool {
|
||||||
return &NormalChatTool{llm: llm}
|
return &NormalChatTool{llm: llm, config: config}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name 返回工具名称
|
// Name 返回工具名称
|
||||||
|
@ -56,11 +58,11 @@ func (w *NormalChatTool) Execute(ctx context.Context, requireData *entitys.Requi
|
||||||
|
|
||||||
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
// getMockZltxOrderDetail 获取模拟直连天下订单详情数据
|
||||||
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
|
func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat) (err error) {
|
||||||
requireData.Ch <- entitys.Response{
|
//requireData.Ch <- entitys.Response{
|
||||||
Index: w.Name(),
|
// Index: w.Name(),
|
||||||
Content: "<think></think>",
|
// Content: "<think></think>",
|
||||||
Type: entitys.ResponseStream,
|
// Type: entitys.ResponseStream,
|
||||||
}
|
//}
|
||||||
|
|
||||||
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
|
err = w.llm.ChatStream(context.TODO(), requireData.Ch, []api.Message{
|
||||||
{
|
{
|
||||||
|
@ -75,7 +77,7 @@ func (w *NormalChatTool) chat(requireData *entitys.RequireData, chat *NormalChat
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: chat.ChatContent,
|
Content: chat.ChatContent,
|
||||||
},
|
},
|
||||||
}, w.Name())
|
}, w.Name(), w.config.Ollama.GenerateModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(":%s", err)
|
return fmt.Errorf(":%s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,7 +173,7 @@ func (w *ZltxOrderDetailTool) getZltxOrderDetail(requireData *entitys.RequireDat
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: requireData.UserInput,
|
Content: requireData.UserInput,
|
||||||
},
|
},
|
||||||
}, w.Name())
|
}, w.Name(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("订单日志解析失败:%s", err)
|
return fmt.Errorf("订单日志解析失败:%s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue