ai_scheduler/internal/biz/advice_file.go

190 lines
5.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/biz/llm_service/third_party"
"ai_scheduler/internal/entitys"
"ai_scheduler/internal/pkg"
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"github.com/gofiber/fiber/v2/log"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine"
)
type AdviceFileBiz struct {
hsyq *third_party.Hsyq
}
func NewAdviceFileBiz(hsyq *third_party.Hsyq) *AdviceFileBiz {
return &AdviceFileBiz{
hsyq: hsyq,
}
}
const (
key = "236ba4b6-9daa-4755-b22f-2fd274cd223a"
fileModel = "doubao-seed-1-8-251228"
jsonModel = "doubao-seed-1-6-flash-250828"
)
var DataMap = map[string]entitys.AdviceData{
"dialectFeatures": &entitys.DialectFeatures{},
"sentencePatterns": &entitys.SentencePatterns{},
"personalityTags": &entitys.PersonalityTags{},
"toneTags": &entitys.ToneTags{},
"signatureDialogues": &entitys.SignatureDialogues{},
"regionValue": &entitys.RegionValue{},
"competitionComparison": &entitys.CompetitionComparison{},
"coreSellingPoints": &entitys.CoreSellingPoints{},
"supportingFacilities": &entitys.SupportingFacilities{},
"developerBacking": &entitys.DeveloperBacking{},
"needsMining": &entitys.NeedsMining{},
"painPointResponse": &entitys.PainPointResponse{},
"valueBuilding": &entitys.ValueBuilding{},
"closingTechniques": &entitys.ClosingTechniques{},
"communicationRhythm": &entitys.CommunicationRhythm{},
"customer": &entitys.Customer{},
}
func (a *AdviceFileBiz) WordAna(ctx context.Context, wordContent string) (map[entitys.AdviceRole]map[string]entitys.AdviceData, error) {
timeSte := time.Now().Format("200601021504")
dir := "./cache/" + timeSte
os.Mkdir(dir, 0755)
//获取示例
examples := a.getAllExamples()
//构建提示词
prompt := a.buildSimplePrompt(wordContent, examples)
os.WriteFile(dir+"/requset.json", []byte(prompt), 0644)
//llm提取信息
anaContent, err := a.callLlm(ctx, prompt, fileModel)
if err != nil {
return nil, err
}
os.WriteFile(dir+"/res.json", []byte(anaContent), 0644)
//格式整理
data, err := a.parseResponse(ctx, []byte(anaContent))
if err != nil {
return nil, err
}
//组装数据
resData := a.cateData(data)
os.WriteFile("./cache/"+timeSte+"/extracted.json", pkg.JsonByteIgonErr(resData), 0644)
return resData, err
}
func (a *AdviceFileBiz) cateData(data map[string]entitys.AdviceData) map[entitys.AdviceRole]map[string]entitys.AdviceData {
var res = make(map[entitys.AdviceRole]map[string]entitys.AdviceData)
for k, v := range data {
if _, ok := res[v.Role()]; !ok {
res[v.Role()] = make(map[string]entitys.AdviceData)
}
res[v.Role()][k] = v
}
return res
}
func (a *AdviceFileBiz) parseResponse(ctx context.Context, responseByte []byte) (resultOutPut map[string]entitys.AdviceData, err error) {
//只尝试修复一次
if isValid := json.Valid(responseByte); !isValid {
responseByte, err = a.fixJson(ctx, responseByte)
if err != nil {
return nil, fmt.Errorf("json格式错误修复失败:%s", err.Error())
}
}
if isValid := json.Valid(responseByte); !isValid {
return nil, fmt.Errorf("json格式错误")
}
var (
result map[string]interface{}
)
resultOutPut = make(map[string]entitys.AdviceData)
if err = json.Unmarshal(responseByte, &result); err != nil {
return
}
for k, v := range result {
if _, ok := DataMap[k]; !ok {
return
}
var vbyte []byte
if vbyte, err = json.Marshal(v); err != nil {
return
}
newData := DataMap[k].Copy()
if err = json.Unmarshal(vbyte, newData); err != nil {
return
}
resultOutPut[k] = newData
}
return
}
func (a *AdviceFileBiz) fixJson(ctx context.Context, json []byte) ([]byte, error) {
prompt := "你是一个专业的JSON修复专家。请帮我修复以下错误的JSON格式。\n\n要求\n1. 保持原有数据的结构和内容不变\n2. 修复JSON语法错误\n3. 输出格式化的正确JSON\n4. 简要说明修复了哪些问题\n\n错误的JSON\n" + string(json) + "\n\n请直接输出修复后的JSON。"
call, err := a.callLlm(ctx, prompt, jsonModel)
if err != nil {
return nil, err
}
return []byte(call), nil
}
func (a *AdviceFileBiz) callLlm(ctx context.Context, prompt string, modelName string) (string, error) {
var message = make([]*model.ChatCompletionMessage, 1)
message[0] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(prompt),
},
}
res, err := a.hsyq.RequestHsyq(ctx, key, modelName, message)
if err != nil {
return "", err
}
log.Info("token用量", res.Usage.TotalTokens)
return *res.Choices[0].Message.Content.StringValue, nil
}
func (a *AdviceFileBiz) getAllExamples() map[string]entitys.AdviceData {
return DataMap
}
func (a *AdviceFileBiz) buildSimplePrompt(wordContent string, examples map[string]entitys.AdviceData) string {
// 最简单的提示词模板
template := `分析以下房地产销售对话,按指定格式提取信息:
对话内容:
%s
请按照以下` + fmt.Sprintf("%d", len(examples)) + `个格式生成JSON数据key为格式名称value为对应值
%s
输出要求:
1. 所有内容必须严格基于提供的对话原文,不得编造(重要!)
2. 每个结构体一个JSON对象
3. 严格按照示例格式
4. 将上述生成的` + fmt.Sprintf("%d", len(examples)) + `个JSON对象,json不需要有可读性,不要有特殊符号,比如"\n",用map[string]json来包裹所有json对象{"SupportingFacilities":{...},"SignatureDialogues":[{...},{...}]}`
// 构建格式部分
var formats strings.Builder
for name, example := range examples {
formats.WriteString(fmt.Sprintf("=== %s (%s:%s)===\n示例%s\n\n", name, entitys.RoleDesc[example.Role()], example.Desc(), example.Example()))
}
return fmt.Sprintf(template, wordContent, formats.String())
}