77 lines
2.2 KiB
Go
77 lines
2.2 KiB
Go
package third_party
|
||
|
||
import (
|
||
"context"
|
||
|
||
"time"
|
||
|
||
"github.com/gofiber/fiber/v2/log"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses"
|
||
)
|
||
|
||
type Hsyq struct {
|
||
mapClient map[string]*arkruntime.Client
|
||
}
|
||
|
||
func NewHsyq() *Hsyq {
|
||
return &Hsyq{
|
||
mapClient: make(map[string]*arkruntime.Client),
|
||
}
|
||
}
|
||
|
||
func (h *Hsyq) getClient(key string) *arkruntime.Client {
|
||
var client *arkruntime.Client
|
||
if _, ok := h.mapClient[key]; ok {
|
||
client = h.mapClient[key]
|
||
} else {
|
||
client = arkruntime.NewClientWithApiKey(
|
||
key,
|
||
arkruntime.WithRegion("cn-beijing"),
|
||
arkruntime.WithTimeout(2*time.Minute),
|
||
arkruntime.WithRetryTimes(2),
|
||
)
|
||
h.mapClient[key] = client
|
||
}
|
||
return client
|
||
}
|
||
|
||
// 火山引擎
|
||
func (h *Hsyq) RequestHsyq(ctx context.Context, key string, modelName string, prompt []*model.ChatCompletionMessage) (model.ChatCompletionResponse, error) {
|
||
req := model.CreateChatCompletionRequest{
|
||
Model: modelName,
|
||
Messages: prompt,
|
||
Stream: new(bool),
|
||
Thinking: &model.Thinking{Type: model.ThinkingTypeDisabled},
|
||
}
|
||
|
||
resp, err := h.getClient(key).CreateChatCompletion(ctx, req)
|
||
if err != nil {
|
||
return model.ChatCompletionResponse{ID: ""}, err
|
||
}
|
||
log.Info("token用量:", resp.Usage.TotalTokens, "输入:", resp.Usage.PromptTokens, "输出:", resp.Usage.CompletionTokens)
|
||
return resp, err
|
||
}
|
||
|
||
func (h *Hsyq) RequestHsyqJson(ctx context.Context, key string, modelName string, prompt []*responses.InputItem) (*responses.ResponseObject, error) {
|
||
req := responses.ResponsesRequest{
|
||
Model: modelName,
|
||
Input: &responses.ResponsesInput{
|
||
Union: &responses.ResponsesInput_ListValue{
|
||
ListValue: &responses.InputItemList{ListValue: prompt},
|
||
},
|
||
},
|
||
Stream: new(bool),
|
||
Thinking: &responses.ResponsesThinking{Type: responses.ThinkingType_disabled.Enum()},
|
||
Text: &responses.ResponsesText{Format: &responses.TextFormat{Type: responses.TextType_json_object}},
|
||
}
|
||
|
||
resp, err := h.getClient(key).CreateResponses(ctx, &req)
|
||
if err != nil {
|
||
return resp, err
|
||
}
|
||
log.Info("token用量:", resp.Usage.TotalTokens)
|
||
return resp, err
|
||
}
|