This commit is contained in:
renzhiyuan 2025-07-27 00:27:02 +08:00
commit f0171e80ef
10 changed files with 434 additions and 0 deletions

9
entity/base.go Normal file
View File

@ -0,0 +1,9 @@
package entity
type Config struct {
SampleRate int `env:"AUDIO_SAMPLE_RATE" envDefault:"16000"`
Channels int `env:"AUDIO_CHANNELS" envDefault:"1"`
BitDepth int `env:"AUDIO_BIT_DEPTH" envDefault:"16"`
Port string `env:"PORT" envDefault:"8000"`
MaxBufferSize int `env:"MAX_BUFFER_SIZE" envDefault:"104857600"` // 100MB
}

30
go.mod Normal file
View File

@ -0,0 +1,30 @@
module l_szr_go
go 1.24.0
toolchain go1.24.5
require (
github.com/caarlos0/env/v6 v6.10.1
github.com/gofiber/fiber/v2 v2.52.9
github.com/gofiber/websocket/v2 v2.2.1
github.com/volcengine/volcengine-go-sdk v1.1.23
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.64.0 // indirect
github.com/volcengine/volc-sdk-golang v1.0.23 // indirect
golang.org/x/sys v0.34.0 // indirect
gopkg.in/yaml.v2 v2.2.8 // indirect
)

13
handle.go Normal file
View File

@ -0,0 +1,13 @@
package main
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/gofiber/websocket/v2"
"io"
"log"
"os"
"time"
)

127
main.go Normal file
View File

@ -0,0 +1,127 @@
package main
import (
"bytes"
"fmt"
"github.com/caarlos0/env/v6"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"l_szr_go/pkg"
"os/signal"
"syscall"
"time"
"l_szr_go/entity"
"log"
"os"
"sync"
)
var (
config entity.Config
activeConnections sync.WaitGroup
)
func Server() *fiber.App {
app := fiber.New()
// 健康检查
app.Get("/health", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
// WebSocket 路由
app.Use("/ws", func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
c.Locals("allowed", true)
return c.Next()
}
return fiber.ErrUpgradeRequired
})
app.Get("/ws", websocket.New(HandleWebSocket))
return app
}
func HandleWebSocket(c *websocket.Conn) {
activeConnections.Add(1)
defer activeConnections.Done()
var audioBuffer bytes.Buffer
fileName := fmt.Sprintf("output_%s.wav",
time.Now().Format("20060102_150405.000"))
log.Printf("客户端连接: %s", c.RemoteAddr())
defer func() {
if audioBuffer.Len() == 0 {
log.Println("未接收到音频数据,跳过保存")
return
}
if err := pkg.SaveAsWav(&audioBuffer, fileName, config); err != nil {
log.Printf("保存文件失败: %v", err)
} else {
log.Printf("音频已保存为 %s (%d字节)", fileName, audioBuffer.Len())
}
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("客户端异常断开: %v", err)
} else {
log.Println("客户端主动断开")
}
break
}
// 检查缓冲区大小
if audioBuffer.Len()+len(message) > config.MaxBufferSize {
log.Println("达到最大缓冲区大小,断开连接")
break
}
// 验证PCM数据 (基本验证)
if config.BitDepth == 16 && len(message)%2 != 0 {
log.Println("接收到无效的16位PCM数据长度不是2的倍数")
continue
}
if _, err := audioBuffer.Write(message); err != nil {
log.Printf("写入缓冲区失败: %v", err)
break
}
log.Printf("接收到 %d 字节音频数据 (总大小: %d字节)",
len(message), audioBuffer.Len())
}
}
func main() {
if err := env.Parse(&config); err != nil {
log.Fatalf("加载配置失败: %v", err)
}
app := Server()
// 启动服务器
go func() {
log.Printf("服务器启动,监听端口 %s", config.Port)
log.Printf("音频配置: %dHz, %d通道, %d位深",
config.SampleRate, config.Channels, config.BitDepth)
if err := app.Listen(":" + config.Port); err != nil {
log.Fatalf("服务器启动失败: %v", err)
}
}()
// 优雅关闭
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
<-quit
log.Println("正在关闭服务器...")
// 等待所有连接完成
activeConnections.Wait()
log.Println("所有连接已关闭")
os.Exit(0)
}

17
onnx/onnx.go Normal file
View File

@ -0,0 +1,17 @@
package onnx
import "C"
import "unsafe"
func onnx() {
// 1. 初始化 ONNX Runtime 环境
env := C.CreateEnv()
defer C.OrtReleaseEnv(env)
// 2. 加载模型(替换为实际模型路径)
modelPath := C.CString("sensevoice_small.onnx")
defer C.free(unsafe.Pointer(modelPath))
session := C.CreateSession(env, modelPath)
defer C.OrtReleaseSession(session)
}

17
pkg/doubao/constant.go Normal file
View File

@ -0,0 +1,17 @@
package doubao
var UrlMap = map[UrlType]string{
Text: "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
Video: "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
Embedding: "https://ark.cn-beijing.volces.com/api/v3/embeddings",
Token: "https://ark.cn-beijing.volces.com/api/v3/tokenization",
}
type UrlType string
const (
Text UrlType = "text"
Video UrlType = "video"
Embedding UrlType = "embedding"
Token UrlType = "token"
)

43
pkg/doubao/doubao.go Normal file
View File

@ -0,0 +1,43 @@
package doubao
import (
"context"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"time"
)
type DouBao struct {
Model string
Key string
}
func NewDouBao(modelName string, key string) *DouBao {
return &DouBao{
Model: modelName,
Key: key,
}
}
func (o *DouBao) GetData(ctx context.Context, respHandle func(input string) (string, error), Message []*model.ChatCompletionMessage) (string, error) {
client := arkruntime.NewClientWithApiKey(
o.Key,
//arkruntime.WithBaseUrl(UrlMap[url]),
arkruntime.WithRegion("cn-beijing"),
arkruntime.WithTimeout(2*time.Minute),
arkruntime.WithRetryTimes(2),
)
req := model.CreateChatCompletionRequest{
Model: o.Model,
Messages: Message,
}
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
result, err := respHandle(*resp.Choices[0].Message.Content.StringValue)
return result, err
}

68
pkg/func.go Normal file
View File

@ -0,0 +1,68 @@
package pkg
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"l_szr_go/entity"
"os"
)
func SaveAsWav(buffer *bytes.Buffer, filename string, config entity.Config) error {
if buffer.Len() == 0 {
return fmt.Errorf("缓冲区为空,无法保存")
}
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("创建文件失败: %v", err)
}
defer file.Close()
// 写入 WAV 头部
if err := writeWavHeader(file, config.SampleRate, config.BitDepth, config.Channels, buffer.Len()); err != nil {
return fmt.Errorf("写入 WAV 头部失败: %v", err)
}
// 写入 PCM 数据
if _, err := buffer.WriteTo(file); err != nil {
return fmt.Errorf("写入 PCM 数据失败: %v", err)
}
return nil
}
func writeWavHeader(w io.Writer, sampleRate, bitDepth, channels, dataSize int) error {
header := struct {
ChunkID [4]byte
ChunkSize uint32
Format [4]byte
Subchunk1ID [4]byte
Subchunk1Size uint32
AudioFormat uint16
NumChannels uint16
SampleRate uint32
ByteRate uint32
BlockAlign uint16
BitsPerSample uint16
Subchunk2ID [4]byte
Subchunk2Size uint32
}{
ChunkID: [4]byte{'R', 'I', 'F', 'F'},
ChunkSize: 36 + uint32(dataSize),
Format: [4]byte{'W', 'A', 'V', 'E'},
Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
Subchunk1Size: 16,
AudioFormat: 1, // PCM
NumChannels: uint16(channels),
SampleRate: uint32(sampleRate),
ByteRate: uint32(sampleRate * channels * bitDepth / 8),
BlockAlign: uint16(channels * bitDepth / 8),
BitsPerSample: uint16(bitDepth),
Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
Subchunk2Size: uint32(dataSize),
}
return binary.Write(w, binary.LittleEndian, header)
}

88
szr/szr.go Normal file
View File

@ -0,0 +1,88 @@
package szr
import (
"context"
"l_szr_go/pkg/doubao"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine"
"strings"
)
var modelObj *doubao.DouBao
func Instance(key, model string) *doubao.DouBao {
if modelObj == nil {
modelObj = doubao.NewDouBao(model, key)
}
return modelObj
}
type View struct {
Ques string
CheckSkill []string
QuesType string
Ans string
}
func TextCorrect(ctx context.Context, view *View, key string, modelStr string) (result string, err error) {
Instance(key, modelStr)
var texts = make([]*model.ChatCompletionMessage, 4)
texts[0] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleSystem,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String("你现在是AI面试评估系统的文本修正模块请根据以下输入完成修正任务"),
},
}
texts[1] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(Input(view)),
},
}
texts[2] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(modify()),
},
}
texts[3] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(output()),
},
}
return modelObj.GetData(ctx, func(input string) (string, error) {
return input, nil
}, texts)
}
func Input(view *View) string {
str := `1. 输入信息
- 面试问题: "{ques}"
- 问题类型: "{quesType}"
- 涉及知识点: "{checkSkill}"
- 面试者回答: "{Ans}"`
str = strings.ReplaceAll(str, "{ques}", view.Ques)
str = strings.ReplaceAll(str, "{quesType}", view.QuesType)
str = strings.ReplaceAll(str, "{checkSkill}", strings.Join(view.CheckSkill, ","))
str = strings.ReplaceAll(str, "{Ans}", view.Ans)
return str
}
func modify() string {
return `2. 修正要求
- 仅修正语音转文字产生的错误如谐音词/缺失字/多余字保持回答原意不变
- 对技术术语专业名词的错误必须修正如将"Redis"误写为"瑞迪斯"
- 保留回答中的口语化表达"嗯""然后"但需删除重复填充词如连续3个"那个"
- 根据提问对回答的内容进行0-100分的打分
- 如果回答内容不足以满足涉及知识点根据回答进行合理追问分数小于20分就不用追问`
}
func output() string {
return `3. 输出格式json
"{"text":"<修正后的完整文本>","score":<回答内容评分,数字类型>,"ask":"<追问内容文本>","not_enough_check_skill":"<为满足涉及的知识点(多个','隔开)>"}"`
}

22
szr/szr_test.go Normal file
View File

@ -0,0 +1,22 @@
package szr
import (
"context"
"testing"
)
var (
modelInfo = "doubao-1-5-lite-32k-250115"
keyInfo = "236ba4b6-9daa-4755-b22f-2fd274cd223a"
)
func TestAddress(t *testing.T) {
view := &View{
Ques: "讲一下你对mysql索引的理解以及日常使用注意点",
CheckSkill: []string{"mysql索引认知", "mysql索引的日常使用", "索引类型", "命名规范", "索引失效场景", "复合索引最左原则"},
QuesType: "mysql索引",
Ans: "我一般不用啊啊啊 索引,数据量太啊啊啊小了,没意大撒大撒义,索引不就反对反对是为了加速查询吗",
}
res, err := TextCorrect(context.Background(), view, keyInfo, modelInfo)
t.Log(res, err)
}