refactor: optimize ollama service and update dependencies
This commit is contained in:
parent
782006b8d3
commit
586ad6a124
1
go.mod
1
go.mod
|
|
@ -58,7 +58,6 @@ require (
|
||||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
|
||||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||||
github.com/goph/emperror v0.17.2 // indirect
|
github.com/goph/emperror v0.17.2 // indirect
|
||||||
github.com/gorilla/websocket v1.5.0 // indirect
|
github.com/gorilla/websocket v1.5.0 // indirect
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -174,8 +174,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
||||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
|
||||||
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,15 @@ package llm_service
|
||||||
import (
|
import (
|
||||||
"ai_scheduler/internal/config"
|
"ai_scheduler/internal/config"
|
||||||
"ai_scheduler/internal/data/impl"
|
"ai_scheduler/internal/data/impl"
|
||||||
"ai_scheduler/internal/data/model"
|
|
||||||
"ai_scheduler/internal/entitys"
|
"ai_scheduler/internal/entitys"
|
||||||
"ai_scheduler/internal/pkg"
|
"ai_scheduler/internal/pkg"
|
||||||
"ai_scheduler/internal/pkg/utils_ollama"
|
"ai_scheduler/internal/pkg/utils_ollama"
|
||||||
"ai_scheduler/internal/pkg/utils_vllm"
|
"ai_scheduler/internal/pkg/utils_vllm"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"xorm.io/builder"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaService struct {
|
type OllamaService struct {
|
||||||
|
|
@ -41,12 +37,8 @@ func NewOllamaGenerate(
|
||||||
|
|
||||||
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSelect) (msg string, err error) {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
toolDefinitions := r.registerToolsOllama(req.Tools)
|
toolDefinitions := r.registerToolsOllama(req.Tools)
|
||||||
|
match, err := r.client.ToolSelect(ctx, req.Prompt, toolDefinitions)
|
||||||
match, err := r.client.ToolSelect(ctx, rec.Prompt, toolDefinitions)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -70,132 +62,63 @@ func (r *OllamaService) IntentRecognize(ctx context.Context, req *entitys.ToolSe
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *OllamaService) getPrompt(ctx context.Context, requireData *entitys.RequireData) ([]api.Message, error) {
|
//func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
||||||
|
// if imgByte == nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||||
|
//
|
||||||
|
// desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
||||||
|
// Model: r.config.Ollama.VlModel,
|
||||||
|
// Stream: new(bool),
|
||||||
|
// System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||||
|
// Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
|
// Images: requireData.ImgByte,
|
||||||
|
// KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
||||||
|
// //Think: &api.ThinkValue{Value: false},
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
|
||||||
var (
|
//func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
||||||
prompt = make([]api.Message, 0)
|
// if requireData.ImgByte == nil {
|
||||||
)
|
// return
|
||||||
content, err := r.getUserContent(ctx, requireData)
|
// }
|
||||||
if err != nil {
|
// entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
||||||
return nil, err
|
//
|
||||||
}
|
// outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
||||||
prompt = append(prompt, api.Message{
|
// r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
||||||
Role: "system",
|
// r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
||||||
Content: buildSystemPrompt(requireData.Sys.SysPrompt),
|
// requireData.ImgUrls,
|
||||||
}, api.Message{
|
// )
|
||||||
Role: "assistant",
|
// if err != nil {
|
||||||
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(buildAssistant(requireData.Histories)),
|
// return api.GenerateResponse{}, err
|
||||||
}, api.Message{
|
// }
|
||||||
Role: "user",
|
//
|
||||||
Content: content,
|
// desc = api.GenerateResponse{
|
||||||
})
|
// Response: outMsg.Content,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
|
||||||
return prompt, nil
|
func (r *OllamaService) registerToolsOllama(tasks []entitys.RegistrationTask) []api.Tool {
|
||||||
}
|
|
||||||
|
|
||||||
func (r *OllamaService) getUserContent(ctx context.Context, requireData *entitys.RequireData) (string, error) {
|
|
||||||
var content strings.Builder
|
|
||||||
content.WriteString(requireData.Req.Text)
|
|
||||||
if len(requireData.ImgByte) > 0 {
|
|
||||||
content.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(requireData.Req.Tags) > 0 {
|
|
||||||
content.WriteString("\n")
|
|
||||||
content.WriteString("### 工具必须使用:")
|
|
||||||
content.WriteString(requireData.Req.Tags)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(requireData.ImgByte) > 0 {
|
|
||||||
// desc, err := r.RecognizeWithImg(ctx, requireData)
|
|
||||||
desc, err := r.RecognizeWithImgVllm(ctx, requireData)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
content.WriteString("### 上传图片解析内容:\n")
|
|
||||||
content.WriteString(requireData.Req.Tags)
|
|
||||||
content.WriteString(desc.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
if requireData.Req.MarkHis > 0 {
|
|
||||||
var his model.AiChatHi
|
|
||||||
cond := builder.NewCond()
|
|
||||||
cond = cond.And(builder.Eq{"his_id": requireData.Req.MarkHis})
|
|
||||||
err := r.chatHis.GetOneBySearchToStrut(&cond, &his)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
content.WriteString("### 引用历史聊天记录:\n")
|
|
||||||
content.WriteString(pkg.JsonStringIgonErr(BuildChatHisMessage([]model.AiChatHi{his})))
|
|
||||||
}
|
|
||||||
return content.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
|
||||||
if requireData.ImgByte == nil {
|
|
||||||
func (r *OllamaService) RecognizeWithImg(ctx context.Context, imgByte []api.ImageData, ch chan entitys.Response) (desc api.GenerateResponse, err error) {
|
|
||||||
if imgByte == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
|
||||||
|
|
||||||
desc, err = r.client.Generation(ctx, &api.GenerateRequest{
|
|
||||||
Model: r.config.Ollama.VlModel,
|
|
||||||
Stream: new(bool),
|
|
||||||
System: r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
|
||||||
Prompt: r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
|
||||||
Images: requireData.ImgByte,
|
|
||||||
KeepAlive: &api.Duration{Duration: 3600 * time.Second},
|
|
||||||
//Think: &api.ThinkValue{Value: false},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *OllamaService) RecognizeWithImgVllm(ctx context.Context, requireData *entitys.RequireData) (desc api.GenerateResponse, err error) {
|
|
||||||
if requireData.ImgByte == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entitys.ResLog(requireData.Ch, "recognize_img_start", "图片识别中...")
|
|
||||||
|
|
||||||
outMsg, err := r.vllmClient.RecognizeWithImg(ctx,
|
|
||||||
r.config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
|
||||||
r.config.DefaultPrompt.ImgRecognize.UserPrompt,
|
|
||||||
requireData.ImgUrls,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return api.GenerateResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
desc = api.GenerateResponse{
|
|
||||||
Response: outMsg.Content,
|
|
||||||
}
|
|
||||||
|
|
||||||
entitys.ResLog(requireData.Ch, "recognize_img_end", "图片识别完成,识别内容:"+desc.Response)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *OllamaService) registerToolsOllama(tasks []model.AiTask) []api.Tool {
|
|
||||||
taskPrompt := make([]api.Tool, 0)
|
taskPrompt := make([]api.Tool, 0)
|
||||||
for _, task := range tasks {
|
for _, task := range tasks {
|
||||||
var taskConfig entitys.TaskConfigDetail
|
|
||||||
err := json.Unmarshal([]byte(task.Config), &taskConfig)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
taskPrompt = append(taskPrompt, api.Tool{
|
taskPrompt = append(taskPrompt, api.Tool{
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: task.Index,
|
Name: task.Name,
|
||||||
Description: task.Desc,
|
Description: task.Desc,
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: taskConfig.Param.Type,
|
Type: task.TaskConfigDetail.Param.Type,
|
||||||
Required: taskConfig.Param.Required,
|
Required: task.TaskConfigDetail.Param.Required,
|
||||||
Properties: taskConfig.Param.Properties,
|
Properties: task.TaskConfigDetail.Param.Properties,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,9 @@ type Recognize struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type RegistrationTask struct {
|
type RegistrationTask struct {
|
||||||
|
Name string
|
||||||
|
Desc string
|
||||||
|
TaskConfigDetail TaskConfigDetail
|
||||||
}
|
}
|
||||||
|
|
||||||
type RecognizeUserContent struct {
|
type RecognizeUserContent struct {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue