geoGo/internal/ai/hsyq.go

181 lines
5.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package third_party
import (
"context"
"sync"
"time"
"github.com/gofiber/fiber/v2/log"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
"github.com/volcengine/volcengine-go-sdk/volcengine"
)
type Hsyq struct {
mapClient map[string]*arkruntime.Client
}
var (
HsyqClient *Hsyq
once sync.Once
)
func NewHsyq() *Hsyq {
once.Do(func() {
HsyqClient = &Hsyq{
mapClient: make(map[string]*arkruntime.Client),
}
})
return HsyqClient
}
func (h *Hsyq) getClient(key string) *arkruntime.Client {
var client *arkruntime.Client
if _, ok := h.mapClient[key]; ok {
client = h.mapClient[key]
} else {
client = arkruntime.NewClientWithApiKey(
key,
arkruntime.WithBaseUrl("https://ark.cn-beijing.volces.com/api/v3"),
arkruntime.WithRegion("cn-beijing"),
arkruntime.WithTimeout(2*time.Minute),
arkruntime.WithRetryTimes(2),
)
h.mapClient[key] = client
}
return client
}
// 火山引擎
func (h *Hsyq) Chat(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (model.ChatCompletionResponse, error) {
req := model.CreateChatCompletionRequest{
Model: modelName,
Messages: prompt,
Stream: new(bool),
Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled},
}
resp, err := h.getClient(key).CreateChatCompletion(ctx, req)
if err != nil {
return model.ChatCompletionResponse{ID: ""}, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp, err
}
// 火山引擎
func (h *Hsyq) ChatWithRequest(ctx context.Context, key string, request model.ContextChatCompletionRequest) (model.ChatCompletionResponse, error) {
resp, err := h.getClient(key).CreateContextChatCompletion(ctx, request)
if err != nil {
return model.ChatCompletionResponse{ID: ""}, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp, err
}
func (h *Hsyq) CreateContextCache(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (string, error) {
req := model.CreateContextRequest{
Model: modelName,
Messages: prompt,
TTL: volcengine.Int(3600),
Mode: model.ContextModeSession,
TruncationStrategy: &model.TruncationStrategy{Type: model.TruncationStrategyTypeRollingTokens},
}
resp, err := h.getClient(key).CreateContext(ctx, req)
if err != nil {
return "", err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp.ID, err
}
func (h *Hsyq) CreateResponse(ctx context.Context, key string, modelName string, prompt []*responses.InputItem, id string, isRegis bool) (*responses.ResponseObject, error) {
req := &responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: prompt},
},
},
Stream: new(bool),
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
}
if isRegis {
prefix := true
req.Caching = &responses.ResponsesCaching{Type: responses.CacheType_enabled.Enum(), Prefix: &prefix}
req.ExpireAt = volcengine.Int64(time.Now().Unix() + 3600)
}
if len(id) != 0 {
req.PreviousResponseId = &id
//req.Text = &responses.ResponsesText{
// Format: &responses.TextFormat{
// Type: responses.TextType_json_object,
// Schema:
// }
//}
}
resp, err := h.getClient(key).CreateResponses(ctx, req)
if err != nil {
return nil, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.InputTokens, "输出:", resp.Usage.OutputTokens)
return resp, err
}
func (h *Hsyq) RequestHsyqJson(ctx context.Context, key string, modelName string, prompt []*responses.InputItem) (*responses.ResponseObject, error) {
req := responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: prompt},
},
},
Stream: new(bool),
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
Text: &responses.ResponsesText{Format: &responses.TextFormat{Type: responses.TextType_json_object}},
}
resp, err := h.getClient(key).CreateResponses(ctx, &req)
if err != nil {
return resp, err
}
log.Info("token用量", resp.Usage.TotalTokens)
return resp, err
}
// []*model.ChatCompletionMessage{
// {
// Role: model.ChatMessageRoleSystem,
// Content: &model.ChatCompletionMessageContent{
// StringValue: volcengine.String("你是豆包,是由字节跳动开发的 AI 人工智能助手"),
// },
// },
// {
// Role: model.ChatMessageRoleUser,
// Content: &model.ChatCompletionMessageContent{
// StringValue: volcengine.String("常见的十字花科植物有哪些?"),
// },
// },
// },
func (h *Hsyq) RequestHsyqBot(ctx context.Context, key string, botId string, message []*model.ChatCompletionMessage) ([]byte, error) {
req := model.BotChatCompletionRequest{
BotId: botId,
Messages: message,
Stream: false,
Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled},
ResponseFormat: &model.ResponseFormat{Type: model.ResponseFormatJsonObject},
}
resp, err := h.mapClient[key].CreateBotChatCompletion(ctx, req)
if err != nil {
return nil, err
}
log.Info("token用量", resp.Usage.TotalTokens)
return resp.Choices[0].Message.Content.MarshalJSON()
}