ai_scheduler/internal/biz/llm_service/third_party/hsyq.go

140 lines
4.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package third_party
import (
"context"
"time"
"github.com/gofiber/fiber/v2/log"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
"github.com/volcengine/volcengine-go-sdk/volcengine"
)
type Hsyq struct {
mapClient map[string]*arkruntime.Client
}
func NewHsyq() *Hsyq {
return &Hsyq{
mapClient: make(map[string]*arkruntime.Client),
}
}
func (h *Hsyq) getClient(key string) *arkruntime.Client {
var client *arkruntime.Client
if _, ok := h.mapClient[key]; ok {
client = h.mapClient[key]
} else {
client = arkruntime.NewClientWithApiKey(
key,
arkruntime.WithRegion("cn-beijing"),
arkruntime.WithTimeout(2*time.Minute),
arkruntime.WithRetryTimes(2),
)
h.mapClient[key] = client
}
return client
}
// 火山引擎
func (h *Hsyq) Chat(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (model.ChatCompletionResponse, error) {
req := model.CreateChatCompletionRequest{
Model: modelName,
Messages: prompt,
Stream: new(bool),
Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled},
}
resp, err := h.getClient(key).CreateChatCompletion(ctx, req)
if err != nil {
return model.ChatCompletionResponse{ID: ""}, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp, err
}
// 火山引擎
func (h *Hsyq) ChatWithRequest(ctx context.Context, key string, request model.ContextChatCompletionRequest) (model.ChatCompletionResponse, error) {
resp, err := h.getClient(key).CreateContextChatCompletion(ctx, request)
if err != nil {
return model.ChatCompletionResponse{ID: ""}, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp, err
}
func (h *Hsyq) CreateContextCache(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (string, error) {
req := model.CreateContextRequest{
Model: modelName,
Messages: prompt,
TTL: volcengine.Int(3600),
Mode: model.ContextModeSession,
TruncationStrategy: &model.TruncationStrategy{Type: model.TruncationStrategyTypeRollingTokens},
}
resp, err := h.getClient(key).CreateContext(ctx, req)
if err != nil {
return "", err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
return resp.ID, err
}
func (h *Hsyq) CreateResponse(ctx context.Context, key string, modelName string, prompt []*responses.InputItem, id string, isRegis bool) (*responses.ResponseObject, error) {
req := &responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: prompt},
},
},
Stream: new(bool),
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
}
if isRegis {
prefix := true
req.Caching = &responses.ResponsesCaching{Type: responses.CacheType_enabled.Enum(), Prefix: &prefix}
req.ExpireAt = volcengine.Int64(time.Now().Unix() + 3600)
}
if len(id) != 0 {
req.PreviousResponseId = &id
//req.Text = &responses.ResponsesText{
// Format: &responses.TextFormat{
// Type: responses.TextType_json_object,
// Schema:
// }
//}
}
resp, err := h.getClient(key).CreateResponses(ctx, req)
if err != nil {
return nil, err
}
log.Info("token用量", resp.Usage.TotalTokens, "输入:", resp.Usage.InputTokens, "输出:", resp.Usage.OutputTokens)
return resp, err
}
func (h *Hsyq) RequestHsyqJson(ctx context.Context, key string, modelName string, prompt []*responses.InputItem) (*responses.ResponseObject, error) {
req := responses.ResponsesRequest{
Model: modelName,
Input: &responses.ResponsesInput{
Union: &responses.ResponsesInput_ListValue{
ListValue: &responses.InputItemList{ListValue: prompt},
},
},
Stream: new(bool),
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
Text: &responses.ResponsesText{Format: &responses.TextFormat{Type: responses.TextType_json_object}},
}
resp, err := h.getClient(key).CreateResponses(ctx, &req)
if err != nil {
return resp, err
}
log.Info("token用量", resp.Usage.TotalTokens)
return resp, err
}