This commit is contained in:
commit
f0171e80ef
|
@ -0,0 +1,9 @@
|
|||
package entity
|
||||
|
||||
type Config struct {
|
||||
SampleRate int `env:"AUDIO_SAMPLE_RATE" envDefault:"16000"`
|
||||
Channels int `env:"AUDIO_CHANNELS" envDefault:"1"`
|
||||
BitDepth int `env:"AUDIO_BIT_DEPTH" envDefault:"16"`
|
||||
Port string `env:"PORT" envDefault:"8000"`
|
||||
MaxBufferSize int `env:"MAX_BUFFER_SIZE" envDefault:"104857600"` // 100MB
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
module l_szr_go
|
||||
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.5
|
||||
|
||||
require (
|
||||
github.com/caarlos0/env/v6 v6.10.1
|
||||
github.com/gofiber/fiber/v2 v2.52.9
|
||||
github.com/gofiber/websocket/v2 v2.2.1
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.23
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.64.0 // indirect
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.8 // indirect
|
||||
)
|
|
@ -0,0 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
|
@ -0,0 +1,127 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/caarlos0/env/v6"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"l_szr_go/pkg"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"l_szr_go/entity"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
config entity.Config
|
||||
activeConnections sync.WaitGroup
|
||||
)
|
||||
|
||||
func Server() *fiber.App {
|
||||
app := fiber.New()
|
||||
|
||||
// 健康检查
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// WebSocket 路由
|
||||
app.Use("/ws", func(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
c.Locals("allowed", true)
|
||||
return c.Next()
|
||||
}
|
||||
return fiber.ErrUpgradeRequired
|
||||
})
|
||||
|
||||
app.Get("/ws", websocket.New(HandleWebSocket))
|
||||
return app
|
||||
}
|
||||
|
||||
func HandleWebSocket(c *websocket.Conn) {
|
||||
activeConnections.Add(1)
|
||||
defer activeConnections.Done()
|
||||
|
||||
var audioBuffer bytes.Buffer
|
||||
fileName := fmt.Sprintf("output_%s.wav",
|
||||
time.Now().Format("20060102_150405.000"))
|
||||
|
||||
log.Printf("客户端连接: %s", c.RemoteAddr())
|
||||
|
||||
defer func() {
|
||||
if audioBuffer.Len() == 0 {
|
||||
log.Println("未接收到音频数据,跳过保存")
|
||||
return
|
||||
}
|
||||
|
||||
if err := pkg.SaveAsWav(&audioBuffer, fileName, config); err != nil {
|
||||
log.Printf("保存文件失败: %v", err)
|
||||
} else {
|
||||
log.Printf("音频已保存为 %s (%d字节)", fileName, audioBuffer.Len())
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("客户端异常断开: %v", err)
|
||||
} else {
|
||||
log.Println("客户端主动断开")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 检查缓冲区大小
|
||||
if audioBuffer.Len()+len(message) > config.MaxBufferSize {
|
||||
log.Println("达到最大缓冲区大小,断开连接")
|
||||
break
|
||||
}
|
||||
|
||||
// 验证PCM数据 (基本验证)
|
||||
if config.BitDepth == 16 && len(message)%2 != 0 {
|
||||
log.Println("接收到无效的16位PCM数据,长度不是2的倍数")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := audioBuffer.Write(message); err != nil {
|
||||
log.Printf("写入缓冲区失败: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
log.Printf("接收到 %d 字节音频数据 (总大小: %d字节)",
|
||||
len(message), audioBuffer.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := env.Parse(&config); err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
app := Server()
|
||||
// 启动服务器
|
||||
go func() {
|
||||
log.Printf("服务器启动,监听端口 %s", config.Port)
|
||||
log.Printf("音频配置: %dHz, %d通道, %d位深",
|
||||
config.SampleRate, config.Channels, config.BitDepth)
|
||||
if err := app.Listen(":" + config.Port); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 优雅关闭
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("正在关闭服务器...")
|
||||
|
||||
// 等待所有连接完成
|
||||
activeConnections.Wait()
|
||||
log.Println("所有连接已关闭")
|
||||
os.Exit(0)
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package onnx
|
||||
|
||||
import "C"
|
||||
import "unsafe"
|
||||
|
||||
func onnx() {
|
||||
// 1. 初始化 ONNX Runtime 环境
|
||||
env := C.CreateEnv()
|
||||
defer C.OrtReleaseEnv(env)
|
||||
|
||||
// 2. 加载模型(替换为实际模型路径)
|
||||
modelPath := C.CString("sensevoice_small.onnx")
|
||||
defer C.free(unsafe.Pointer(modelPath))
|
||||
session := C.CreateSession(env, modelPath)
|
||||
defer C.OrtReleaseSession(session)
|
||||
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package doubao
|
||||
|
||||
var UrlMap = map[UrlType]string{
|
||||
Text: "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
|
||||
Video: "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
|
||||
Embedding: "https://ark.cn-beijing.volces.com/api/v3/embeddings",
|
||||
Token: "https://ark.cn-beijing.volces.com/api/v3/tokenization",
|
||||
}
|
||||
|
||||
type UrlType string
|
||||
|
||||
const (
|
||||
Text UrlType = "text"
|
||||
Video UrlType = "video"
|
||||
Embedding UrlType = "embedding"
|
||||
Token UrlType = "token"
|
||||
)
|
|
@ -0,0 +1,43 @@
|
|||
package doubao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DouBao struct {
|
||||
Model string
|
||||
Key string
|
||||
}
|
||||
|
||||
func NewDouBao(modelName string, key string) *DouBao {
|
||||
return &DouBao{
|
||||
Model: modelName,
|
||||
Key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *DouBao) GetData(ctx context.Context, respHandle func(input string) (string, error), Message []*model.ChatCompletionMessage) (string, error) {
|
||||
|
||||
client := arkruntime.NewClientWithApiKey(
|
||||
o.Key,
|
||||
//arkruntime.WithBaseUrl(UrlMap[url]),
|
||||
arkruntime.WithRegion("cn-beijing"),
|
||||
arkruntime.WithTimeout(2*time.Minute),
|
||||
arkruntime.WithRetryTimes(2),
|
||||
)
|
||||
req := model.CreateChatCompletionRequest{
|
||||
Model: o.Model,
|
||||
Messages: Message,
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result, err := respHandle(*resp.Choices[0].Message.Content.StringValue)
|
||||
|
||||
return result, err
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package pkg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"l_szr_go/entity"
|
||||
"os"
|
||||
)
|
||||
|
||||
func SaveAsWav(buffer *bytes.Buffer, filename string, config entity.Config) error {
|
||||
if buffer.Len() == 0 {
|
||||
return fmt.Errorf("缓冲区为空,无法保存")
|
||||
}
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建文件失败: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 写入 WAV 头部
|
||||
if err := writeWavHeader(file, config.SampleRate, config.BitDepth, config.Channels, buffer.Len()); err != nil {
|
||||
return fmt.Errorf("写入 WAV 头部失败: %v", err)
|
||||
}
|
||||
|
||||
// 写入 PCM 数据
|
||||
if _, err := buffer.WriteTo(file); err != nil {
|
||||
return fmt.Errorf("写入 PCM 数据失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeWavHeader(w io.Writer, sampleRate, bitDepth, channels, dataSize int) error {
|
||||
header := struct {
|
||||
ChunkID [4]byte
|
||||
ChunkSize uint32
|
||||
Format [4]byte
|
||||
Subchunk1ID [4]byte
|
||||
Subchunk1Size uint32
|
||||
AudioFormat uint16
|
||||
NumChannels uint16
|
||||
SampleRate uint32
|
||||
ByteRate uint32
|
||||
BlockAlign uint16
|
||||
BitsPerSample uint16
|
||||
Subchunk2ID [4]byte
|
||||
Subchunk2Size uint32
|
||||
}{
|
||||
ChunkID: [4]byte{'R', 'I', 'F', 'F'},
|
||||
ChunkSize: 36 + uint32(dataSize),
|
||||
Format: [4]byte{'W', 'A', 'V', 'E'},
|
||||
Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
|
||||
Subchunk1Size: 16,
|
||||
AudioFormat: 1, // PCM
|
||||
NumChannels: uint16(channels),
|
||||
SampleRate: uint32(sampleRate),
|
||||
ByteRate: uint32(sampleRate * channels * bitDepth / 8),
|
||||
BlockAlign: uint16(channels * bitDepth / 8),
|
||||
BitsPerSample: uint16(bitDepth),
|
||||
Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
|
||||
Subchunk2Size: uint32(dataSize),
|
||||
}
|
||||
|
||||
return binary.Write(w, binary.LittleEndian, header)
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
package szr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"l_szr_go/pkg/doubao"
|
||||
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"github.com/volcengine/volcengine-go-sdk/volcengine"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var modelObj *doubao.DouBao
|
||||
|
||||
func Instance(key, model string) *doubao.DouBao {
|
||||
if modelObj == nil {
|
||||
modelObj = doubao.NewDouBao(model, key)
|
||||
}
|
||||
return modelObj
|
||||
}
|
||||
|
||||
type View struct {
|
||||
Ques string
|
||||
CheckSkill []string
|
||||
QuesType string
|
||||
Ans string
|
||||
}
|
||||
|
||||
func TextCorrect(ctx context.Context, view *View, key string, modelStr string) (result string, err error) {
|
||||
Instance(key, modelStr)
|
||||
var texts = make([]*model.ChatCompletionMessage, 4)
|
||||
texts[0] = &model.ChatCompletionMessage{
|
||||
Role: model.ChatMessageRoleSystem,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String("你现在是AI面试评估系统的文本修正模块,请根据以下输入完成修正任务:"),
|
||||
},
|
||||
}
|
||||
texts[1] = &model.ChatCompletionMessage{
|
||||
Role: model.ChatMessageRoleUser,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(Input(view)),
|
||||
},
|
||||
}
|
||||
|
||||
texts[2] = &model.ChatCompletionMessage{
|
||||
Role: model.ChatMessageRoleUser,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(modify()),
|
||||
},
|
||||
}
|
||||
|
||||
texts[3] = &model.ChatCompletionMessage{
|
||||
Role: model.ChatMessageRoleUser,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(output()),
|
||||
},
|
||||
}
|
||||
|
||||
return modelObj.GetData(ctx, func(input string) (string, error) {
|
||||
return input, nil
|
||||
}, texts)
|
||||
}
|
||||
|
||||
func Input(view *View) string {
|
||||
str := `1. 【输入信息】
|
||||
- 面试问题: "{ques}"
|
||||
- 问题类型: "{quesType}"
|
||||
- 涉及知识点: "{checkSkill}"
|
||||
- 面试者回答: "{Ans}"`
|
||||
str = strings.ReplaceAll(str, "{ques}", view.Ques)
|
||||
str = strings.ReplaceAll(str, "{quesType}", view.QuesType)
|
||||
str = strings.ReplaceAll(str, "{checkSkill}", strings.Join(view.CheckSkill, ","))
|
||||
str = strings.ReplaceAll(str, "{Ans}", view.Ans)
|
||||
return str
|
||||
}
|
||||
|
||||
func modify() string {
|
||||
return `2. 【修正要求】
|
||||
- 仅修正语音转文字产生的错误(如谐音词/缺失字/多余字),保持回答原意不变
|
||||
- 对技术术语、专业名词的错误必须修正(如将"Redis"误写为"瑞迪斯")
|
||||
- 保留回答中的口语化表达(如"嗯""然后"),但需删除重复填充词(如连续3个"那个")
|
||||
- 根据提问对回答的内容进行0-100分的打分
|
||||
- 如果回答内容不足以满足涉及知识点,根据回答进行合理追问,分数小于20分就不用追问`
|
||||
}
|
||||
|
||||
func output() string {
|
||||
return `3. 【输出格式(json)】
|
||||
"{"text":"<修正后的完整文本>","score":<回答内容评分,数字类型>,"ask":"<追问内容文本>","not_enough_check_skill":"<为满足涉及的知识点(多个','隔开)>"}"`
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package szr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
modelInfo = "doubao-1-5-lite-32k-250115"
|
||||
keyInfo = "236ba4b6-9daa-4755-b22f-2fd274cd223a"
|
||||
)
|
||||
|
||||
func TestAddress(t *testing.T) {
|
||||
view := &View{
|
||||
Ques: "讲一下你对mysql索引的理解以及日常使用注意点",
|
||||
CheckSkill: []string{"mysql索引认知", "mysql索引的日常使用", "索引类型", "命名规范", "索引失效场景", "复合索引最左原则"},
|
||||
QuesType: "mysql索引",
|
||||
Ans: "我一般不用啊啊啊 索引,数据量太啊啊啊小了,没意大撒大撒义,索引不就反对反对是为了加速查询吗",
|
||||
}
|
||||
res, err := TextCorrect(context.Background(), view, keyInfo, modelInfo)
|
||||
t.Log(res, err)
|
||||
}
|
Loading…
Reference in New Issue