88 lines
2.2 KiB
Go
88 lines
2.2 KiB
Go
package llm_service
|
|
|
|
import (
|
|
"ai_scheduler/internal/data/model"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg"
|
|
"ai_scheduler/internal/pkg/utils_langchain"
|
|
"context"
|
|
"encoding/json"
|
|
|
|
"github.com/tmc/langchaingo/llms"
|
|
)
|
|
|
|
type LangChainService struct {
|
|
client *utils_langchain.UtilLangChain
|
|
}
|
|
|
|
func NewLangChainGenerate(
|
|
client *utils_langchain.UtilLangChain,
|
|
) *LangChainService {
|
|
|
|
return &LangChainService{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (r *LangChainService) IntentRecognize(ctx context.Context, sysInfo model.AiSy, history []model.AiChatHi, userInput string, tasks []model.AiTask) (msg string, err error) {
|
|
prompt := r.getPrompt(sysInfo, history, userInput, tasks)
|
|
AgentClient := r.client.Get()
|
|
defer r.client.Put(AgentClient)
|
|
match, err := AgentClient.Llm.GenerateContent(
|
|
ctx, // 使用可取消的上下文
|
|
prompt,
|
|
llms.WithJSONMode(),
|
|
)
|
|
msg = match.Choices[0].Content
|
|
return
|
|
}
|
|
|
|
func (r *LangChainService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []llms.MessageContent {
|
|
var (
|
|
prompt = make([]llms.MessageContent, 0)
|
|
)
|
|
prompt = append(prompt, llms.MessageContent{
|
|
Role: llms.ChatMessageTypeSystem,
|
|
Parts: []llms.ContentPart{
|
|
llms.TextPart(buildSystemPrompt(sysInfo.SysPrompt)),
|
|
},
|
|
}, llms.MessageContent{
|
|
Role: llms.ChatMessageTypeTool,
|
|
Parts: []llms.ContentPart{
|
|
llms.TextPart(pkg.JsonStringIgonErr(buildAssistant(history))),
|
|
},
|
|
}, llms.MessageContent{
|
|
Role: llms.ChatMessageTypeTool,
|
|
Parts: []llms.ContentPart{
|
|
llms.TextPart(pkg.JsonStringIgonErr(r.registerTools(tasks))),
|
|
},
|
|
}, llms.MessageContent{
|
|
Role: llms.ChatMessageTypeHuman,
|
|
Parts: []llms.ContentPart{
|
|
llms.TextPart(reqInput),
|
|
},
|
|
})
|
|
return prompt
|
|
}
|
|
|
|
func (r *LangChainService) registerTools(tasks []model.AiTask) []llms.Tool {
|
|
taskPrompt := make([]llms.Tool, 0)
|
|
for _, task := range tasks {
|
|
var taskConfig entitys.TaskConfig
|
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
taskPrompt = append(taskPrompt, llms.Tool{
|
|
Type: "function",
|
|
Function: &llms.FunctionDefinition{
|
|
Name: task.Index,
|
|
Description: task.Desc,
|
|
Parameters: taskConfig.Param,
|
|
},
|
|
})
|
|
|
|
}
|
|
return taskPrompt
|
|
}
|