355 lines
12 KiB
Go
355 lines
12 KiB
Go
package biz
|
||
|
||
import (
|
||
"ai_scheduler/internal/biz/llm_service/third_party"
|
||
"ai_scheduler/internal/data/constants"
|
||
"ai_scheduler/internal/data/mongo_model"
|
||
|
||
"ai_scheduler/internal/data/impl"
|
||
dbmodel "ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/utils"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
|
||
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
"xorm.io/builder"
|
||
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type AdviceChatBiz struct {
|
||
hsyq *third_party.Hsyq
|
||
rdb *utils.Rdb
|
||
aiAdviceSessionImpl *impl.AiAdviceSessionImpl
|
||
aiAdviceModelSupImpl *impl.AiAdviceModelSupImpl
|
||
advicerChatHisMongo *mongo_model.AdvicerChatHisMongo
|
||
mongo *pkg.Mongo
|
||
}
|
||
|
||
func NewAdviceChatBiz(
|
||
hsyq *third_party.Hsyq,
|
||
rdb *utils.Rdb,
|
||
aiAdviceSessionImpl *impl.AiAdviceSessionImpl,
|
||
aiAdviceModelSupImpl *impl.AiAdviceModelSupImpl,
|
||
advicerChatHisMongo *mongo_model.AdvicerChatHisMongo,
|
||
mongo *pkg.Mongo,
|
||
) *AdviceChatBiz {
|
||
return &AdviceChatBiz{
|
||
hsyq: hsyq,
|
||
rdb: rdb,
|
||
aiAdviceSessionImpl: aiAdviceSessionImpl,
|
||
aiAdviceModelSupImpl: aiAdviceModelSupImpl,
|
||
advicerChatHisMongo: advicerChatHisMongo,
|
||
mongo: mongo,
|
||
}
|
||
}
|
||
|
||
func (a *AdviceChatBiz) contextCache(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq, projectInfo *entitys.AdvicerProjectInfoRes) (promptJson string, contextCache string, err error) {
|
||
switch constants.Mode(projectInfo.ModelInfo.Mode) {
|
||
case constants.ModeResponse:
|
||
prompt, err := a.buildBasePromptResponse(ctx, chatData, req)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
cache, err := a.hsyq.CreateResponse(ctx, projectInfo.ModelInfo.Key, projectInfo.ModelInfo.ChatModel, prompt, "", true)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
contextCache = cache.Id
|
||
promptJson = pkg.JsonStringIgonErr(prompt)
|
||
case constants.ModeContext:
|
||
prompt, err := a.buildBasePromptContext(ctx, chatData, req)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
contextCache, err = a.hsyq.CreateContextCache(ctx, projectInfo.ModelInfo.Key, projectInfo.ModelInfo.ChatModel, prompt)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
promptJson = pkg.JsonStringIgonErr(prompt)
|
||
default:
|
||
return "", "", fmt.Errorf("未知的mode类型:%d", projectInfo.ModelInfo.Mode)
|
||
}
|
||
return
|
||
}
|
||
|
||
func (a *AdviceChatBiz) Regis(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq, projectInfo *entitys.AdvicerProjectInfoRes) (string, error) {
|
||
promptJson, contextCache, err := a.contextCache(ctx, chatData, req, projectInfo)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
sessionId := uuid.New().String()
|
||
//创建会话
|
||
_, err = a.aiAdviceSessionImpl.Add(&dbmodel.AiAdviceSession{
|
||
SessionID: sessionId,
|
||
ProjectID: projectInfo.Base.ProjectID,
|
||
SupID: projectInfo.Base.ModelSupID,
|
||
AdvicerVersionID: req.AdvicerVersionId,
|
||
ClientID: req.ClientId,
|
||
TalkSkillID: req.TalkSkillId,
|
||
Mission: req.Mission,
|
||
ContextCache: contextCache,
|
||
CreateAt: time.Now(),
|
||
})
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
err = a.rdb.Rdb.SetEx(ctx, sessionId, promptJson, 3600*time.Second).Err()
|
||
|
||
return sessionId, err
|
||
}
|
||
|
||
func (a *AdviceChatBiz) Chat(ctx context.Context, chat *entitys.AdvicerChatReq) (assistant mongo_model.Assistant, err error) {
|
||
var session dbmodel.AiAdviceSession
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": chat.SessionId})
|
||
err = a.aiAdviceSessionImpl.GetOneBySearchToStrut(&cond, &session)
|
||
if err != nil {
|
||
return assistant, err
|
||
}
|
||
if session.SessionID == "" {
|
||
return assistant, errors.New("未找到会话信息")
|
||
}
|
||
if len(chat.Content) == 0 {
|
||
return assistant, nil
|
||
}
|
||
var modelInfo dbmodel.AiAdviceModelSup
|
||
cond = builder.NewCond()
|
||
cond = cond.And(builder.Eq{"sup_id": session.SupID})
|
||
err = a.aiAdviceModelSupImpl.GetOneBySearchToStrut(&cond, &modelInfo)
|
||
if err != nil {
|
||
return assistant, err
|
||
}
|
||
if modelInfo.SupID == 0 {
|
||
return assistant, errors.New("未找到模型信息")
|
||
}
|
||
//basePromptJson, err := a.getChatDataFromStringSessionId(ctx, chat.SessionId)
|
||
//if err != nil {
|
||
// return nil, err
|
||
//}
|
||
chatHis, err := a.getChatHis(ctx, session.SessionID, 6)
|
||
prompt, err := a.buildChatPromptResponse(ctx, chat, &session, chatHis)
|
||
if err != nil {
|
||
return assistant, err
|
||
}
|
||
resContent, err := a.callLlmResponse(ctx, prompt, modelInfo.Key, modelInfo.ChatModel, session.ContextCache)
|
||
if err != nil {
|
||
return assistant, err
|
||
}
|
||
|
||
result := resContent.Output[0].GetOutputMessage().Content[0].GetText().GetText()
|
||
if err = json.Unmarshal([]byte(result), &assistant); err != nil {
|
||
return assistant, err
|
||
}
|
||
chatCtx, cancel := context.WithCancel(context.Background())
|
||
go func(session dbmodel.AiAdviceSession) {
|
||
defer cancel()
|
||
_, _ = a.mongo.Co(a.advicerChatHisMongo).InsertOne(chatCtx, &mongo_model.AdvicerChatHisMongo{
|
||
SessionId: chat.SessionId,
|
||
User: chat.Content,
|
||
Assistant: assistant,
|
||
InToken: resContent.Usage.InputTokens,
|
||
OutToken: resContent.Usage.OutputTokens,
|
||
CreatAt: time.Now(),
|
||
})
|
||
if assistant.MissionStatus == "fail" || assistant.MissionStatus == "complete" {
|
||
cond = builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": chat.SessionId})
|
||
session.MissionStatus = assistant.MissionStatus
|
||
session.MissionCompleteDesc = assistant.MissionCompleteDesc
|
||
_ = a.aiAdviceSessionImpl.UpdateByCond(&cond, session)
|
||
}
|
||
}(session)
|
||
return
|
||
}
|
||
|
||
func (a *AdviceChatBiz) buildChatPromptResponse(ctx context.Context, chat *entitys.AdvicerChatReq, session *dbmodel.AiAdviceSession, chatList []mongo_model.AdvicerChatHisMongoEntity) ([]*responses.InputItem, error) {
|
||
|
||
var message = make([]*responses.InputItem, 3)
|
||
message[0] = &responses.InputItem{
|
||
Union: &responses.InputItem_EasyMessage{
|
||
EasyMessage: &responses.ItemEasyMessage{
|
||
Role: responses.MessageRole_system,
|
||
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: a.taskPrompt(session)}},
|
||
},
|
||
},
|
||
}
|
||
message[1] = &responses.InputItem{
|
||
Union: &responses.InputItem_EasyMessage{
|
||
EasyMessage: &responses.ItemEasyMessage{
|
||
Role: responses.MessageRole_system,
|
||
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: "历史聊天记录:\n" + pkg.JsonStringIgonErr(chatList)}},
|
||
},
|
||
},
|
||
}
|
||
message[2] = &responses.InputItem{
|
||
Union: &responses.InputItem_EasyMessage{
|
||
EasyMessage: &responses.ItemEasyMessage{
|
||
Role: responses.MessageRole_user,
|
||
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: chat.Content}},
|
||
},
|
||
},
|
||
}
|
||
|
||
return message, nil
|
||
}
|
||
func (a *AdviceChatBiz) getChatHis(ctx context.Context, sessionId string, limit int64) (chatList []mongo_model.AdvicerChatHisMongoEntity, err error) {
|
||
chatList = make([]mongo_model.AdvicerChatHisMongoEntity, 0)
|
||
filter := bson.M{}
|
||
filter["sessionId"] = sessionId
|
||
cursor, err := a.mongo.Co(a.advicerChatHisMongo).Find(ctx, filter, options.Find().SetLimit(limit))
|
||
if err != nil {
|
||
return chatList, err
|
||
}
|
||
for cursor.Next(ctx) {
|
||
var chatHIS mongo_model.AdvicerChatHisMongo
|
||
if err := cursor.Decode(&chatHIS); err != nil {
|
||
return nil, err
|
||
}
|
||
chatList = append(chatList, chatHIS.Entity())
|
||
}
|
||
|
||
if err := cursor.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return
|
||
}
|
||
|
||
func (a *AdviceChatBiz) buildChatPrompt(ctx context.Context, chat *entitys.AdvicerChatReq, session *dbmodel.AiAdviceSession, modelInfo *dbmodel.AiAdviceModelSup) (model.ContextChatCompletionRequest, error) {
|
||
var message = make([]*model.ChatCompletionMessage, 2)
|
||
message[0] = &model.ChatCompletionMessage{
|
||
Role: model.ChatMessageRoleUser,
|
||
Content: &model.ChatCompletionMessageContent{
|
||
StringValue: volcengine.String(chat.Content),
|
||
},
|
||
}
|
||
message[1] = &model.ChatCompletionMessage{
|
||
Role: model.ChatMessageRoleAssistant,
|
||
Content: &model.ChatCompletionMessageContent{
|
||
StringValue: volcengine.String(a.taskPrompt(session)),
|
||
},
|
||
}
|
||
|
||
req := model.ContextChatCompletionRequest{
|
||
ContextID: session.ContextCache,
|
||
Model: modelInfo.ChatModel,
|
||
Messages: message,
|
||
Stream: false,
|
||
}
|
||
return req, nil
|
||
}
|
||
|
||
func (a *AdviceChatBiz) taskPrompt(session *dbmodel.AiAdviceSession) string {
|
||
//type mission struct {
|
||
// missionName string
|
||
// status string
|
||
// missionCompleteDesc string
|
||
//}
|
||
//var m = &mission{
|
||
// missionName: session.Mission,
|
||
// status: pkg.Ter(session.MissionStatus == 1, "进行中", "已完成"),
|
||
// missionCompleteDesc: session.MissionCompleteDesc,
|
||
//}
|
||
//missionJon, _ := json.Marshal(m)
|
||
return "[当前时间]" + time.Now().Format("2006-01-02 15:04:05")
|
||
}
|
||
|
||
func (a *AdviceChatBiz) buildBasePromptContext(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq) ([]*model.ChatCompletionMessage, error) {
|
||
var message = make([]*model.ChatCompletionMessage, 2)
|
||
message[0] = &model.ChatCompletionMessage{
|
||
Role: model.ChatMessageRoleSystem,
|
||
Content: &model.ChatCompletionMessageContent{
|
||
StringValue: volcengine.String(a.sysPrompt(chatData, req)),
|
||
},
|
||
}
|
||
message[1] = &model.ChatCompletionMessage{
|
||
Role: model.ChatMessageRoleSystem,
|
||
Content: &model.ChatCompletionMessageContent{
|
||
StringValue: volcengine.String(a.assistantPrompt(chatData)),
|
||
},
|
||
}
|
||
return message, nil
|
||
}
|
||
|
||
func (a *AdviceChatBiz) buildBasePromptResponse(ctx context.Context, chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq) ([]*responses.InputItem, error) {
|
||
var message = make([]*responses.InputItem, 2)
|
||
message[0] = &responses.InputItem{
|
||
Union: &responses.InputItem_EasyMessage{
|
||
EasyMessage: &responses.ItemEasyMessage{
|
||
Role: responses.MessageRole_system,
|
||
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: a.sysPrompt(chatData, req)}},
|
||
},
|
||
},
|
||
}
|
||
message[1] = &responses.InputItem{
|
||
Union: &responses.InputItem_EasyMessage{
|
||
EasyMessage: &responses.ItemEasyMessage{
|
||
Role: responses.MessageRole_system,
|
||
Content: &responses.MessageContent{Union: &responses.MessageContent_StringValue{StringValue: a.assistantPrompt(chatData)}},
|
||
},
|
||
},
|
||
}
|
||
return message, nil
|
||
}
|
||
|
||
func (a *AdviceChatBiz) setContent(ctx context.Context, basePromptJson string, content string, session *dbmodel.AiAdviceSession) ([]*model.ChatCompletionMessage, error) {
|
||
promptJson := strings.ReplaceAll(basePromptJson, "{{chat_content}}", content)
|
||
var basePrompt []*model.ChatCompletionMessage
|
||
err := json.Unmarshal([]byte(promptJson), &basePrompt)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return basePrompt, nil
|
||
}
|
||
|
||
func (a *AdviceChatBiz) sysPrompt(chatData *entitys.ChatData, req *entitys.AdvicerChatRegistReq) string {
|
||
var prompt strings.Builder
|
||
prompt.WriteString(constants.BasePrompt)
|
||
prompt.WriteString(req.Mission)
|
||
prompt.WriteString(constants.BasePrompt2)
|
||
|
||
return prompt.String()
|
||
|
||
}
|
||
|
||
func (a *AdviceChatBiz) assistantPrompt(chatData *entitys.ChatData) string {
|
||
return pkg.JsonStringIgonErr(chatData)
|
||
}
|
||
|
||
func (a *AdviceChatBiz) getChatDataFromStringSessionId(ctx context.Context, sessionId string) (basePromptJson string, err error) {
|
||
cache := a.rdb.Rdb.Get(ctx, sessionId)
|
||
if cache.Err() != nil {
|
||
err = cache.Err()
|
||
return
|
||
}
|
||
|
||
return cache.Val(), cache.Err()
|
||
}
|
||
|
||
func (a *AdviceChatBiz) callLlm(ctx context.Context, request model.ContextChatCompletionRequest, key string) (string, error) {
|
||
res, err := a.hsyq.ChatWithRequest(ctx, key, request)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return *res.Choices[0].Message.Content.StringValue, nil
|
||
}
|
||
|
||
func (a *AdviceChatBiz) callLlmResponse(ctx context.Context, request []*responses.InputItem, key string, modelName string, id string) (*responses.ResponseObject, error) {
|
||
res, err := a.hsyq.CreateResponse(ctx, key, modelName, request, id, false)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return res, nil
|
||
}
|