129 lines
3.6 KiB
Go
129 lines
3.6 KiB
Go
package llm_service
|
|
|
|
import (
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/impl"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg"
|
|
"ai_scheduler/internal/pkg/utils_ollama"
|
|
"ai_scheduler/internal/pkg/utils_vllm"
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
type OllamaService struct {
|
|
client *utils_ollama.Client
|
|
vllmClient *utils_vllm.Client
|
|
config *config.Config
|
|
chatHis *impl.ChatHisImpl
|
|
}
|
|
|
|
func NewOllamaGenerate(
|
|
client *utils_ollama.Client,
|
|
vllmClient *utils_vllm.Client,
|
|
config *config.Config,
|
|
chatHis *impl.ChatHisImpl,
|
|
) *OllamaService {
|
|
return &OllamaService{
|
|
client: client,
|
|
vllmClient: vllmClient,
|
|
config: config,
|
|
chatHis: chatHis,
|
|
}
|
|
}
|
|
|
|
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
|
|
|
toolDefinitions := r.registerToolsOllama(req.Tools)
|
|
match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if len(match.Message.Content) == 0 {
|
|
if match.Message.ToolCalls != nil {
|
|
var matchFromTools = &entitys.Match{
|
|
Confidence: 1,
|
|
Index: match.Message.ToolCalls[0].Function.Name,
|
|
Parameters: pkg.JsonStringIgonErr(match.Message.ToolCalls[0].Function.Arguments),
|
|
IsMatch: true,
|
|
}
|
|
match.Message.Content = pkg.JsonStringIgonErr(matchFromTools)
|
|
} else {
|
|
err = errors.New("不太明白你想表达的意思呢,可以在仔细描述一下您所需要的内容吗,感谢感谢")
|
|
return
|
|
}
|
|
}
|
|
|
|
msg = match.Message.Content
|
|
return
|
|
}
|
|
|
|
//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
|
// if imgByte == nil {
|
|
// return
|
|
// }
|
|
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
|
//
|
|
// desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
|
// Model: r.config.Ollama.VlModel,
|
|
// Stream: new(bool),
|
|
// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
|
// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
|
// Images: requireData.ImgByte,
|
|
// KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
|
// //Think: &api.ThinkValue{Value: false},
|
|
// })
|
|
// if err != nil {
|
|
// return
|
|
// }
|
|
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
|
// return
|
|
//}
|
|
|
|
//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
|
// if requireData.ImgByte == nil {
|
|
// return
|
|
// }
|
|
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
|
//
|
|
// outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
|
// r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
|
// r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
|
// requireData.ImgUrls,
|
|
// )
|
|
// if err != nil {
|
|
// return api.GenerateResponse{}, err
|
|
// }
|
|
//
|
|
// desc = api.GenerateResponse{
|
|
// Response: outMsg.Content,
|
|
// }
|
|
//
|
|
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
|
// return
|
|
//}
|
|
|
|
func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool {
|
|
taskPrompt := make([]api.Tool, 0)
|
|
for _, task := range tasks {
|
|
taskPrompt = append(taskPrompt, api.Tool{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: task.Name,
|
|
Description: task.Desc,
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: task.TaskConfigDetail.Param.Type,
|
|
Required: task.TaskConfigDetail.Param.Required,
|
|
Properties: task.TaskConfigDetail.Param.Properties,
|
|
},
|
|
},
|
|
})
|
|
|
|
}
|
|
return taskPrompt
|
|
}
|