ai_scheduler/internal/biz/do/ctx.go

356 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package do
import (
"ai_scheduler/internal/config"
errors "ai_scheduler/internal/data/error"
"ai_scheduler/internal/data/impl"
"ai_scheduler/internal/data/model"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/gateway"
"ai_scheduler/internal/pkg"
"ai_scheduler/tmpl/dataTemp"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"gitea.cdlsxd.cn/self-tools/l_request"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/websocket/v2"
"xorm.io/builder"
)
type Do struct {
//Ctx *entitys.RequireData
sessionImpl *impl.SessionImpl
sysImpl *impl.SysImpl
taskImpl *impl.TaskImpl
hisImpl *impl.ChatImpl
conf *config.Config
}
func NewDo(
sysImpl *impl.SysImpl,
taskImpl *impl.TaskImpl,
hisImpl *impl.ChatImpl,
conf *config.Config,
) *Do {
return &Do{
conf: conf,
sysImpl: sysImpl,
hisImpl: hisImpl,
taskImpl: taskImpl,
}
}
func (d *Do) DataAuth(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) (err error) {
// 1. 验证客户端数据
if err = d.validateClientData(client, requireData); err != nil {
return err
}
// 2. 加载系统信息
if err = d.loadSystemInfo(ctx, client, requireData); err != nil {
return fmt.Errorf("获取系统信息失败: %w", err)
}
// 3. 加载任务列表
if err = d.loadTaskList(ctx, client, requireData); err != nil {
return fmt.Errorf("获取任务列表失败: %w", err)
}
// 4. 加载聊天历史
if err = d.loadChatHistory(ctx, requireData); err != nil {
return fmt.Errorf("获取历史记录失败: %w", err)
}
// 5. 加载图片数据
if err = d.getImgData(requireData); err != nil {
return err
}
// 6. 加载用户权限
if _, err = d.LoadUserPermission(client, requireData); err != nil {
return fmt.Errorf("获取用户权限失败: %w", err)
}
return nil
}
// 提取数据验证为单独函数
func (d *Do) validateClientData(client *gateway.Client, requireData *entitys.RequireData) error {
requireData.Session = client.GetSession()
if len(requireData.Session) == 0 {
return errors.SessionNotFound
}
requireData.Auth = client.GetAuth()
if len(requireData.Auth) == 0 {
return errors.AuthNotFound
}
requireData.Key = client.GetKey()
if len(requireData.Key) == 0 {
return errors.KeyNotFound
}
return nil
}
// 获取系统信息的辅助函数
func (d *Do) loadSystemInfo(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
if sysInfo := client.GetSysInfo(); sysInfo == nil {
sys, err := d.getSysInfo(requireData)
if err != nil {
return err
}
client.SetSysInfo(&sys)
requireData.Sys = sys
} else {
requireData.Sys = *sysInfo
}
return nil
}
// 获取任务列表的辅助函数
func (d *Do) loadTaskList(ctx context.Context, client *gateway.Client, requireData *entitys.RequireData) error {
if taskInfo := client.GetTasks(); len(taskInfo) == 0 {
tasks, err := d.getTasks(requireData.Sys.SysID)
if err != nil {
return err
}
requireData.Tasks = tasks
client.SetTasks(tasks)
} else {
requireData.Tasks = taskInfo
}
return nil
}
// 获取历史记录的辅助函数
func (d *Do) loadChatHistory(ctx context.Context, requireData *entitys.RequireData) error {
histories, err := d.getSessionChatHis(requireData)
if err != nil {
return err
}
requireData.Histories = histories
return nil
}
func (d *Do) MakeCh(c *websocket.Conn, requireData *entitys.RequireData) (ctx context.Context, deferFunc func()) {
requireData.Ch = make(chan entitys.Response)
ctx, cancel := context.WithCancel(context.Background())
done := d.startMessageHandler(ctx, c, requireData)
return ctx, func() {
close(requireData.Ch) //关闭主通道
<-done // 等待消息处理完成
cancel()
}
}
func (d *Do) getImgData(requireData *entitys.RequireData) (err error) {
if len(requireData.Req.Img) == 0 {
return
}
imgs := strings.Split(requireData.Req.Img, ",")
if len(imgs) == 0 {
return
}
for k, img := range imgs {
baseErr := "获取第" + strconv.Itoa(k+1) + "张图片失败:"
entitys.ResLog(requireData.Ch, "img_get_start", "正在获取第"+strconv.Itoa(k+1)+"张图片")
if err = pkg.ValidateImageURL(img); err != nil {
entitys.ResLog(requireData.Ch, "", baseErr+"expected image content")
continue
}
req := l_request.Request{
Method: "GET",
Url: img,
Headers: map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
},
}
res, _err := req.Send()
if _err != nil {
entitys.ResLog(requireData.Ch, "", baseErr+_err.Error())
continue
}
if _, ex := res.Headers["Content-Type"]; !ex {
entitys.ResLog(requireData.Ch, "", baseErr+"Content-Type不存在")
continue
}
if !strings.HasPrefix(res.Headers["Content-Type"], "image/") {
entitys.ResLog(requireData.Ch, "", baseErr+"expected image content")
continue
}
requireData.ImgByte = append(requireData.ImgByte, res.Content)
requireData.ImgUrls = append(requireData.ImgUrls, img)
entitys.ResLog(requireData.Ch, "img_get_end", "第"+strconv.Itoa(k+1)+"张图片获取成功")
}
return
}
func (d *Do) getRequireData() (err error) {
return
}
func (d *Do) getSysInfo(requireData *entitys.RequireData) (sysInfo model.AiSy, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"app_key": requireData.Key})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
err = d.sysImpl.GetOneBySearchToStrut(&cond, &sysInfo)
return
}
func (d *Do) getSessionChatHis(requireData *entitys.RequireData) (his []model.AiChatHi, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"session_id": requireData.Session})
_, err = d.hisImpl.GetListToStruct(&cond, &dataTemp.ReqPageBo{Limit: d.conf.Sys.SessionLen}, &his, "his_id desc")
return
}
func (d *Do) getTasks(sysId int32) (tasks []model.AiTask, err error) {
cond := builder.NewCond()
cond = cond.And(builder.Eq{"sys_id": sysId})
cond = cond.And(builder.IsNull{"delete_at"})
cond = cond.And(builder.Eq{"status": 1})
_, err = d.taskImpl.GetListToStruct(&cond, nil, &tasks, "")
return
}
// startMessageHandler 启动独立的消息处理协程
func (d *Do) startMessageHandler(
ctx context.Context,
c *websocket.Conn,
requireData *entitys.RequireData,
) <-chan struct{} {
done := make(chan struct{})
var chat []string
go func() {
defer func() {
close(done)
// 保存历史记录
var (
hisLog = &entitys.ChatHisLog{}
)
if len(chat) > 0 {
AiRes := &model.AiChatHi{
SessionID: requireData.Session,
Ques: requireData.Req.Text,
Ans: strings.Join(chat, ""),
Files: requireData.Req.Img,
}
d.hisImpl.AddWithData(AiRes)
hisLog.HisId = AiRes.HisID
}
_ = entitys.MsgSend(c, entitys.Response{
Content: pkg.JsonStringIgonErr(hisLog),
Type: entitys.ResponseEnd,
})
}()
for v := range requireData.Ch { // 自动检测通道关闭
if err := sendWithTimeout(c, v, 2*time.Second); err != nil {
log.Errorf("Send error: %v", err)
return
}
if v.Type == entitys.ResponseText || v.Type == entitys.ResponseStream || v.Type == entitys.ResponseJson {
chat = append(chat, v.Content)
}
}
}()
return done
}
// 辅助函数:带超时的 WebSocket 发送
func sendWithTimeout(c *websocket.Conn, data entitys.Response, timeout time.Duration) error {
sendCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- fmt.Errorf("panic in MsgSend: %v", r)
}
close(done)
}()
// 如果 MsgSend 阻塞,这里会卡住
err := entitys.MsgSend(c, data)
done <- err
}()
select {
case err := <-done:
return err
case <-sendCtx.Done():
return sendCtx.Err()
}
}
// 从统一登录平台获取用户权限
func (d *Do) LoadUserPermission(client *gateway.Client, requireData *entitys.RequireData) (codes []string, err error) {
if len(client.GetCodes()) > 0 {
return client.GetCodes(), nil
}
var (
request l_request.Request
)
// 构建请求URL
request.Url = d.conf.PermissionConfig.PermissionURL + strconv.Itoa(int(requireData.Sys.SysID))
request.Method = "GET"
request.Headers = map[string]string{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "application/json, text/plain, */*",
"Authorization": "Bearer " + client.GetAuth(),
}
// 发送请求
res, err := request.Send()
if err != nil {
return
}
// 检查响应状态码
if res.StatusCode != http.StatusOK {
err = errors.SysErr("获取用户权限失败")
return
}
type resp struct {
Codes []string `json:"codes"`
}
// 解析响应体
var respBody resp
err = json.Unmarshal([]byte(res.Text), &respBody)
if err != nil {
return
}
// 设置客户端权限
client.SetCodes(respBody.Codes)
return respBody.Codes, nil
}