ai_scheduler/internal/biz/llm_service/ollama.go

82 lines
2.0 KiB
Go

package llm_service
import (
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"context"
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2/log"
"github.com/ollama/ollama/api"
)
type OllamaService struct {
client *utils_ollama.Client
}
func NewOllamaGenerate(
client *utils_ollama.Client,
) *OllamaService {
return &OllamaService{
client: client,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, requireData *entitys.RequireData) (msg string, err error) {
prompt := r.getPrompt(requireData.Sys, requireData.Histories, requireData.UserInput, requireData.Tasks)
toolDefinitions := r.registerToolsOllama(requireData.Tasks)
match, err := r.client.ToolSelect(context.TODO(), prompt, toolDefinitions)
log.Info("意图识别结果: %v", match)
if err != nil {
return
}
msg = match.Message.Content
return
}
func (r *OllamaService) getPrompt(sysInfo model.AiSy, history []model.AiChatHi, reqInput string, tasks []model.AiTask) []api.Message {
var (
prompt = make([]api.Message, 0)
)
prompt = append(prompt, api.Message{
Role: "system",
Content: buildSystemPrompt(sysInfo.SysPrompt),
}, api.Message{
Role: "assistant",
Content: fmt.Sprintf("聊天记录:%s", pkg.JsonStringIgonErr(buildAssistant(history))),
}, api.Message{
Role: "user",
Content: reqInput,
})
return prompt
}
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
var taskConfig entitys.TaskConfigDetail
err := json.Unmarshal([]byte(task.Config), &taskConfig)
if err != nil {
continue
}
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Index,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: taskConfig.Param.Type,
Required: taskConfig.Param.Required,
Properties: taskConfig.Param.Properties,
},
},
})
}
return taskPrompt
}