package do import ( "ai_scheduler/internal/config" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/gateway" "ai_scheduler/internal/pkg" "ai_scheduler/tmpl/dataTemp" "context" "fmt" "strconv" "strings" "time" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "xorm.io/builder" ) type Do struct { //Ctx *entitys.RequireData sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config } func NewDo( sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, ) *Do { return &Do{ conf: conf, sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, } } func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) { // 1. 验证客户端数据 if err = d.validateClientData(client, requireData); err != nil { return err } // 2. 加载系统信息 if err = d.loadSystemInfo(ctx, client, requireData); err != nil { return fmt.Errorf("获取系统信息失败: %w", err) } // 3. 加载任务列表 if err = d.loadTaskList(ctx, client, requireData); err != nil { return fmt.Errorf("获取任务列表失败: %w", err) } // 4. 加载聊天历史 if err = d.loadChatHistory(ctx, requireData); err != nil { return fmt.Errorf("获取历史记录失败: %w", err) } // 5. 加载图片数据 if err = d.getImgData(requireData); err != nil { return err } return nil } // 提取数据验证为单独函数 func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error { requireData.Session = client.GetSession() if len(requireData.Session) == 0 { return errors.SessionNotFound } requireData.Auth = client.GetAuth() if len(requireData.Auth) == 0 { return errors.AuthNotFound } requireData.Key = client.GetKey() if len(requireData.Key) == 0 { return errors.KeyNotFound } return nil } // 获取系统信息的辅助函数 func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if sysInfo := client.GetSysInfo(); sysInfo == nil { sys, err := d.getSysInfo(requireData) if err != nil { return err } client.SetSysInfo(&sys) requireData.Sys = sys } else { requireData.Sys = *sysInfo } return nil } // 获取任务列表的辅助函数 func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error { if taskInfo := client.GetTasks(); len(taskInfo) == 0 { tasks, err := d.getTasks(requireData.Sys.SysID) if err != nil { return err } requireData.Tasks = tasks client.SetTasks(tasks) } else { requireData.Tasks = taskInfo } return nil } // 获取历史记录的辅助函数 func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireData) error { histories, err := d.getSessionChatHis(requireData) if err != nil { return err } requireData.Histories = histories return nil } func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) { requireData.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) done := d.startMessageHandler(ctx, c, requireData) return ctx, func() { close(requireData.Ch) //关闭主通道 <-done // 等待消息处理完成 cancel() } } func (d *Do) getImgData(requireData *entitys.RequireData) (err error) { if len(requireData.Req.Img) == 0 { return } imgs := strings.Split(requireData.Req.Img, ",") if len(imgs) == 0 { return } for k, img := range imgs { baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:" entitys.ResLog(requireData.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片") if err = pkg.ValidateImageURL(img); err != nil { entitys.ResLog(requireData.Ch, "", baseErr+":expected image content") continue } req := l_request.Request{ Method: "GET", Url: img, Headers: map[string]string{ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", }, } res, _err := req.Send() if _err != nil { entitys.ResLog(requireData.Ch, "", baseErr+_err.Error()) continue } if _, ex := res.Headers["Content-Type"]; !ex { entitys.ResLog(requireData.Ch, "", baseErr+":Content-Type不存在") continue } if !strings.HasPrefix(res.Headers["Content-Type"], "image/") { entitys.ResLog(requireData.Ch, "", baseErr+":expected image content") continue } requireData.ImgByte = append(requireData.ImgByte, res.Content) requireData.ImgUrls = append(requireData.ImgUrls, img) entitys.ResLog(requireData.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功") } return } func (d *Do) getRequireData() (err error) { return } func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": requireData.Key}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) return } func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": requireData.Session}) _, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc") return } func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) _, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return } // startMessageHandler 启动独立的消息处理协程 func (d *Do) startMessageHandler( ctx context.Context, c *websocket.Conn, requireData *entitys.RequireData, ) <-chan struct{} { done := make(chan struct{}) var chat []string go func() { defer func() { close(done) // 保存历史记录 var ( hisLog = &entitys.ChatHisLog{} ) if len(chat) > 0 { AiRes := &model.AiChatHi{ SessionID: requireData.Session, Ques: requireData.Req.Text, Ans: strings.Join(chat, ""), Files: requireData.Req.Img, } d.hisImpl.AddWithData(AiRes) hisLog.HisId = AiRes.HisID } _ = entitys.MsgSend(c, entitys.Response{ Content: pkg.JsonStringIgonErr(hisLog), Type: entitys.ResponseEnd, }) }() for v := range requireData.Ch { // 自动检测通道关闭 if err := sendWithTimeout(c, v, 2*time.Second); err != nil { log.Errorf("Send error: %v", err) return } if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson { chat = append(chat, v.Content) } } }() return done } // 辅助函数:带超时的 WebSocket 发送 func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() done := make(chan error, 1) go func() { defer func() { if r := recover(); r != nil { done <- fmt.Errorf("panic in MsgSend: %v", r) } close(done) }() // 如果 MsgSend 阻塞,这里会卡住 err := entitys.MsgSend(c, data) done <- err }() select { case err := <-done: return err case <-sendCtx.Done(): return sendCtx.Err() } }