56 lines
1.3 KiB
Go
56 lines
1.3 KiB
Go
package doubao
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
|
"time"
|
|
)
|
|
|
|
type DouBao struct {
|
|
Model string
|
|
Key string
|
|
}
|
|
|
|
func NewDouBao(modelName string, key string) *DouBao {
|
|
return &DouBao{
|
|
Model: modelName,
|
|
Key: key,
|
|
}
|
|
}
|
|
|
|
func (o *DouBao) GetData(ctx context.Context, key string, role string, respHandle func(input string) (string, error), text ...string) (string, error) {
|
|
var Message = make([]*model.ChatCompletionMessage, len(text))
|
|
|
|
client := arkruntime.NewClientWithApiKey(
|
|
key,
|
|
//arkruntime.WithBaseUrl(UrlMap[url]),
|
|
arkruntime.WithRegion("cn-beijing"),
|
|
arkruntime.WithTimeout(2*time.Minute),
|
|
arkruntime.WithRetryTimes(2),
|
|
)
|
|
for k, v := range text {
|
|
Message[k] = &model.ChatCompletionMessage{
|
|
Role: role,
|
|
Content: &model.ChatCompletionMessageContent{
|
|
StringValue: volcengine.String(v),
|
|
},
|
|
}
|
|
}
|
|
req := model.CreateChatCompletionRequest{
|
|
Model: o.Model,
|
|
Messages: Message,
|
|
}
|
|
|
|
resp, err := client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
fmt.Printf("用量:%d", resp.Usage.TotalTokens)
|
|
result, err := respHandle(*resp.Choices[0].Message.Content.StringValue)
|
|
|
|
return result, err
|
|
}
|