refactor: optimize ollama service and update dependencies
This commit is contained in:
parent
782006b8d3
commit
586ad6a124
1
go.mod
1
go.mod
|
|
@ -58,7 +58,6 @@ require (
|
|||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -174,8 +174,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
|||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
||||
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||
|
|
|
|||
|
|
@ -3,19 +3,15 @@ package llm_service
|
|||
import (
|
||||
"ai_scheduler/internal/config"
|
||||
"ai_scheduler/internal/data/impl"
|
||||
"ai_scheduler/internal/data/model"
|
||||
"ai_scheduler/internal/entitys"
|
||||
"ai_scheduler/internal/pkg"
|
||||
"ai_scheduler/internal/pkg/utils_ollama"
|
||||
"ai_scheduler/internal/pkg/utils_vllm"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type OllamaService struct {
|
||||
|
|
@ -41,12 +37,8 @@ func NewOllamaGenerate(
|
|||
|
||||
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
toolDefinitions := r.registerToolsOllama(req.Tools)
|
||||
|
||||
match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions)
|
||||
match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -70,132 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSe
|
|||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
||||
//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
||||
// if imgByte == nil {
|
||||
// return
|
||||
// }
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
//
|
||||
// desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||
// Model: r.config.Ollama.VlModel,
|
||||
// Stream: new(bool),
|
||||
// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
// Images: requireData.ImgByte,
|
||||
// KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||
// //Think: &api.ThinkValue{Value: false},
|
||||
// })
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
// return
|
||||
//}
|
||||
|
||||
var (
|
||||
prompt = make([]api.Message, 0)
|
||||
)
|
||||
content, err := r.getUserContent(ctx, requireData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompt = append(prompt, api.Message{
|
||||
Role: "system",
|
||||
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
||||
}, api.Message{
|
||||
Role: "assistant",
|
||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
||||
}, api.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
// if requireData.ImgByte == nil {
|
||||
// return
|
||||
// }
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
//
|
||||
// outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
||||
// r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
// r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
// requireData.ImgUrls,
|
||||
// )
|
||||
// if err != nil {
|
||||
// return api.GenerateResponse{}, err
|
||||
// }
|
||||
//
|
||||
// desc = api.GenerateResponse{
|
||||
// Response: outMsg.Content,
|
||||
// }
|
||||
//
|
||||
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
// return
|
||||
//}
|
||||
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) {
|
||||
var content strings.Builder
|
||||
content.WriteString(requireData.Req.Text)
|
||||
if len(requireData.ImgByte) > 0 {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(requireData.Req.Tags) > 0 {
|
||||
content.WriteString("\n")
|
||||
content.WriteString("### 工具必须使用:")
|
||||
content.WriteString(requireData.Req.Tags)
|
||||
}
|
||||
|
||||
if len(requireData.ImgByte) > 0 {
|
||||
// desc, err := r.RecognizeWithImg(ctx, requireData)
|
||||
desc, err := r.RecognizeWithImgVllm(ctx, requireData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
content.WriteString("### 上传图片解析内容:\n")
|
||||
content.WriteString(requireData.Req.Tags)
|
||||
content.WriteString(desc.Response)
|
||||
}
|
||||
|
||||
if requireData.Req.MarkHis > 0 {
|
||||
var his model.AiChatHi
|
||||
cond := builder.NewCond()
|
||||
cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis})
|
||||
err := r.chatHis.GetOneBySearchToStrut(&cond, &his)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
content.WriteString("### 引用历史聊天记录:\n")
|
||||
content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his})))
|
||||
}
|
||||
return content.String(), nil
|
||||
}
|
||||
|
||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
if requireData.ImgByte == nil {
|
||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
||||
if imgByte == nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
||||
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||
Model: r.config.Ollama.VlModel,
|
||||
Stream: new(bool),
|
||||
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
Images: requireData.ImgByte,
|
||||
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||
//Think: &api.ThinkValue{Value: false},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||
if requireData.ImgByte == nil {
|
||||
return
|
||||
}
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||
|
||||
outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
||||
r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||
r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||
requireData.ImgUrls,
|
||||
)
|
||||
if err != nil {
|
||||
return api.GenerateResponse{}, err
|
||||
}
|
||||
|
||||
desc = api.GenerateResponse{
|
||||
Response: outMsg.Content,
|
||||
}
|
||||
|
||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
||||
func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool {
|
||||
taskPrompt := make([]api.Tool, 0)
|
||||
for _, task := range tasks {
|
||||
var taskConfig entitys.TaskConfigDetail
|
||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
taskPrompt = append(taskPrompt, api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: task.Index,
|
||||
Name: task.Name,
|
||||
Description: task.Desc,
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: taskConfig.Param.Type,
|
||||
Required: taskConfig.Param.Required,
|
||||
Properties: taskConfig.Param.Properties,
|
||||
Type: task.TaskConfigDetail.Param.Type,
|
||||
Required: task.TaskConfigDetail.Param.Required,
|
||||
Properties: task.TaskConfigDetail.Param.Properties,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ type Recognize struct {
|
|||
}
|
||||
|
||||
type RegistrationTask struct {
|
||||
Name string
|
||||
Desc string
|
||||
TaskConfigDetail TaskConfigDetail
|
||||
}
|
||||
|
||||
type RecognizeUserContent struct {
|
||||
|
|
|
|||
Loading…
Reference in New Issue