301 lines
7.6 KiB
Go
301 lines
7.6 KiB
Go
package do
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/gateway"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/tmpl/dataTemp"
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/gofiber/websocket/v2"
|
||
|
||
"xorm.io/builder"
|
||
)
|
||
|
||
type Do struct {
|
||
//Ctx *entitys.RequireData
|
||
sessionImpl *impl.SessionImpl
|
||
sysImpl *impl.SysImpl
|
||
taskImpl *impl.TaskImpl
|
||
hisImpl *impl.ChatImpl
|
||
conf *config.Config
|
||
}
|
||
|
||
func NewDo(
|
||
sysImpl *impl.SysImpl,
|
||
taskImpl *impl.TaskImpl,
|
||
hisImpl *impl.ChatImpl,
|
||
conf *config.Config,
|
||
) *Do {
|
||
return &Do{
|
||
conf: conf,
|
||
sysImpl: sysImpl,
|
||
hisImpl: hisImpl,
|
||
taskImpl: taskImpl,
|
||
}
|
||
}
|
||
|
||
func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
|
||
// 1. 验证客户端数据
|
||
if err = d.validateClientData(client, requireData); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 2. 加载系统信息
|
||
if err = d.loadSystemInfo(ctx, client, requireData); err != nil {
|
||
return fmt.Errorf("获取系统信息失败: %w", err)
|
||
}
|
||
|
||
// 3. 加载任务列表
|
||
if err = d.loadTaskList(ctx, client, requireData); err != nil {
|
||
return fmt.Errorf("获取任务列表失败: %w", err)
|
||
}
|
||
|
||
// 4. 加载聊天历史
|
||
if err = d.loadChatHistory(ctx, requireData); err != nil {
|
||
return fmt.Errorf("获取历史记录失败: %w", err)
|
||
}
|
||
|
||
// 5. 加载图片数据
|
||
if err = d.getImgData(requireData); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 提取数据验证为单独函数
|
||
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
|
||
requireData.Session = client.GetSession()
|
||
if len(requireData.Session) == 0 {
|
||
return errors.SessionNotFound
|
||
}
|
||
|
||
requireData.Auth = client.GetAuth()
|
||
if len(requireData.Auth) == 0 {
|
||
return errors.AuthNotFound
|
||
}
|
||
|
||
requireData.Key = client.GetKey()
|
||
if len(requireData.Key) == 0 {
|
||
return errors.KeyNotFound
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 获取系统信息的辅助函数
|
||
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||
if sysInfo := client.GetSysInfo(); sysInfo == nil {
|
||
sys, err := d.getSysInfo(requireData)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
client.SetSysInfo(&sys)
|
||
requireData.Sys = sys
|
||
} else {
|
||
requireData.Sys = *sysInfo
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 获取任务列表的辅助函数
|
||
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
|
||
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
|
||
tasks, err := d.getTasks(requireData.Sys.SysID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
requireData.Tasks = tasks
|
||
client.SetTasks(tasks)
|
||
} else {
|
||
requireData.Tasks = taskInfo
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 获取历史记录的辅助函数
|
||
func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireData) error {
|
||
histories, err := d.getSessionChatHis(requireData)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
requireData.Histories = histories
|
||
return nil
|
||
}
|
||
|
||
func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
|
||
requireData.Ch = make(chan entitys.Response)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
done := d.startMessageHandler(ctx, c, requireData)
|
||
return ctx, func() {
|
||
close(requireData.Ch) //关闭主通道
|
||
<-done // 等待消息处理完成
|
||
cancel()
|
||
}
|
||
}
|
||
|
||
func (d *Do) getImgData(requireData *entitys.RequireData) (err error) {
|
||
if len(requireData.Req.Img) == 0 {
|
||
return
|
||
}
|
||
imgs := strings.Split(requireData.Req.Img, ",")
|
||
if len(imgs) == 0 {
|
||
return
|
||
}
|
||
|
||
for k, img := range imgs {
|
||
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
|
||
entitys.ResLog(requireData.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片")
|
||
if err = pkg.ValidateImageURL(img); err != nil {
|
||
entitys.ResLog(requireData.Ch, "", baseErr+":expected image content")
|
||
continue
|
||
}
|
||
req := l_request.Request{
|
||
Method: "GET",
|
||
Url: img,
|
||
Headers: map[string]string{
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
|
||
},
|
||
}
|
||
res, _err := req.Send()
|
||
if _err != nil {
|
||
entitys.ResLog(requireData.Ch, "", baseErr+_err.Error())
|
||
continue
|
||
}
|
||
if _, ex := res.Headers["Content-Type"]; !ex {
|
||
entitys.ResLog(requireData.Ch, "", baseErr+":Content-Type不存在")
|
||
continue
|
||
}
|
||
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
|
||
entitys.ResLog(requireData.Ch, "", baseErr+":expected image content")
|
||
continue
|
||
}
|
||
requireData.ImgByte = append(requireData.ImgByte, res.Content)
|
||
requireData.ImgUrls = append(requireData.ImgUrls, img)
|
||
entitys.ResLog(requireData.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功")
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) getRequireData() (err error) {
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"app_key": requireData.Key})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||
return
|
||
}
|
||
|
||
func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": requireData.Session})
|
||
|
||
_, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc")
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
_, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||
|
||
return
|
||
}
|
||
|
||
// startMessageHandler 启动独立的消息处理协程
|
||
func (d *Do) startMessageHandler(
|
||
ctx context.Context,
|
||
c *websocket.Conn,
|
||
requireData *entitys.RequireData,
|
||
) <-chan struct{} {
|
||
done := make(chan struct{})
|
||
var chat []string
|
||
|
||
go func() {
|
||
defer func() {
|
||
close(done)
|
||
// 保存历史记录
|
||
var (
|
||
hisLog = &entitys.ChatHisLog{}
|
||
)
|
||
if len(chat) > 0 {
|
||
AiRes := &model.AiChatHi{
|
||
SessionID: requireData.Session,
|
||
Ques: requireData.Req.Text,
|
||
Ans: strings.Join(chat, ""),
|
||
Files: requireData.Req.Img,
|
||
}
|
||
d.hisImpl.AddWithData(AiRes)
|
||
hisLog.HisId = AiRes.HisID
|
||
}
|
||
|
||
_ = entitys.MsgSend(c, entitys.Response{
|
||
Content: pkg.JsonStringIgonErr(hisLog),
|
||
Type: entitys.ResponseEnd,
|
||
})
|
||
|
||
}()
|
||
|
||
for v := range requireData.Ch { // 自动检测通道关闭
|
||
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||
log.Errorf("Send error: %v", err)
|
||
return
|
||
}
|
||
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
||
chat = append(chat, v.Content)
|
||
}
|
||
}
|
||
}()
|
||
|
||
return done
|
||
}
|
||
|
||
// 辅助函数:带超时的 WebSocket 发送
|
||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
||
}
|
||
close(done)
|
||
}()
|
||
// 如果 MsgSend 阻塞,这里会卡住
|
||
err := entitys.MsgSend(c, data)
|
||
done <- err
|
||
}()
|
||
|
||
select {
|
||
case err := <-done:
|
||
return err
|
||
case <-sendCtx.Done():
|
||
return sendCtx.Err()
|
||
}
|
||
}
|