169 lines
4.6 KiB
Go
169 lines
4.6 KiB
Go
package do
|
|
|
|
import (
|
|
"ai_scheduler/internal/biz/handle"
|
|
"ai_scheduler/internal/config"
|
|
"ai_scheduler/internal/data/constants"
|
|
"ai_scheduler/internal/entitys"
|
|
"ai_scheduler/internal/pkg"
|
|
"ai_scheduler/internal/pkg/utils_vllm"
|
|
"context"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
type PromptOption interface {
|
|
CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error)
|
|
}
|
|
|
|
type WithSys struct {
|
|
Config *config.Config
|
|
}
|
|
|
|
func (f *WithSys) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
|
var (
|
|
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
|
|
)
|
|
// 获取用户内容,如果出错则直接返回错误
|
|
content, err := f.getUserContent(ctx, rec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// 构建提示消息列表,包含系统提示、助手回复和用户内容
|
|
mes = append(prompt, api.Message{
|
|
Role: "system", // 系统角色
|
|
Content: rec.SystemPrompt, // 系统提示内容
|
|
}, api.Message{
|
|
Role: "assistant", // 助手角色
|
|
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
|
|
}, api.Message{
|
|
Role: "user", // 用户角色
|
|
Content: content.String(), // 用户输入内容
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
func (f *WithSys) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
|
var hasFile bool
|
|
if len(rec.UserContent.File) > 0 {
|
|
hasFile = true
|
|
}
|
|
content.WriteString(rec.UserContent.Text)
|
|
if hasFile {
|
|
content.WriteString("\n")
|
|
}
|
|
|
|
if len(rec.UserContent.Tag) > 0 {
|
|
content.WriteString("\n")
|
|
content.WriteString("### 工具必须使用:")
|
|
content.WriteString(rec.UserContent.Tag)
|
|
}
|
|
|
|
if len(rec.ChatHis.Messages) > 0 {
|
|
content.WriteString("### 引用历史聊天记录:\n")
|
|
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
|
|
}
|
|
|
|
if hasFile {
|
|
content.WriteString("\n")
|
|
content.WriteString("### 文件内容:\n")
|
|
for _, file := range rec.UserContent.File {
|
|
handle.HandleRecognizeFile(file)
|
|
// 文件识别
|
|
switch file.FileType {
|
|
case constants.FileTypeImage:
|
|
entitys.ResLog(rec.Ch, "recognize_img_start", "图片识别中...")
|
|
var imageContent string
|
|
imageContent, err = f.recognizeWithImgVllm(ctx, file)
|
|
if err != nil {
|
|
return
|
|
}
|
|
entitys.ResLog(rec.Ch, "recognize_img_end", "图片识别完成,识别内容:"+imageContent)
|
|
|
|
// 解析结果回写到file
|
|
file.FileRec = imageContent
|
|
default:
|
|
content.WriteString(file.FileRec)
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (f *WithSys) recognizeWithImgVllm(ctx context.Context, file *entitys.RecognizeFile) (content string, err error) {
|
|
if file.FileData == nil || file.FileType != constants.FileTypeImage {
|
|
return
|
|
}
|
|
|
|
client, cleanup, err := utils_vllm.NewClient(f.Config)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer cleanup()
|
|
|
|
outMsg, err := client.RecognizeWithImgBytes(ctx,
|
|
f.Config.DefaultPrompt.ImgRecognize.SystemPrompt,
|
|
f.Config.DefaultPrompt.ImgRecognize.UserPrompt,
|
|
file.FileData,
|
|
file.FileRealMime,
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return outMsg.Content, nil
|
|
}
|
|
|
|
type WithDingTalkBot struct {
|
|
}
|
|
|
|
func (f *WithDingTalkBot) CreatePrompt(ctx context.Context, rec *entitys.Recognize) (mes []api.Message, err error) {
|
|
var (
|
|
prompt = make([]api.Message, 0) // 初始化一个空的api.Message切片
|
|
)
|
|
// 获取用户内容,如果出错则直接返回错误
|
|
content, err := f.getUserContent(ctx, rec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// 构建提示消息列表,包含系统提示、助手回复和用户内容
|
|
mes = append(prompt, api.Message{
|
|
Role: "system", // 系统角色
|
|
Content: rec.SystemPrompt, // 系统提示内容
|
|
}, api.Message{
|
|
Role: "assistant", // 助手角色
|
|
Content: "### 聊天记录:" + pkg.JsonStringIgonErr(rec.ChatHis), // 助手回复内容
|
|
}, api.Message{
|
|
Role: "user", // 用户角色
|
|
Content: content.String(), // 用户输入内容
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
func (f *WithDingTalkBot) getUserContent(ctx context.Context, rec *entitys.Recognize) (content strings.Builder, err error) {
|
|
var hasFile bool
|
|
if rec.UserContent.File != nil && len(rec.UserContent.File) > 0 {
|
|
hasFile = true
|
|
}
|
|
content.WriteString(rec.UserContent.Text)
|
|
if hasFile {
|
|
content.WriteString("\n")
|
|
}
|
|
|
|
if len(rec.UserContent.Tag) > 0 {
|
|
content.WriteString("\n")
|
|
content.WriteString("### 工具必须使用:")
|
|
content.WriteString(rec.UserContent.Tag)
|
|
}
|
|
|
|
if len(rec.ChatHis.Messages) > 0 {
|
|
content.WriteString("### 引用历史聊天记录:\n")
|
|
content.WriteString(pkg.JsonStringIgonErr(rec.ChatHis))
|
|
}
|
|
|
|
return
|
|
}
|