package do import ( "ai_scheduler/internal/config" errors "ai_scheduler/internal/data/error" "ai_scheduler/internal/data/impl" "ai_scheduler/internal/data/model" "ai_scheduler/internal/entitys" "ai_scheduler/internal/pkg" "ai_scheduler/tmpl/dataTemp" "context" "fmt" "strconv" "strings" "time" "gitea.cdlsxd.cn/self-tools/l_request" "github.com/gofiber/fiber/v2/log" "github.com/gofiber/websocket/v2" "xorm.io/builder" ) type Do struct { Ctx *entitys.RequireData sessionImpl *impl.SessionImpl sysImpl *impl.SysImpl taskImpl *impl.TaskImpl hisImpl *impl.ChatImpl conf *config.Config } func NewDo( sysImpl *impl.SysImpl, taskImpl *impl.TaskImpl, hisImpl *impl.ChatImpl, conf *config.Config, ) *Do { return &Do{ conf: conf, sysImpl: sysImpl, hisImpl: hisImpl, taskImpl: taskImpl, } } func (d *Do) InitCtx(req *entitys.ChatSockRequest) *Do { d.Ctx = &entitys.RequireData{ Req: req, } return d } func (d *Do) DataAuth(c *websocket.Conn) (err error) { d.Ctx.Session = c.Query("x-session", "") if len(d.Ctx.Session) == 0 { err = errors.SessionNotFound return } d.Ctx.Auth = c.Query("x-authorization", "") if len(d.Ctx.Auth) == 0 { err = errors.AuthNotFound return } d.Ctx.Key = c.Query("x-app-key", "") if len(d.Ctx.Key) == 0 { err = errors.KeyNotFound return } d.Ctx.Sys, err = d.getSysInfo() if err != nil { err = errors.SysErr("获取系统信息失败:%v", err.Error()) return } d.Ctx.Histories, err = d.getSessionChatHis() if err != nil { err = errors.SysErr("获取历史记录失败:%v", err.Error()) return } d.Ctx.Tasks, err = d.getTasks(d.Ctx.Sys.SysID) if err != nil { err = errors.SysErr("获取任务列表失败:%v", err.Error()) return } if err = d.getImgData(); err != nil { return } return } func (d *Do) MakeCh(c *websocket.Conn) (ctx context.Context, deferFunc func()) { d.Ctx.Ch = make(chan entitys.Response) ctx, cancel := context.WithCancel(context.Background()) done := d.startMessageHandler(ctx, c, d.hisImpl) return ctx, func() { close(d.Ctx.Ch) //关闭主通道 <-done // 等待消息处理完成 cancel() } } func (d *Do) getImgData() (err error) { if len(d.Ctx.Req.Img) == 0 { return } imgs := strings.Split(d.Ctx.Req.Img, ",") if len(imgs) == 0 { return } for k, img := range imgs { baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:" entitys.ResLog(d.Ctx.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片") if err = pkg.ValidateImageURL(img); err != nil { entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content") continue } req := l_request.Request{ Method: "GET", Url: img, Headers: map[string]string{ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", }, } res, _err := req.Send() if _err != nil { entitys.ResLog(d.Ctx.Ch, "", baseErr+_err.Error()) continue } if _, ex := res.Headers["Content-Type"]; !ex { entitys.ResLog(d.Ctx.Ch, "", baseErr+":Content-Type不存在") continue } if !strings.HasPrefix(res.Headers["Content-Type"], "image/") { entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content") continue } d.Ctx.ImgByte = append(d.Ctx.ImgByte, res.Content) d.Ctx.ImgUrls = append(d.Ctx.ImgUrls, img) entitys.ResLog(d.Ctx.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功") } return } func (d *Do) getRequireData() (err error) { return } func (d *Do) getSysInfo() (sysInfo model.AiSy, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"app_key": d.Ctx.Key}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo) return } func (d *Do) getSessionChatHis() (his []model.AiChatHi, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"session_id": d.Ctx.Session}) _, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc") return } func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) { cond := builder.NewCond() cond = cond.And(builder.Eq{"sys_id": sysId}) cond = cond.And(builder.IsNull{"delete_at"}) cond = cond.And(builder.Eq{"status": 1}) _, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "") return } // startMessageHandler 启动独立的消息处理协程 func (d *Do) startMessageHandler( ctx context.Context, c *websocket.Conn, hisImpl *impl.ChatImpl, ) <-chan struct{} { done := make(chan struct{}) var chat []string go func() { defer func() { close(done) // 保存历史记录 var his = []*model.AiChatHi{ { SessionID: d.Ctx.Session, Role: "user", Content: d.Ctx.Req.Text, // 用户输入在外部处理 }, } if len(chat) > 0 { his = append(his, &model.AiChatHi{ SessionID: d.Ctx.Session, Role: "assistant", Content: strings.Join(chat, ""), }) } for _, hi := range his { hisImpl.Add(hi) } }() for v := range d.Ctx.Ch { // 自动检测通道关闭 if err := sendWithTimeout(c, v, 2*time.Second); err != nil { log.Errorf("Send error: %v", err) return } if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson { chat = append(chat, v.Content) } } }() return done } // 辅助函数:带超时的 WebSocket 发送 func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error { sendCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() done := make(chan error, 1) go func() { defer func() { if r := recover(); r != nil { done <- fmt.Errorf("panic in MsgSend: %v", r) } close(done) }() // 如果 MsgSend 阻塞,这里会卡住 err := entitys.MsgSend(c, data) done <- err }() select { case err := <-done: return err case <-sendCtx.Done(): return sendCtx.Err() } }