package ai_tool import ( "context" "encoding/json" "geo/pkg" "geo/tmpl/errcode" "sync" "time" "github.com/gofiber/fiber/v2/log" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" "github.com/volcengine/volcengine-go-sdk/volcengine" ) type Hsyq struct { mapClient map[string]*arkruntime.Client } var ( HsyqClient *Hsyq once sync.Once ) func NewHsyq() *Hsyq { once.Do(func() { HsyqClient = &Hsyq{ mapClient: make(map[string]*arkruntime.Client), } }) return HsyqClient } func (h *Hsyq) getClient(key string) *arkruntime.Client { var client *arkruntime.Client if _, ok := h.mapClient[key]; ok { client = h.mapClient[key] } else { client = arkruntime.NewClientWithApiKey( key, arkruntime.WithBaseUrl("https://ark.cn-beijing.volces.com/api/v3"), arkruntime.WithRegion("cn-beijing"), arkruntime.WithTimeout(2*time.Minute), arkruntime.WithRetryTimes(2), ) h.mapClient[key] = client } return client } // 火山引擎 func (h *Hsyq) Chat(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (model.ChatCompletionResponse, error) { req := model.CreateChatCompletionRequest{ Model: modelName, Messages: prompt, Stream: new(bool), Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled}, } resp, err := h.getClient(key).CreateChatCompletion(ctx, req) if err != nil { return model.ChatCompletionResponse{ID: ""}, err } log.Info("token用量:", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens) return resp, err } // 火山引擎 func (h *Hsyq) ChatWithRequest(ctx context.Context, key string, request model.ContextChatCompletionRequest) (model.ChatCompletionResponse, error) { resp, err := h.getClient(key).CreateContextChatCompletion(ctx, request) if err != nil { return model.ChatCompletionResponse{ID: ""}, err } log.Info("token用量:", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens) return resp, err } func (h *Hsyq) CreateContextCache(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (string, error) { req := model.CreateContextRequest{ Model: modelName, Messages: prompt, TTL: volcengine.Int(3600), Mode: model.ContextModeSession, TruncationStrategy: &model.TruncationStrategy{Type: model.TruncationStrategyTypeRollingTokens}, } resp, err := h.getClient(key).CreateContext(ctx, req) if err != nil { return "", err } log.Info("token用量:", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens) return resp.ID, err } func (h *Hsyq) CreateResponse(ctx context.Context, key string, modelName string, prompt []*responses.InputItem, id string, isRegis bool) (*responses.ResponseObject, error) { req := &responses.ResponsesRequest{ Model: modelName, Input: &responses.ResponsesInput{ Union: &responses.ResponsesInput_ListValue{ ListValue: &responses.InputItemList{ListValue: prompt}, }, }, Stream: new(bool), Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()}, } if isRegis { prefix := true req.Caching = &responses.ResponsesCaching{Type: responses.CacheType_enabled.Enum(), Prefix: &prefix} req.ExpireAt = volcengine.Int64(time.Now().Unix() + 3600) } if len(id) != 0 { req.PreviousResponseId = &id //req.Text = &responses.ResponsesText{ // Format: &responses.TextFormat{ // Type: responses.TextType_json_object, // Schema: // } //} } resp, err := h.getClient(key).CreateResponses(ctx, req) if err != nil { return nil, err } log.Info("token用量:", resp.Usage.TotalTokens, "输入:", resp.Usage.InputTokens, "输出:", resp.Usage.OutputTokens) return resp, err } func (h *Hsyq) RequestHsyqJson(ctx context.Context, key string, modelName string, prompt []*responses.InputItem) (*responses.ResponseObject, error) { req := responses.ResponsesRequest{ Model: modelName, Input: &responses.ResponsesInput{ Union: &responses.ResponsesInput_ListValue{ ListValue: &responses.InputItemList{ListValue: prompt}, }, }, Stream: new(bool), Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()}, Text: &responses.ResponsesText{Format: &responses.TextFormat{Type: responses.TextType_json_object}}, } resp, err := h.getClient(key).CreateResponses(ctx, &req) if err != nil { return resp, err } log.Info("token用量:", resp.Usage.TotalTokens) return resp, err } func (h *Hsyq) RequestHsyqBot(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage) (*string, error) { req := model.BotChatCompletionRequest{ BotId: botId, Messages: message, Stream: false, Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled}, ResponseFormat: &model.ResponseFormat{Type: model.ResponseFormatJSONSchema}, } resp, err := h.getClient(key).CreateBotChatCompletion(ctx, req) if err != nil { return nil, err } log.Info("token用量:", resp.Usage.TotalTokens) return resp.Choices[0].Message.Content.StringValue, nil } func (h *Hsyq) RequestHsyqBotToJson(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage, point interface{}) error { content, err := h.RequestHsyqBot(ctx, key, botId, message) if err != nil { return err } contentByte := []byte(*content) if err := json.Unmarshal(contentByte, &point); err != nil { contentStr, err := pkg.JsonRepair(*content) if err != nil { return errcode.SysErr("生成失败,请重试") } if err := json.Unmarshal([]byte(contentStr), point); err != nil { return errcode.SysErr("生成失败,请重试") } } return nil }