ai_scheduler/internal/biz/llm_service/ollama.go

129 lines
3.6 KiB
Go

package llm_service
import (
"ai_scheduler/internal/config"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/internal/pkg/utils_ollama"
"ai_scheduler/internal/pkg/utils_vllm"
"context"
"errors"
"github.com/ollama/ollama/api"
)
type OllamaService struct {
client *utils_ollama.Client
vllmClient *utils_vllm.Client
config *config.Config
chatHis *impl.ChatHisImpl
}
func NewOllamaGenerate(
client *utils_ollama.Client,
vllmClient *utils_vllm.Client,
config *config.Config,
chatHis *impl.ChatHisImpl,
) *OllamaService {
return &OllamaService{
client: client,
vllmClient: vllmClient,
config: config,
chatHis: chatHis,
}
}
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
toolDefinitions := r.registerToolsOllama(req.Tools)
match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions)
if err != nil {
return
}
if len(match.Message.Content) == 0 {
if match.Message.ToolCalls != nil {
var matchFromTools = &entitys.Match{
Confidence: 1,
Index: match.Message.ToolCalls[0].Function.Name,
Parameters: pkg.JsonStringIgonErr(match.Message.ToolCalls[0].Function.Arguments),
IsMatch: true,
}
match.Message.Content = pkg.JsonStringIgonErr(matchFromTools)
} else {
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
return
}
}
msg = match.Message.Content
return
}
//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
// if imgByte == nil {
// return
// }
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
//
// desc, err = r.client.Generation(ctx, &api.GenerateRequest{
// Model: r.config.Ollama.VlModel,
// Stream: new(bool),
// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
// Images: requireData.ImgByte,
// KeepAlive: &api.Duration{Duration: 3600 * time.Second},
// //Think: &api.ThinkValue{Value: false},
// })
// if err != nil {
// return
// }
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
// return
//}
//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
// if requireData.ImgByte == nil {
// return
// }
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
//
// outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
// r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
// r.config.DefaultPrompt.ImgRecognize.UserPrompt,
// requireData.ImgUrls,
// )
// if err != nil {
// return api.GenerateResponse{}, err
// }
//
// desc = api.GenerateResponse{
// Response: outMsg.Content,
// }
//
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
// return
//}
func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool {
taskPrompt := make([]api.Tool, 0)
for _, task := range tasks {
taskPrompt = append(taskPrompt, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: task.Name,
Description: task.Desc,
Parameters: api.ToolFunctionParameters{
Type: task.TaskConfigDetail.Param.Type,
Required: task.TaskConfigDetail.Param.Required,
Properties: task.TaskConfigDetail.Param.Properties,
},
},
})
}
return taskPrompt
}