ai_scheduler/internal/biz/advice_chat.go

355 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/biz/llm_service/third_party"
"ai_scheduler/internal/data/constants"
"ai_scheduler/internal/data/mongo_model"
"ai_scheduler/internal/data/impl"
dbmodel "ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"ai_scheduler/utils"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
"github.com/volcengine/volcengine-go-sdk/volcengine"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
"xorm.io/builder"
"strings"
"time"
)
type AdviceChatBiz struct {
hsyq *third_party.Hsyq
rdb *utils.Rdb
aiAdviceSessionImpl *impl.AiAdviceSessionImpl
aiAdviceModelSupImpl *impl.AiAdviceModelSupImpl
advicerChatHisMongo *mongo_model.AdvicerChatHisMongo
mongo *pkg.Mongo
}
func NewAdviceChatBiz(
hsyq *third_party.Hsyq,
rdb *utils.Rdb,
aiAdviceSessionImpl *impl.AiAdviceSessionImpl,
aiAdviceModelSupImpl *impl.AiAdviceModelSupImpl,
advicerChatHisMongo *mongo_model.AdvicerChatHisMongo,
mongo *pkg.Mongo,
) *AdviceChatBiz {
return &AdviceChatBiz{
hsyq: hsyq,
rdb: rdb,
aiAdviceSessionImpl: aiAdviceSessionImpl,
aiAdviceModelSupImpl: aiAdviceModelSupImpl,
advicerChatHisMongo: advicerChatHisMongo,
mongo: mongo,
}
}
func (a *AdviceChatBiz) contextCache(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq, projectInfo *entitys.AdvicerProjectInfoRes) (promptJson string, contextCache string, err error) {
switch constants.Mode(projectInfo.ModelInfo.Mode) {
case constants.ModeResponse:
prompt, err := a.buildBasePromptResponse(ctx, chatData, req)
if err != nil {
return "", "", err
}
cache, err := a.hsyq.CreateResponse(ctx, projectInfo.ModelInfo.Key, projectInfo.ModelInfo.ChatModel, prompt, "", true)
if err != nil {
return "", "", err
}
contextCache = cache.Id
promptJson = pkg.JsonStringIgonErr(prompt)
case constants.ModeContext:
prompt, err := a.buildBasePromptContext(ctx, chatData, req)
if err != nil {
return "", "", err
}
contextCache, err = a.hsyq.CreateContextCache(ctx, projectInfo.ModelInfo.Key, projectInfo.ModelInfo.ChatModel, prompt)
if err != nil {
return "", "", err
}
promptJson = pkg.JsonStringIgonErr(prompt)
default:
return "", "", fmt.Errorf("未知的mode类型%d", projectInfo.ModelInfo.Mode)
}
return
}
func (a *AdviceChatBiz) Regis(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq, projectInfo *entitys.AdvicerProjectInfoRes) (string, error) {
promptJson, contextCache, err := a.contextCache(ctx, chatData, req, projectInfo)
if err != nil {
return "", err
}
sessionId := uuid.New().String()
//创建会话
_, err = a.aiAdviceSessionImpl.Add(&dbmodel.AiAdviceSession{
SessionID: sessionId,
ProjectID: projectInfo.Base.ProjectID,
SupID: projectInfo.Base.ModelSupID,
AdvicerVersionID: req.AdvicerVersionId,
ClientID: req.ClientId,
TalkSkillID: req.TalkSkillId,
Mission: req.Mission,
ContextCache: contextCache,
CreateAt: time.Now(),
})
if err != nil {
return "", err
}
err = a.rdb.Rdb.SetEx(ctx, sessionId, promptJson, 3600*time.Second).Err()
return sessionId, err
}
func (a *AdviceChatBiz) Chat(ctx context.Context, chat *entitys.AdvicerChatReq) (assistant mongo_model.Assistant, err error) {
var session dbmodel.AiAdviceSession
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": chat.SessionId})
err = a.aiAdviceSessionImpl.GetOneBySearchToStrut(&cond, &session)
if err != nil {
return assistant, err
}
if session.SessionID == "" {
return assistant, errors.New("未找到会话信息")
}
if len(chat.Content) == 0 {
return assistant, nil
}
var modelInfo dbmodel.AiAdviceModelSup
cond = builder.NewCond()
cond = cond.And(builder.Eq{"sup_id": session.SupID})
err = a.aiAdviceModelSupImpl.GetOneBySearchToStrut(&cond, &modelInfo)
if err != nil {
return assistant, err
}
if modelInfo.SupID == 0 {
return assistant, errors.New("未找到模型信息")
}
//basePromptJson, err := a.getChatDataFromStringSessionId(ctx, chat.SessionId)
//if err != nil {
// return nil, err
//}
chatHis, err := a.getChatHis(ctx, session.SessionID, 6)
prompt, err := a.buildChatPromptResponse(ctx, chat, &session, chatHis)
if err != nil {
return assistant, err
}
resContent, err := a.callLlmResponse(ctx, prompt, modelInfo.Key, modelInfo.ChatModel, session.ContextCache)
if err != nil {
return assistant, err
}
result := resContent.Output[0].GetOutputMessage().Content[0].GetText().GetText()
if err = json.Unmarshal([]byte(result), &assistant); err != nil {
return assistant, err
}
chatCtx, cancel := context.WithCancel(context.Background())
go func(session dbmodel.AiAdviceSession) {
defer cancel()
_, _ = a.mongo.Co(a.advicerChatHisMongo).InsertOne(chatCtx, &mongo_model.AdvicerChatHisMongo{
SessionId: chat.SessionId,
User: chat.Content,
Assistant: assistant,
InToken: resContent.Usage.InputTokens,
OutToken: resContent.Usage.OutputTokens,
CreatAt: time.Now(),
})
if assistant.MissionStatus == "fail" || assistant.MissionStatus == "complete" {
cond = builder.NewCond()
cond = cond.And(builder.Eq{"session_id": chat.SessionId})
session.MissionStatus = assistant.MissionStatus
session.MissionCompleteDesc = assistant.MissionCompleteDesc
_ = a.aiAdviceSessionImpl.UpdateByCond(&cond, session)
}
}(session)
return
}
func (a *AdviceChatBiz) buildChatPromptResponse(ctx context.Context, chat *entitys.AdvicerChatReq, session *dbmodel.AiAdviceSession, chatList []mongo_model.AdvicerChatHisMongoEntity) ([]*responses.InputItem, error) {
var message = make([]*responses.InputItem, 3)
message[0] = &responses.InputItem{
Union: &responses.InputItem_EasyMessage{
EasyMessage: &responses.ItemEasyMessage{
Role: responses.MessageRole_system,
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: a.taskPrompt(session)}},
},
},
}
message[1] = &responses.InputItem{
Union: &responses.InputItem_EasyMessage{
EasyMessage: &responses.ItemEasyMessage{
Role: responses.MessageRole_system,
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: "历史聊天记录:\n" + pkg.JsonStringIgonErr(chatList)}},
},
},
}
message[2] = &responses.InputItem{
Union: &responses.InputItem_EasyMessage{
EasyMessage: &responses.ItemEasyMessage{
Role: responses.MessageRole_user,
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: chat.Content}},
},
},
}
return message, nil
}
func (a *AdviceChatBiz) getChatHis(ctx context.Context, sessionId string, limit int64) (chatList []mongo_model.AdvicerChatHisMongoEntity, err error) {
chatList = make([]mongo_model.AdvicerChatHisMongoEntity, 0)
filter := bson.M{}
filter["sessionId"] = sessionId
cursor, err := a.mongo.Co(a.advicerChatHisMongo).Find(ctx, filter, options.Find().SetLimit(limit))
if err != nil {
return chatList, err
}
for cursor.Next(ctx) {
var chatHIS mongo_model.AdvicerChatHisMongo
if err := cursor.Decode(&chatHIS); err != nil {
return nil, err
}
chatList = append(chatList, chatHIS.Entity())
}
if err := cursor.Err(); err != nil {
return nil, err
}
return
}
func (a *AdviceChatBiz) buildChatPrompt(ctx context.Context, chat *entitys.AdvicerChatReq, session *dbmodel.AiAdviceSession, modelInfo *dbmodel.AiAdviceModelSup) (model.ContextChatCompletionRequest, error) {
var message = make([]*model.ChatCompletionMessage, 2)
message[0] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(chat.Content),
},
}
message[1] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleAssistant,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(a.taskPrompt(session)),
},
}
req := model.ContextChatCompletionRequest{
ContextID: session.ContextCache,
Model: modelInfo.ChatModel,
Messages: message,
Stream: false,
}
return req, nil
}
func (a *AdviceChatBiz) taskPrompt(session *dbmodel.AiAdviceSession) string {
//type mission struct {
// missionName string
// status string
// missionCompleteDesc string
//}
//var m = &mission{
// missionName: session.Mission,
// status: pkg.Ter(session.MissionStatus == 1, "进行中", "已完成"),
// missionCompleteDesc: session.MissionCompleteDesc,
//}
//missionJon, _ := json.Marshal(m)
return "[当前时间]" + time.Now().Format("2006-01-02 15:04:05")
}
func (a *AdviceChatBiz) buildBasePromptContext(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq) ([]*model.ChatCompletionMessage, error) {
var message = make([]*model.ChatCompletionMessage, 2)
message[0] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleSystem,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(a.sysPrompt(chatData, req)),
},
}
message[1] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleSystem,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(a.assistantPrompt(chatData)),
},
}
return message, nil
}
func (a *AdviceChatBiz) buildBasePromptResponse(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq) ([]*responses.InputItem, error) {
var message = make([]*responses.InputItem, 2)
message[0] = &responses.InputItem{
Union: &responses.InputItem_EasyMessage{
EasyMessage: &responses.ItemEasyMessage{
Role: responses.MessageRole_system,
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: a.sysPrompt(chatData, req)}},
},
},
}
message[1] = &responses.InputItem{
Union: &responses.InputItem_EasyMessage{
EasyMessage: &responses.ItemEasyMessage{
Role: responses.MessageRole_system,
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: a.assistantPrompt(chatData)}},
},
},
}
return message, nil
}
func (a *AdviceChatBiz) setContent(ctx context.Context, basePromptJson string, content string, session *dbmodel.AiAdviceSession) ([]*model.ChatCompletionMessage, error) {
promptJson := strings.ReplaceAll(basePromptJson, "{{chat_content}}", content)
var basePrompt []*model.ChatCompletionMessage
err := json.Unmarshal([]byte(promptJson), &basePrompt)
if err != nil {
return nil, err
}
return basePrompt, nil
}
func (a *AdviceChatBiz) sysPrompt(chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq) string {
var prompt strings.Builder
prompt.WriteString(constants.BasePrompt)
prompt.WriteString(req.Mission)
prompt.WriteString(constants.BasePrompt2)
return prompt.String()
}
func (a *AdviceChatBiz) assistantPrompt(chatData *entitys.ChatData) string {
return pkg.JsonStringIgonErr(chatData)
}
func (a *AdviceChatBiz) getChatDataFromStringSessionId(ctx context.Context, sessionId string) (basePromptJson string, err error) {
cache := a.rdb.Rdb.Get(ctx, sessionId)
if cache.Err() != nil {
err = cache.Err()
return
}
return cache.Val(), cache.Err()
}
func (a *AdviceChatBiz) callLlm(ctx context.Context, request model.ContextChatCompletionRequest, key string) (string, error) {
res, err := a.hsyq.ChatWithRequest(ctx, key, request)
if err != nil {
return "", err
}
return *res.Choices[0].Message.Content.StringValue, nil
}
func (a *AdviceChatBiz) callLlmResponse(ctx context.Context, request []*responses.InputItem, key string, modelName string, id string) (*responses.ResponseObject, error) {
res, err := a.hsyq.CreateResponse(ctx, key, modelName, request, id, false)
if err != nil {
return nil, err
}
return res, nil
}