结构修改
This commit is contained in:
parent
c130900ac6
commit
ed7591975b
|
@ -21,6 +21,7 @@ import (
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/fiber/v2/log"
|
"github.com/gofiber/fiber/v2/log"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
)
|
)
|
||||||
|
@ -34,7 +35,7 @@ type AiRouterBiz struct {
|
||||||
taskImpl *impl.TaskImpl
|
taskImpl *impl.TaskImpl
|
||||||
hisImpl *impl.ChatImpl
|
hisImpl *impl.ChatImpl
|
||||||
conf *config.Config
|
conf *config.Config
|
||||||
utilAgent *utils_ollama.UtilOllama
|
ai *utils_ollama.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouterService 创建路由服务
|
// NewRouterService 创建路由服务
|
||||||
|
@ -46,7 +47,7 @@ func NewAiRouterBiz(
|
||||||
taskImpl *impl.TaskImpl,
|
taskImpl *impl.TaskImpl,
|
||||||
hisImpl *impl.ChatImpl,
|
hisImpl *impl.ChatImpl,
|
||||||
conf *config.Config,
|
conf *config.Config,
|
||||||
utilAgent *utils_ollama.UtilOllama,
|
ai *utils_ollama.Client,
|
||||||
|
|
||||||
) *AiRouterBiz {
|
) *AiRouterBiz {
|
||||||
return &AiRouterBiz{
|
return &AiRouterBiz{
|
||||||
|
@ -57,7 +58,7 @@ func NewAiRouterBiz(
|
||||||
sysImpl: sysImpl,
|
sysImpl: sysImpl,
|
||||||
hisImpl: hisImpl,
|
hisImpl: hisImpl,
|
||||||
taskImpl: taskImpl,
|
taskImpl: taskImpl,
|
||||||
utilAgent: utilAgent,
|
ai: ai,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,22 +104,19 @@ func (r *AiRouterBiz) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSockRe
|
||||||
|
|
||||||
//意图预测
|
//意图预测
|
||||||
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
prompt := r.getPromptLLM(sysInfo, history, req.Text, task)
|
||||||
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
|
toolDefinitions := r.registerTools(task)
|
||||||
//llms.WithTools(toolDefinitions),
|
match, err := r.ai.ToolSelect(context.TODO(), prompt, toolDefinitions)
|
||||||
//llms.WithToolChoice("tool_name"),
|
|
||||||
llms.WithJSONMode(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.SystemError
|
return errors.SystemError
|
||||||
}
|
}
|
||||||
log.Info(match)
|
log.Info(match)
|
||||||
var matchJson entitys.Match
|
//var matchJson entitys.Match
|
||||||
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
//err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return errors.SystemError
|
// return errors.SystemError
|
||||||
}
|
//}
|
||||||
return r.handleMatch(c, &matchJson, task)
|
//return r.handleMatch(c, &matchJson, task)
|
||||||
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
|
c.WriteMessage(1, []byte(match.Message.Content))
|
||||||
// 构建消息
|
// 构建消息
|
||||||
//messages := []entitys.Message{
|
//messages := []entitys.Message{
|
||||||
// {
|
// {
|
||||||
|
@ -327,20 +325,25 @@ func (r *AiRouterBiz) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []llms.Tool {
|
func (r *AiRouterBiz) registerTools(tasks []model.AiTask) []api.Tool {
|
||||||
taskPrompt := make([]llms.Tool, 0)
|
taskPrompt := make([]api.Tool, 0)
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
var taskConfig entitys.TaskConfig
|
var taskConfig entitys.TaskConfigDetail
|
||||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
taskPrompt = append(taskPrompt, llms.Tool{
|
|
||||||
|
taskPrompt = append(taskPrompt, api.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: &llms.FunctionDefinition{
|
Function: api.ToolFunction{
|
||||||
Name: task.Index,
|
Name: task.Index,
|
||||||
Description: task.Desc,
|
Description: task.Desc,
|
||||||
Parameters: taskConfig.Param,
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: taskConfig.Param.Type,
|
||||||
|
Required: taskConfig.Param.Required,
|
||||||
|
Properties: taskConfig.Param.Properties,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -365,30 +368,22 @@ func (r *AiRouterBiz) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, re
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
func (r *AiRouterBiz) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
|
||||||
var (
|
var (
|
||||||
prompt = make([]llms.MessageContent, 0)
|
prompt = make([]api.Message, 0)
|
||||||
)
|
)
|
||||||
prompt = append(prompt, llms.MessageContent{
|
prompt = append(prompt, api.Message{
|
||||||
Role: llms.ChatMessageTypeSystem,
|
Role: string(llms.ChatMessageTypeSystem),
|
||||||
Parts: []llms.ContentPart{
|
Content: r.buildSystemPrompt(sysInfo.SysPrompt),
|
||||||
llms.TextPart(r.buildSystemPrompt(sysInfo.SysPrompt)),
|
}, api.Message{
|
||||||
},
|
Role: "chatHistory",
|
||||||
}, llms.MessageContent{
|
Content: pkg.JsonStringIgonErr(r.buildAssistant(history)),
|
||||||
Role: llms.ChatMessageTypeTool,
|
}, api.Message{
|
||||||
Parts: []llms.ContentPart{
|
Role: string(llms.ChatMessageTypeTool),
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
|
Content: pkg.JsonStringIgonErr(r.registerTools(tasks)),
|
||||||
},
|
}, api.Message{
|
||||||
}, llms.MessageContent{
|
Role: string(llms.ChatMessageTypeHuman),
|
||||||
Role: llms.ChatMessageTypeTool,
|
Content: reqInput,
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
|
||||||
},
|
|
||||||
}, llms.MessageContent{
|
|
||||||
Role: llms.ChatMessageTypeHuman,
|
|
||||||
Parts: []llms.ContentPart{
|
|
||||||
llms.TextPart(reqInput),
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
return prompt
|
return prompt
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"gitea.cdlsxd.cn/self-tools/l_request"
|
"gitea.cdlsxd.cn/self-tools/l_request"
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatRequest 聊天请求
|
// ChatRequest 聊天请求
|
||||||
|
@ -93,6 +94,15 @@ type FuncApi struct {
|
||||||
type TaskConfig struct {
|
type TaskConfig struct {
|
||||||
Param interface{} `json:"param"`
|
Param interface{} `json:"param"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TaskConfigDetail struct {
|
||||||
|
Param ConfigParam `json:"param"`
|
||||||
|
}
|
||||||
|
type ConfigParam struct {
|
||||||
|
Properties map[string]api.ToolProperty
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
type Match struct {
|
type Match struct {
|
||||||
Confidence float64 `json:"confidence"`
|
Confidence float64 `json:"confidence"`
|
||||||
Index string `json:"index"`
|
Index string `json:"index"`
|
||||||
|
|
|
@ -10,4 +10,5 @@ var ProviderSetClient = wire.NewSet(
|
||||||
NewRdb,
|
NewRdb,
|
||||||
NewGormDb,
|
NewGormDb,
|
||||||
utils_ollama.NewUtilOllama,
|
utils_ollama.NewUtilOllama,
|
||||||
|
utils_ollama.NewClient,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
package utils_ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ai_scheduler/internal/config"
|
||||||
|
"ai_scheduler/internal/entitys"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client Ollama客户端适配器
|
||||||
|
type Client struct {
|
||||||
|
client *api.Client
|
||||||
|
config *config.OllamaConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient 创建新的Ollama客户端
|
||||||
|
func NewClient(config *config.Config) (client *Client, cleanFunc func(), err error) {
|
||||||
|
client = &Client{
|
||||||
|
config: &config.Ollama,
|
||||||
|
}
|
||||||
|
url, err := client.getUrl()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client.client = api.NewClient(url, http.DefaultClient)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
if client != nil {
|
||||||
|
client = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolSelect 工具选择
|
||||||
|
func (c *Client) ToolSelect(ctx context.Context, messages []api.Message, tools []api.Tool) (res api.ChatResponse, err error) {
|
||||||
|
// 构建聊天请求
|
||||||
|
req := &api.ChatRequest{
|
||||||
|
Model: c.config.Model,
|
||||||
|
Messages: messages,
|
||||||
|
Stream: new(bool), // 设置为false,不使用流式响应
|
||||||
|
Think: &api.ThinkValue{Value: true},
|
||||||
|
Tools: tools,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||||
|
res = resp
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertResponse 转换响应格式
|
||||||
|
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
||||||
|
//result := &entitys.ChatResponse{
|
||||||
|
// Message: resp.Message.Content,
|
||||||
|
// Finished: resp.Done,
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//// 转换工具调用
|
||||||
|
//if len(resp.Message.ToolCalls) > 0 {
|
||||||
|
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
|
||||||
|
// for i, toolCall := range resp.Message.ToolCalls {
|
||||||
|
// // 转换函数参数
|
||||||
|
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
|
||||||
|
//
|
||||||
|
// result.ToolCalls[i] = entitys.ToolCall{
|
||||||
|
// ID: fmt.Sprintf("call_%d", i),
|
||||||
|
// Type: "function",
|
||||||
|
// Function: entitys.FunctionCall{
|
||||||
|
// Name: toolCall.Function.Name,
|
||||||
|
// Arguments: json.RawMessage(argBytes),
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
//return result
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getUrl() (*url.URL, error) {
|
||||||
|
baseURL := c.config.BaseURL
|
||||||
|
envURL := os.Getenv("OLLAMA_BASE_URL")
|
||||||
|
if envURL != "" {
|
||||||
|
baseURL = envURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return url.Parse(baseURL)
|
||||||
|
}
|
|
@ -1,124 +0,0 @@
|
||||||
package utils_ollama
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai_scheduler/internal/config"
|
|
||||||
"ai_scheduler/internal/entitys"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/tmc/langchaingo/llms/ollama"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client Ollama客户端适配器
|
|
||||||
type Client struct {
|
|
||||||
client *api.Client
|
|
||||||
config *config.OllamaConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient 创建新的Ollama客户端
|
|
||||||
func NewClient(config *config.Config) (entitys.AIClient, func(), error) {
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
cleanup := func() {
|
|
||||||
if client != nil {
|
|
||||||
client = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, cleanup, fmt.Errorf("failed to create ollama client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
|
||||||
client: client,
|
|
||||||
config: &config.Ollama,
|
|
||||||
}, cleanup, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chat 实现聊天功能
|
|
||||||
func (c *Client) Chat(ctx context.Context, messages []entitys.Message, tools []entitys.ToolDefinition) (*entitys.ChatResponse, error) {
|
|
||||||
// 构建聊天请求
|
|
||||||
req := &api.ChatRequest{
|
|
||||||
Model: c.config.Model,
|
|
||||||
Messages: make([]api.Message, len(messages)),
|
|
||||||
Stream: new(bool), // 设置为false,不使用流式响应
|
|
||||||
Think: &api.ThinkValue{Value: true},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转换消息格式
|
|
||||||
for i, msg := range messages {
|
|
||||||
req.Messages[i] = api.Message{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: msg.Content,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加工具定义
|
|
||||||
if len(tools) > 0 {
|
|
||||||
req.Tools = make([]api.Tool, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
toolData, _ := json.Marshal(tool)
|
|
||||||
var apiTool api.Tool
|
|
||||||
json.Unmarshal(toolData, &apiTool)
|
|
||||||
req.Tools[i] = apiTool
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
responseChan := make(chan api.ChatResponse)
|
|
||||||
errorChan := make(chan error)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err := c.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
|
||||||
responseChan <- resp
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
errorChan <- err
|
|
||||||
}
|
|
||||||
close(responseChan)
|
|
||||||
close(errorChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 等待响应
|
|
||||||
select {
|
|
||||||
case resp := <-responseChan:
|
|
||||||
return c.convertResponse(&resp), nil
|
|
||||||
case err := <-errorChan:
|
|
||||||
return nil, fmt.Errorf("chat request failed: %w", err)
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-time.After(c.config.Timeout):
|
|
||||||
return nil, fmt.Errorf("chat request timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertResponse 转换响应格式
|
|
||||||
func (c *Client) convertResponse(resp *api.ChatResponse) *entitys.ChatResponse {
|
|
||||||
//result := &entitys.ChatResponse{
|
|
||||||
// Message: resp.Message.Content,
|
|
||||||
// Finished: resp.Done,
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// 转换工具调用
|
|
||||||
//if len(resp.Message.ToolCalls) > 0 {
|
|
||||||
// result.ToolCalls = make([]entitys.ToolCall, len(resp.Message.ToolCalls))
|
|
||||||
// for i, toolCall := range resp.Message.ToolCalls {
|
|
||||||
// // 转换函数参数
|
|
||||||
// argBytes, _ := json.Marshal(toolCall.Function.Arguments)
|
|
||||||
//
|
|
||||||
// result.ToolCalls[i] = entitys.ToolCall{
|
|
||||||
// ID: fmt.Sprintf("call_%d", i),
|
|
||||||
// Type: "function",
|
|
||||||
// Function: entitys.FunctionCall{
|
|
||||||
// Name: toolCall.Function.Name,
|
|
||||||
// Arguments: json.RawMessage(argBytes),
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
//return result
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -45,7 +45,7 @@ func NewManager(config *config.Config, llm *utils_ollama.UtilOllama) *Manager {
|
||||||
|
|
||||||
// 注册直连天下订单详情工具
|
// 注册直连天下订单详情工具
|
||||||
if config.Tools.ZltxOrderDetail.Enabled {
|
if config.Tools.ZltxOrderDetail.Enabled {
|
||||||
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail)
|
zltxOrderDetailTool := NewZltxOrderDetailTool(config.Tools.ZltxOrderDetail, m.llm)
|
||||||
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
|
m.tools[zltxOrderDetailTool.Name()] = zltxOrderDetailTool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package tools
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -13,11 +14,12 @@ import (
|
||||||
// ZltxOrderDetailTool 直连天下订单详情工具
|
// ZltxOrderDetailTool 直连天下订单详情工具
|
||||||
type ZltxOrderDetailTool struct {
|
type ZltxOrderDetailTool struct {
|
||||||
config config.ToolConfig
|
config config.ToolConfig
|
||||||
|
llm *utils_ollama.UtilOllama
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewZltxOrderDetailTool 创建直连天下订单详情工具
|
// NewZltxOrderDetailTool 创建直连天下订单详情工具
|
||||||
func NewZltxOrderDetailTool(config config.ToolConfig) *ZltxOrderDetailTool {
|
func NewZltxOrderDetailTool(config config.ToolConfig, llm *utils_ollama.UtilOllama) *ZltxOrderDetailTool {
|
||||||
return &ZltxOrderDetailTool{config: config}
|
return &ZltxOrderDetailTool{config: config, llm: llm}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name 返回工具名称
|
// Name 返回工具名称
|
||||||
|
|
Loading…
Reference in New Issue