159 lines
4.1 KiB
Go
159 lines
4.1 KiB
Go
package llm_service
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/model"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg"
|
|
"ai_scheduler/internal/pkg/utils_ollama"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
type OllamaService struct {
|
|
client *utils_ollama.Client
|
|
config *config.Config
|
|
}
|
|
|
|
func NewOllamaGenerate(
|
|
client *utils_ollama.Client,
|
|
config *config.Config,
|
|
) *OllamaService {
|
|
return &OllamaService{
|
|
client: client,
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
|
|
prompt, err := r.getPrompt(ctx, requireData)
|
|
if err != nil {
|
|
return
|
|
}
|
|
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
|
|
|
|
match, err := r.client.ToolSelect(ctx, prompt, toolDefinitions)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if len(match.Message.Content) == 0 {
|
|
if match.Message.ToolCalls != nil {
|
|
var matchFromTools = &entitys.Match{
|
|
Confidence: 1,
|
|
Index: match.Message.ToolCalls[0].Function.Name,
|
|
Parameters: pkg.JsonStringIgonErr(match.Message.ToolCalls[0].Function.Arguments),
|
|
IsMatch: true,
|
|
}
|
|
match.Message.Content = pkg.JsonStringIgonErr(matchFromTools)
|
|
} else {
|
|
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
|
|
return
|
|
}
|
|
}
|
|
|
|
msg = match.Message.Content
|
|
return
|
|
}
|
|
|
|
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
|
|
|
var (
|
|
prompt = make([]api.Message, 0)
|
|
)
|
|
prompt = append(prompt, api.Message{
|
|
Role: "system",
|
|
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
|
}, api.Message{
|
|
Role: "assistant",
|
|
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
|
}, api.Message{
|
|
Role: "user",
|
|
Content: r.getUserContent(requireData),
|
|
//Images: requireData.ImgByte,
|
|
})
|
|
|
|
if len(requireData.ImgByte) > 0 {
|
|
desc, err := r.RecognizeWithImg(ctx, requireData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var imgs strings.Builder
|
|
imgs.WriteString("### 用户上传图片解析内容:\n")
|
|
|
|
prompt = append(prompt, api.Message{
|
|
Role: "image_desc",
|
|
Content: "" + desc.Response,
|
|
})
|
|
}
|
|
return prompt, nil
|
|
}
|
|
|
|
func (r *OllamaService) getUserContent(requireData *entitys.RequireData) string {
|
|
var content strings.Builder
|
|
content.WriteString(requireData.Req.Text)
|
|
if len(requireData.ImgByte) > 0 {
|
|
content.WriteString("\n")
|
|
content.WriteString("### 图片内容已经解析到image_desc里")
|
|
}
|
|
|
|
if len(requireData.Req.Tags) > 0 {
|
|
content.WriteString("\n")
|
|
content.WriteString("### 工具必须使用:")
|
|
content.WriteString(requireData.Req.Tags)
|
|
}
|
|
return content.String()
|
|
}
|
|
|
|
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
|
if requireData.ImgByte == nil {
|
|
return
|
|
}
|
|
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
|
|
|
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
|
Model: r.config.Ollama.VlModel,
|
|
Stream: new(bool),
|
|
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
|
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
|
Images: requireData.ImgByte,
|
|
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
|
return
|
|
}
|
|
|
|
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
|
taskPrompt := make([]api.Tool, 0)
|
|
for _, task := range tasks {
|
|
var taskConfig entitys.TaskConfigDetail
|
|
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
taskPrompt = append(taskPrompt, api.Tool{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: task.Index,
|
|
Description: task.Desc,
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: taskConfig.Param.Type,
|
|
Required: taskConfig.Param.Required,
|
|
Properties: taskConfig.Param.Properties,
|
|
},
|
|
},
|
|
})
|
|
|
|
}
|
|
return taskPrompt
|
|
}
|