Compare commits

..

No commits in common. "0ce09611defb9207db32d51b01f8d6aa15508f15" and "addcf2d24d4ce60fa254c9f458f8cb56bccfe86d" have entirely different histories.

5 changed files with 21 additions and 133 deletions

View File

@ -2,24 +2,18 @@ package biz
import ( import (
"ai_scheduler/internal/config" "ai_scheduler/internal/config"
"ai_scheduler/internal/data/constant"
errors "ai_scheduler/internal/data/error" errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model" "ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg" "ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/mapstructure"
"ai_scheduler/internal/pkg/utils_ollama" "ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/tools" "ai_scheduler/internal/tools"
"ai_scheduler/tmpl/dataTemp" "ai_scheduler/tmpl/dataTemp"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "log"
"net/http"
"strings"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2" "github.com/gofiber/websocket/v2"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"xorm.io/builder" "xorm.io/builder"
@ -47,7 +41,6 @@ func NewAiRouterBiz(
hisImpl *impl.ChatImpl, hisImpl *impl.ChatImpl,
conf *config.Config, conf *config.Config,
utilAgent *utils_ollama.UtilOllama, utilAgent *utils_ollama.UtilOllama,
) entitys.RouterService { ) entitys.RouterService {
return &AiRouterService{ return &AiRouterService{
//aiClient: aiClient, //aiClient: aiClient,
@ -98,27 +91,26 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
return errors.SystemError return errors.SystemError
} }
//toolDefinitions := r.registerTools(task) toolDefinitions := r.registerTools(task)
//prompt := r.getPrompt(sysInfo, history, req.Text) //prompt := r.getPrompt(sysInfo, history, req.Text)
//意图预测 //意图预测
prompt := r.getPromptLLM(sysInfo, history, req.Text, task) //msg, err := r.utilAgent.Llm.Call(context.TODO(), pkg.JsonStringIgonErr(prompt),
match, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt, // llms.WithTools(toolDefinitions),
//llms.WithTools(toolDefinitions), // //llms.WithToolChoice(llms.FunctionCallBehaviorAuto),
//llms.WithToolChoice("tool_name"), // llms.WithJSONMode(),
//)
prompt := r.getPromptLLM(sysInfo, history, req.Text)
msg, err := r.utilAgent.Llm.GenerateContent(context.TODO(), prompt,
llms.WithTools(toolDefinitions),
llms.WithToolChoice("tool_name"),
llms.WithJSONMode(), llms.WithJSONMode(),
) )
if err != nil { if err != nil {
return errors.SystemError return errors.SystemError
} }
log.Info(match)
var matchJson entitys.Match c.WriteMessage(1, []byte(msg.Choices[0].Content))
err = json.Unmarshal([]byte(match.Choices[0].Content), &matchJson)
if err != nil {
return errors.SystemError
}
return r.handleMatch(c, &matchJson, task)
//c.WriteMessage(1, []byte(msg.Choices[0].Content))
// 构建消息 // 构建消息
//messages := []entitys.Message{ //messages := []entitys.Message{
// { // {
@ -201,76 +193,6 @@ func (r *AiRouterService) RouteWithSocket(c *websocket.Conn, req *entitys.ChatSo
return nil return nil
} }
func (r *AiRouterService) handleMatch(c *websocket.Conn, matchJson *entitys.Match, tasks []model.AiTask) (err error) {
defer func() {
c.WriteMessage(1, []byte("EOF"))
}()
if !matchJson.IsMatch {
c.WriteMessage(1, []byte(matchJson.Reasoning))
return
}
var pointTask *model.AiTask
for _, task := range tasks {
if task.Index == matchJson.Index {
pointTask = &task
break
}
}
if pointTask == nil || pointTask.Index == "other" {
return r.handleOtherTask(c, matchJson)
}
var res []byte
switch pointTask.Type {
case constant.TaskTypeApi:
res, err = r.handleApiTask(c, matchJson, pointTask)
default:
return r.handleOtherTask(c, matchJson)
}
return
}
func (r *AiRouterService) handleOtherTask(c *websocket.Conn, matchJson *entitys.Match) (err error) {
c.WriteMessage(1, []byte(matchJson.Reasoning))
return
}
func (r *AiRouterService) handleApiTask(c *websocket.Conn, matchJson *entitys.Match, task *model.AiTask) (resByte []byte, err error) {
var (
request l_request.Request
auth = c.Headers("X-Authorization", "")
requestParam map[string]interface{}
)
err = json.Unmarshal([]byte(matchJson.Parameters), &requestParam)
if err != nil {
return
}
request.Url = strings.ReplaceAll(task.Config, "${authorization}", auth)
for k, v := range requestParam {
task.Config = strings.ReplaceAll(task.Config, "${"+k+"}", fmt.Sprintf("%v", v))
}
var configData entitys.ConfigData
err = json.Unmarshal([]byte(task.Config), &configData)
if err != nil {
return
}
err = mapstructure.Decode(configData.Do, &request)
if err != nil {
return
}
if len(request.Url) == 0 {
err = errors.NewBusinessErr("00022", "api地址获取失败")
return
}
res, err := request.Send()
if err != nil {
return
}
return res.Content, nil
}
func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) { func (r *AiRouterService) getSessionChatHis(sessionId string) (his []model.AiChatHi, err error) {
cond := builder.NewCond() cond := builder.NewCond()
@ -339,7 +261,7 @@ func (r *AiRouterService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi
return prompt return prompt
} }
func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent { func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiChatHi, reqInput string) []llms.MessageContent {
var ( var (
prompt = make([]llms.MessageContent, 0) prompt = make([]llms.MessageContent, 0)
) )
@ -353,11 +275,6 @@ func (r *AiRouterService) getPromptLLM(sysInfo model.AiSy, history []model.AiCha
Parts: []llms.ContentPart{ Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))), llms.TextPart(pkg.JsonStringIgonErr(r.buildAssistant(history))),
}, },
}, llms.MessageContent{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
},
}, llms.MessageContent{ }, llms.MessageContent{
Role: llms.ChatMessageTypeHuman, Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{ Parts: []llms.ContentPart{

View File

@ -11,6 +11,6 @@ const (
type TaskType int32 type TaskType int32
const ( const (
TaskTypeApi = 1 TaskTypeApi ConnStatus = iota + 1
TaskTypeKnowle = 2 TaskTypeKnowle
) )

View File

@ -69,9 +69,9 @@ type Tool interface {
Execute(ctx context.Context, args json.RawMessage) (interface{}, error) Execute(ctx context.Context, args json.RawMessage) (interface{}, error)
} }
type ConfigData struct { // AIClient AI客户端接口
Param map[string]interface{} `json:"param"` type AIClient interface {
Do map[string]interface{} `json:"do"` Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResponse, error)
} }
// Message 消息 // Message 消息
@ -88,13 +88,7 @@ type FuncApi struct {
type TaskConfig struct { type TaskConfig struct {
Param interface{} `json:"param"` Param interface{} `json:"param"`
} }
type Match struct {
Confidence float64 `json:"confidence"`
Index string `json:"index"`
IsMatch bool `json:"is_match"`
Parameters string `json:"parameters"`
Reasoning string `json:"reasoning"`
}
type ChatHis struct { type ChatHis struct {
SessionId string `json:"session_id"` SessionId string `json:"session_id"`
Messages []HisMessage `json:"messages"` Messages []HisMessage `json:"messages"`

View File

@ -18,7 +18,7 @@ func NewUtilOllama(c *config.Config, logger log.AllLogger) *UtilOllama {
ollama.WithModel(c.Ollama.Model), ollama.WithModel(c.Ollama.Model),
ollama.WithHTTPClient(http.DefaultClient), ollama.WithHTTPClient(http.DefaultClient),
ollama.WithServerURL(getUrl(c)), ollama.WithServerURL(getUrl(c)),
ollama.WithKeepAlive("-1s"), ollama.WithKeepAlive("1h"),
) )
if err != nil { if err != nil {
logger.Fatal(err) logger.Fatal(err)

View File

@ -2,11 +2,8 @@ package test
import ( import (
"ai_scheduler/internal/entitys" "ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg/mapstructure"
"encoding/json" "encoding/json"
"testing" "testing"
"gitea.cdlsxd.cn/self-tools/l_request"
) )
func Test_task(t *testing.T) { func Test_task(t *testing.T) {
@ -15,23 +12,3 @@ func Test_task(t *testing.T) {
err := json.Unmarshal([]byte(config), &c) err := json.Unmarshal([]byte(config), &c)
t.Log(err) t.Log(err)
} }
type configData struct {
Param map[string]interface{} `json:"param"`
Do map[string]interface{} `json:"do"`
}
func Test_task2(t *testing.T) {
var (
c l_request.Request
config configData
)
configJson := `{"param": {"type": "object", "optional": [], "required": ["order_number"], "properties": {"order_number": {"type": "string", "description": "订单编号/流水号"}}}, "do": {"url": "http://www.baidu.com/${number}", "headers": {"Authorization": "${authorization}"}, "method": "GET"}}`
err := json.Unmarshal([]byte(configJson), &config)
if err != nil {
panic(err)
}
mapstructure.Decode(config.Do, &c)
t.Log(err)
}