256 lines
6.0 KiB
Go
256 lines
6.0 KiB
Go
package do
|
||
|
||
import (
|
||
"ai_scheduler/internal/config"
|
||
errors "ai_scheduler/internal/data/error"
|
||
"ai_scheduler/internal/data/impl"
|
||
"ai_scheduler/internal/data/model"
|
||
"ai_scheduler/internal/entitys"
|
||
"ai_scheduler/internal/pkg"
|
||
"ai_scheduler/tmpl/dataTemp"
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.cdlsxd.cn/self-tools/l_request"
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/gofiber/websocket/v2"
|
||
|
||
"xorm.io/builder"
|
||
)
|
||
|
||
type Do struct {
|
||
Ctx *entitys.RequireData
|
||
sessionImpl *impl.SessionImpl
|
||
sysImpl *impl.SysImpl
|
||
taskImpl *impl.TaskImpl
|
||
hisImpl *impl.ChatImpl
|
||
conf *config.Config
|
||
}
|
||
|
||
func NewDo(
|
||
sysImpl *impl.SysImpl,
|
||
taskImpl *impl.TaskImpl,
|
||
hisImpl *impl.ChatImpl,
|
||
conf *config.Config,
|
||
) *Do {
|
||
return &Do{
|
||
conf: conf,
|
||
sysImpl: sysImpl,
|
||
hisImpl: hisImpl,
|
||
taskImpl: taskImpl,
|
||
}
|
||
}
|
||
|
||
func (d *Do) InitCtx(req *entitys.ChatSockRequest) *Do {
|
||
d.Ctx = &entitys.RequireData{
|
||
Req: req,
|
||
}
|
||
return d
|
||
}
|
||
|
||
func (d *Do) DataAuth(c *websocket.Conn) (err error) {
|
||
d.Ctx.Session = c.Query("x-session", "")
|
||
if len(d.Ctx.Session) == 0 {
|
||
err = errors.SessionNotFound
|
||
return
|
||
}
|
||
d.Ctx.Auth = c.Query("x-authorization", "")
|
||
if len(d.Ctx.Auth) == 0 {
|
||
err = errors.AuthNotFound
|
||
return
|
||
}
|
||
d.Ctx.Key = c.Query("x-app-key", "")
|
||
if len(d.Ctx.Key) == 0 {
|
||
err = errors.KeyNotFound
|
||
return
|
||
}
|
||
|
||
d.Ctx.Sys, err = d.getSysInfo()
|
||
if err != nil {
|
||
err = errors.SysErr("获取系统信息失败:%v", err.Error())
|
||
return
|
||
}
|
||
d.Ctx.Histories, err = d.getSessionChatHis()
|
||
if err != nil {
|
||
err = errors.SysErr("获取历史记录失败:%v", err.Error())
|
||
return
|
||
}
|
||
|
||
d.Ctx.Tasks, err = d.getTasks(d.Ctx.Sys.SysID)
|
||
if err != nil {
|
||
err = errors.SysErr("获取任务列表失败:%v", err.Error())
|
||
return
|
||
}
|
||
if err = d.getImgData(); err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) MakeCh(c *websocket.Conn) (ctx context.Context, deferFunc func()) {
|
||
d.Ctx.Ch = make(chan entitys.Response)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
done := d.startMessageHandler(ctx, c, d.hisImpl)
|
||
return ctx, func() {
|
||
close(d.Ctx.Ch) //关闭主通道
|
||
<-done // 等待消息处理完成
|
||
cancel()
|
||
}
|
||
}
|
||
|
||
func (d *Do) getImgData() (err error) {
|
||
if len(d.Ctx.Req.Img) == 0 {
|
||
return
|
||
}
|
||
imgs := strings.Split(d.Ctx.Req.Img, ",")
|
||
if len(imgs) == 0 {
|
||
return
|
||
}
|
||
|
||
for k, img := range imgs {
|
||
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
|
||
entitys.ResLog(d.Ctx.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片")
|
||
if err = pkg.ValidateImageURL(img); err != nil {
|
||
entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content")
|
||
continue
|
||
}
|
||
req := l_request.Request{
|
||
Method: "GET",
|
||
Url: img,
|
||
Headers: map[string]string{
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
|
||
},
|
||
}
|
||
res, _err := req.Send()
|
||
if _err != nil {
|
||
entitys.ResLog(d.Ctx.Ch, "", baseErr+_err.Error())
|
||
continue
|
||
}
|
||
if _, ex := res.Headers["Content-Type"]; !ex {
|
||
entitys.ResLog(d.Ctx.Ch, "", baseErr+":Content-Type不存在")
|
||
continue
|
||
}
|
||
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
|
||
entitys.ResLog(d.Ctx.Ch, "", baseErr+":expected image content")
|
||
continue
|
||
}
|
||
d.Ctx.ImgByte = append(d.Ctx.ImgByte, res.Content)
|
||
d.Ctx.ImgUrls = append(d.Ctx.ImgUrls, img)
|
||
entitys.ResLog(d.Ctx.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功")
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) getRequireData() (err error) {
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) getSysInfo() (sysInfo model.AiSy, err error) {
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"app_key": d.Ctx.Key})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
|
||
return
|
||
}
|
||
|
||
func (d *Do) getSessionChatHis() (his []model.AiChatHi, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"session_id": d.Ctx.Session})
|
||
|
||
_, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc")
|
||
|
||
return
|
||
}
|
||
|
||
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
|
||
|
||
cond := builder.NewCond()
|
||
cond = cond.And(builder.Eq{"sys_id": sysId})
|
||
cond = cond.And(builder.IsNull{"delete_at"})
|
||
cond = cond.And(builder.Eq{"status": 1})
|
||
_, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
|
||
|
||
return
|
||
}
|
||
|
||
// startMessageHandler 启动独立的消息处理协程
|
||
func (d *Do) startMessageHandler(
|
||
ctx context.Context,
|
||
c *websocket.Conn,
|
||
hisImpl *impl.ChatImpl,
|
||
) <-chan struct{} {
|
||
done := make(chan struct{})
|
||
var chat []string
|
||
|
||
go func() {
|
||
defer func() {
|
||
close(done)
|
||
// 保存历史记录
|
||
var his = []*model.AiChatHi{
|
||
{
|
||
SessionID: d.Ctx.Session,
|
||
Role: "user",
|
||
Content: d.Ctx.Req.Text, // 用户输入在外部处理
|
||
},
|
||
}
|
||
if len(chat) > 0 {
|
||
his = append(his, &model.AiChatHi{
|
||
SessionID: d.Ctx.Session,
|
||
Role: "assistant",
|
||
Content: strings.Join(chat, ""),
|
||
})
|
||
}
|
||
for _, hi := range his {
|
||
hisImpl.Add(hi)
|
||
}
|
||
}()
|
||
|
||
for v := range d.Ctx.Ch { // 自动检测通道关闭
|
||
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
|
||
log.Errorf("Send error: %v", err)
|
||
return
|
||
}
|
||
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
|
||
chat = append(chat, v.Content)
|
||
}
|
||
}
|
||
}()
|
||
|
||
return done
|
||
}
|
||
|
||
// 辅助函数:带超时的 WebSocket 发送
|
||
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
|
||
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
|
||
done := make(chan error, 1)
|
||
go func() {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
done <- fmt.Errorf("panic in MsgSend: %v", r)
|
||
}
|
||
close(done)
|
||
}()
|
||
// 如果 MsgSend 阻塞,这里会卡住
|
||
err := entitys.MsgSend(c, data)
|
||
done <- err
|
||
}()
|
||
|
||
select {
|
||
case err := <-done:
|
||
return err
|
||
case <-sendCtx.Done():
|
||
return sendCtx.Err()
|
||
}
|
||
}
|