From f0171e80ef54f7cc1584a3d05c54f586efd2b38a Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Sun, 27 Jul 2025 00:27:02 +0800 Subject: [PATCH] 1 --- entity/base.go | 9 +++ go.mod | 30 ++++++++++ handle.go | 13 +++++ main.go | 127 +++++++++++++++++++++++++++++++++++++++++ onnx/onnx.go | 17 ++++++ pkg/doubao/constant.go | 17 ++++++ pkg/doubao/doubao.go | 43 ++++++++++++++ pkg/func.go | 68 ++++++++++++++++++++++ szr/szr.go | 88 ++++++++++++++++++++++++++++ szr/szr_test.go | 22 +++++++ 10 files changed, 434 insertions(+) create mode 100644 entity/base.go create mode 100644 go.mod create mode 100644 handle.go create mode 100644 main.go create mode 100644 onnx/onnx.go create mode 100644 pkg/doubao/constant.go create mode 100644 pkg/doubao/doubao.go create mode 100644 pkg/func.go create mode 100644 szr/szr.go create mode 100644 szr/szr_test.go diff --git a/entity/base.go b/entity/base.go new file mode 100644 index 0000000..1811043 --- /dev/null +++ b/entity/base.go @@ -0,0 +1,9 @@ +package entity + +type Config struct { + SampleRate int `env:"AUDIO_SAMPLE_RATE" envDefault:"16000"` + Channels int `env:"AUDIO_CHANNELS" envDefault:"1"` + BitDepth int `env:"AUDIO_BIT_DEPTH" envDefault:"16"` + Port string `env:"PORT" envDefault:"8000"` + MaxBufferSize int `env:"MAX_BUFFER_SIZE" envDefault:"104857600"` // 100MB +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9c63fb1 --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module l_szr_go + +go 1.24.0 + +toolchain go1.24.5 + +require ( + github.com/caarlos0/env/v6 v6.10.1 + github.com/gofiber/fiber/v2 v2.52.9 + github.com/gofiber/websocket/v2 v2.2.1 + github.com/volcengine/volcengine-go-sdk v1.1.23 +) + +require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/fasthttp/websocket v1.5.3 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.64.0 // indirect + github.com/volcengine/volc-sdk-golang v1.0.23 // indirect + golang.org/x/sys v0.34.0 // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect +) diff --git a/handle.go b/handle.go new file mode 100644 index 0000000..0d1e3d6 --- /dev/null +++ b/handle.go @@ -0,0 +1,13 @@ +package main + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/gofiber/websocket/v2" + "io" + "log" + "os" + "time" +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..c5c6c56 --- /dev/null +++ b/main.go @@ -0,0 +1,127 @@ +package main + +import ( + "bytes" + "fmt" + "github.com/caarlos0/env/v6" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/websocket/v2" + "l_szr_go/pkg" + "os/signal" + "syscall" + "time" + + "l_szr_go/entity" + "log" + "os" + "sync" +) + +var ( + config entity.Config + activeConnections sync.WaitGroup +) + +func Server() *fiber.App { + app := fiber.New() + + // 健康检查 + app.Get("/health", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + // WebSocket 路由 + app.Use("/ws", func(c *fiber.Ctx) error { + if websocket.IsWebSocketUpgrade(c) { + c.Locals("allowed", true) + return c.Next() + } + return fiber.ErrUpgradeRequired + }) + + app.Get("/ws", websocket.New(HandleWebSocket)) + return app +} + +func HandleWebSocket(c *websocket.Conn) { + activeConnections.Add(1) + defer activeConnections.Done() + + var audioBuffer bytes.Buffer + fileName := fmt.Sprintf("output_%s.wav", + time.Now().Format("20060102_150405.000")) + + log.Printf("客户端连接: %s", c.RemoteAddr()) + + defer func() { + if audioBuffer.Len() == 0 { + log.Println("未接收到音频数据,跳过保存") + return + } + + if err := pkg.SaveAsWav(&audioBuffer, fileName, config); err != nil { + log.Printf("保存文件失败: %v", err) + } else { + log.Printf("音频已保存为 %s (%d字节)", fileName, audioBuffer.Len()) + } + }() + + for { + _, message, err := c.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("客户端异常断开: %v", err) + } else { + log.Println("客户端主动断开") + } + break + } + + // 检查缓冲区大小 + if audioBuffer.Len()+len(message) > config.MaxBufferSize { + log.Println("达到最大缓冲区大小,断开连接") + break + } + + // 验证PCM数据 (基本验证) + if config.BitDepth == 16 && len(message)%2 != 0 { + log.Println("接收到无效的16位PCM数据,长度不是2的倍数") + continue + } + + if _, err := audioBuffer.Write(message); err != nil { + log.Printf("写入缓冲区失败: %v", err) + break + } + + log.Printf("接收到 %d 字节音频数据 (总大小: %d字节)", + len(message), audioBuffer.Len()) + } +} + +func main() { + if err := env.Parse(&config); err != nil { + log.Fatalf("加载配置失败: %v", err) + } + app := Server() + // 启动服务器 + go func() { + log.Printf("服务器启动,监听端口 %s", config.Port) + log.Printf("音频配置: %dHz, %d通道, %d位深", + config.SampleRate, config.Channels, config.BitDepth) + if err := app.Listen(":" + config.Port); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } + }() + + // 优雅关闭 + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + <-quit + log.Println("正在关闭服务器...") + + // 等待所有连接完成 + activeConnections.Wait() + log.Println("所有连接已关闭") + os.Exit(0) +} diff --git a/onnx/onnx.go b/onnx/onnx.go new file mode 100644 index 0000000..84d891a --- /dev/null +++ b/onnx/onnx.go @@ -0,0 +1,17 @@ +package onnx + +import "C" +import "unsafe" + +func onnx() { + // 1. 初始化 ONNX Runtime 环境 + env := C.CreateEnv() + defer C.OrtReleaseEnv(env) + + // 2. 加载模型(替换为实际模型路径) + modelPath := C.CString("sensevoice_small.onnx") + defer C.free(unsafe.Pointer(modelPath)) + session := C.CreateSession(env, modelPath) + defer C.OrtReleaseSession(session) + +} diff --git a/pkg/doubao/constant.go b/pkg/doubao/constant.go new file mode 100644 index 0000000..1f00ea6 --- /dev/null +++ b/pkg/doubao/constant.go @@ -0,0 +1,17 @@ +package doubao + +var UrlMap = map[UrlType]string{ + Text: "https://ark.cn-beijing.volces.com/api/v3/chat/completions", + Video: "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", + Embedding: "https://ark.cn-beijing.volces.com/api/v3/embeddings", + Token: "https://ark.cn-beijing.volces.com/api/v3/tokenization", +} + +type UrlType string + +const ( + Text UrlType = "text" + Video UrlType = "video" + Embedding UrlType = "embedding" + Token UrlType = "token" +) diff --git a/pkg/doubao/doubao.go b/pkg/doubao/doubao.go new file mode 100644 index 0000000..b786787 --- /dev/null +++ b/pkg/doubao/doubao.go @@ -0,0 +1,43 @@ +package doubao + +import ( + "context" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "time" +) + +type DouBao struct { + Model string + Key string +} + +func NewDouBao(modelName string, key string) *DouBao { + return &DouBao{ + Model: modelName, + Key: key, + } +} + +func (o *DouBao) GetData(ctx context.Context, respHandle func(input string) (string, error), Message []*model.ChatCompletionMessage) (string, error) { + + client := arkruntime.NewClientWithApiKey( + o.Key, + //arkruntime.WithBaseUrl(UrlMap[url]), + arkruntime.WithRegion("cn-beijing"), + arkruntime.WithTimeout(2*time.Minute), + arkruntime.WithRetryTimes(2), + ) + req := model.CreateChatCompletionRequest{ + Model: o.Model, + Messages: Message, + } + + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + return "", err + } + result, err := respHandle(*resp.Choices[0].Message.Content.StringValue) + + return result, err +} diff --git a/pkg/func.go b/pkg/func.go new file mode 100644 index 0000000..26e2333 --- /dev/null +++ b/pkg/func.go @@ -0,0 +1,68 @@ +package pkg + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "l_szr_go/entity" + "os" +) + +func SaveAsWav(buffer *bytes.Buffer, filename string, config entity.Config) error { + if buffer.Len() == 0 { + return fmt.Errorf("缓冲区为空,无法保存") + } + + file, err := os.Create(filename) + if err != nil { + return fmt.Errorf("创建文件失败: %v", err) + } + defer file.Close() + + // 写入 WAV 头部 + if err := writeWavHeader(file, config.SampleRate, config.BitDepth, config.Channels, buffer.Len()); err != nil { + return fmt.Errorf("写入 WAV 头部失败: %v", err) + } + + // 写入 PCM 数据 + if _, err := buffer.WriteTo(file); err != nil { + return fmt.Errorf("写入 PCM 数据失败: %v", err) + } + + return nil +} + +func writeWavHeader(w io.Writer, sampleRate, bitDepth, channels, dataSize int) error { + header := struct { + ChunkID [4]byte + ChunkSize uint32 + Format [4]byte + Subchunk1ID [4]byte + Subchunk1Size uint32 + AudioFormat uint16 + NumChannels uint16 + SampleRate uint32 + ByteRate uint32 + BlockAlign uint16 + BitsPerSample uint16 + Subchunk2ID [4]byte + Subchunk2Size uint32 + }{ + ChunkID: [4]byte{'R', 'I', 'F', 'F'}, + ChunkSize: 36 + uint32(dataSize), + Format: [4]byte{'W', 'A', 'V', 'E'}, + Subchunk1ID: [4]byte{'f', 'm', 't', ' '}, + Subchunk1Size: 16, + AudioFormat: 1, // PCM + NumChannels: uint16(channels), + SampleRate: uint32(sampleRate), + ByteRate: uint32(sampleRate * channels * bitDepth / 8), + BlockAlign: uint16(channels * bitDepth / 8), + BitsPerSample: uint16(bitDepth), + Subchunk2ID: [4]byte{'d', 'a', 't', 'a'}, + Subchunk2Size: uint32(dataSize), + } + + return binary.Write(w, binary.LittleEndian, header) +} diff --git a/szr/szr.go b/szr/szr.go new file mode 100644 index 0000000..6871270 --- /dev/null +++ b/szr/szr.go @@ -0,0 +1,88 @@ +package szr + +import ( + "context" + "l_szr_go/pkg/doubao" + + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "github.com/volcengine/volcengine-go-sdk/volcengine" + "strings" +) + +var modelObj *doubao.DouBao + +func Instance(key, model string) *doubao.DouBao { + if modelObj == nil { + modelObj = doubao.NewDouBao(model, key) + } + return modelObj +} + +type View struct { + Ques string + CheckSkill []string + QuesType string + Ans string +} + +func TextCorrect(ctx context.Context, view *View, key string, modelStr string) (result string, err error) { + Instance(key, modelStr) + var texts = make([]*model.ChatCompletionMessage, 4) + texts[0] = &model.ChatCompletionMessage{ + Role: model.ChatMessageRoleSystem, + Content: &model.ChatCompletionMessageContent{ + StringValue: volcengine.String("你现在是AI面试评估系统的文本修正模块,请根据以下输入完成修正任务:"), + }, + } + texts[1] = &model.ChatCompletionMessage{ + Role: model.ChatMessageRoleUser, + Content: &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(Input(view)), + }, + } + + texts[2] = &model.ChatCompletionMessage{ + Role: model.ChatMessageRoleUser, + Content: &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(modify()), + }, + } + + texts[3] = &model.ChatCompletionMessage{ + Role: model.ChatMessageRoleUser, + Content: &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(output()), + }, + } + + return modelObj.GetData(ctx, func(input string) (string, error) { + return input, nil + }, texts) +} + +func Input(view *View) string { + str := `1. 【输入信息】 + - 面试问题: "{ques}" + - 问题类型: "{quesType}" + - 涉及知识点: "{checkSkill}" + - 面试者回答: "{Ans}"` + str = strings.ReplaceAll(str, "{ques}", view.Ques) + str = strings.ReplaceAll(str, "{quesType}", view.QuesType) + str = strings.ReplaceAll(str, "{checkSkill}", strings.Join(view.CheckSkill, ",")) + str = strings.ReplaceAll(str, "{Ans}", view.Ans) + return str +} + +func modify() string { + return `2. 【修正要求】 + - 仅修正语音转文字产生的错误(如谐音词/缺失字/多余字),保持回答原意不变 + - 对技术术语、专业名词的错误必须修正(如将"Redis"误写为"瑞迪斯") + - 保留回答中的口语化表达(如"嗯""然后"),但需删除重复填充词(如连续3个"那个") + - 根据提问对回答的内容进行0-100分的打分 + - 如果回答内容不足以满足涉及知识点,根据回答进行合理追问,分数小于20分就不用追问` +} + +func output() string { + return `3. 【输出格式(json)】 + "{"text":"<修正后的完整文本>","score":<回答内容评分,数字类型>,"ask":"<追问内容文本>","not_enough_check_skill":"<为满足涉及的知识点(多个','隔开)>"}"` +} diff --git a/szr/szr_test.go b/szr/szr_test.go new file mode 100644 index 0000000..97ebdea --- /dev/null +++ b/szr/szr_test.go @@ -0,0 +1,22 @@ +package szr + +import ( + "context" + "testing" +) + +var ( + modelInfo = "doubao-1-5-lite-32k-250115" + keyInfo = "236ba4b6-9daa-4755-b22f-2fd274cd223a" +) + +func TestAddress(t *testing.T) { + view := &View{ + Ques: "讲一下你对mysql索引的理解以及日常使用注意点", + CheckSkill: []string{"mysql索引认知", "mysql索引的日常使用", "索引类型", "命名规范", "索引失效场景", "复合索引最左原则"}, + QuesType: "mysql索引", + Ans: "我一般不用啊啊啊 索引,数据量太啊啊啊小了,没意大撒大撒义,索引不就反对反对是为了加速查询吗", + } + res, err := TextCorrect(context.Background(), view, keyInfo, modelInfo) + t.Log(res, err) +}