geoGo/internal/ai_tool/hsyq.go

192 lines
5.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ai_tool
import (
"context"
"encoding/json"
"geo/pkg"
"geo/tmpl/errcode"
"sync"
"time"
"github.com/gofiber/fiber/v2/log"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
"github.com/volcengine/volcengine-go-sdk/volcengine"
)
type Hsyq struct {
mapClient map[string]*arkruntime.Client
}
var (
HsyqClient *Hsyq
once sync.Once
)
func NewHsyq() *Hsyq {
once.Do(func() {
HsyqClient = &Hsyq{
mapClient: make(map[string]*arkruntime.Client),
}
})
return HsyqClient
}
func (h *Hsyq) getClient(key string) *arkruntime.Client {
var client *arkruntime.Client
if _, ok := h.mapClient[key]; ok {
client = h.mapClient[key]
} else {
client = arkruntime.NewClientWithApiKey(
key,
arkruntime.WithBaseUrl("https://ark.cn-beijing.volces.com/api/v3"),
arkruntime.WithRegion("cn-beijing"),
arkruntime.WithTimeout(2*time.Minute),
arkruntime.WithRetryTimes(2),
)
h.mapClient[key] = client
}
return client
}
// 火山引擎
func (h *Hsyq) Chat(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (model.ChatCompletionResponse, error) {
req := model.CreateChatCompletionRequest{
Model: modelName,
Messages: prompt,
Stream: new(bool),
Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled},
}
resp, err := h.getClient(key).CreateChatCompletion(ctx, req)
if err != nil {
return model.ChatCompletionResponse{ID: ""}, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp, err
}
// 火山引擎
func (h *Hsyq) ChatWithRequest(ctx context.Context, key string, request model.ContextChatCompletionRequest) (model.ChatCompletionResponse, error) {
resp, err := h.getClient(key).CreateContextChatCompletion(ctx, request)
if err != nil {
return model.ChatCompletionResponse{ID: ""}, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp, err
}
func (h *Hsyq) CreateContextCache(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (string, error) {
req := model.CreateContextRequest{
Model: modelName,
Messages: prompt,
TTL: volcengine.Int(3600),
Mode: model.ContextModeSession,
TruncationStrategy: &model.TruncationStrategy{Type: model.TruncationStrategyTypeRollingTokens},
}
resp, err := h.getClient(key).CreateContext(ctx, req)
if err != nil {
return "", err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp.ID, err
}
func (h *Hsyq) CreateResponse(ctx context.Context, key string, modelName string, prompt []*responses.InputItem, id string, isRegis bool) (*responses.ResponseObject, error) {
req := &responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: prompt},
},
},
Stream: new(bool),
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
}
if isRegis {
prefix := true
req.Caching = &responses.ResponsesCaching{Type: responses.CacheType_enabled.Enum(), Prefix: &prefix}
req.ExpireAt = volcengine.Int64(time.Now().Unix() + 3600)
}
if len(id) != 0 {
req.PreviousResponseId = &id
//req.Text = &responses.ResponsesText{
// Format: &responses.TextFormat{
// Type: responses.TextType_json_object,
// Schema:
// }
//}
}
resp, err := h.getClient(key).CreateResponses(ctx, req)
if err != nil {
return nil, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.InputTokens, "输出:", resp.Usage.OutputTokens)
return resp, err
}
func (h *Hsyq) RequestHsyqJson(ctx context.Context, key string, modelName string, prompt []*responses.InputItem) (*responses.ResponseObject, error) {
req := responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: prompt},
},
},
Stream: new(bool),
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
Text: &responses.ResponsesText{Format: &responses.TextFormat{Type: responses.TextType_json_object}},
}
resp, err := h.getClient(key).CreateResponses(ctx, &req)
if err != nil {
return resp, err
}
log.Info("token用量", resp.Usage.TotalTokens)
return resp, err
}
func (h *Hsyq) RequestHsyqBot(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage) (*string, error) {
req := model.BotChatCompletionRequest{
BotId: botId,
Messages: message,
Stream: false,
Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled},
ResponseFormat: &model.ResponseFormat{Type: model.ResponseFormatJSONSchema},
}
resp, err := h.getClient(key).CreateBotChatCompletion(ctx, req)
if err != nil {
return nil, err
}
log.Info("token用量", resp.Usage.TotalTokens)
return resp.Choices[0].Message.Content.StringValue, nil
}
func (h *Hsyq) RequestHsyqBotToJson(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage, point interface{}) error {
content, err := h.RequestHsyqBot(ctx, key, botId, message)
if err != nil {
return err
}
contentByte := []byte(*content)
if err := json.Unmarshal(contentByte, &point); err != nil {
contentStr, err := pkg.JsonRepair(*content)
if err != nil {
return errcode.SysErr("生成失败,请重试")
}
if err := json.Unmarshal([]byte(contentStr), point); err != nil {
return errcode.SysErr("生成失败,请重试")
}
}
return nil
}