ai_scheduler/internal/biz/llm_service/ollama.go

129 lines
3.2 KiB
Go

package llm_service
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gofiber/fiber/v2/log"
"github.com/ollama/ollama/api"
)
type OllamaService struct {
client *utils_ollama.Client
config *config.Config
}
func NewOllamaGenerate(
client *utils_ollama.Client,
config *config.Config,
) *OllamaService {
return &OllamaService{
client: client,
config: config,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt, err := r.getPrompt(ctx, requireData)
if err != nil {
return
}
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
if err != nil {
return
}
log.Info("意图识别结果: %v", pkg.JsonStringIgonErr(match))
if len(match.Message.Content) == 0 {
if match.Message.ToolCalls != nil {
var matchFromTools = &entitys.Match{
Confidence: 1,
Index: match.Message.ToolCalls[0].Function.Name,
Parameters: pkg.JsonStringIgonErr(match.Message.ToolCalls[0].Function.Arguments),
IsMatch: true,
}
match.Message.Content = pkg.JsonStringIgonErr(matchFromTools)
} else {
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
return
}
}
msg = match.Message.Content
return
}
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
}, api.Message{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(requireData.Histories))),
}, api.Message{
Role: "user",
Content: requireData.UserInput,
})
if len(requireData.ImgByte) > 0 {
_, err := r.RecognizeWithImg(ctx, requireData)
if err != nil {
return nil, err
}
}
return prompt, nil
}
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
requireData.Ch <- entitys.Response{
Index: "",
Content: "图片识别中。。。",
Type: entitys.ResponseLoading,
}
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
Model: r.config.Ollama.GenerateModel,
Stream: new(bool),
System: "识别图片内容",
Prompt: requireData.UserInput,
})
return
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
},
})
}
return taskPrompt
}