ai_scheduler/internal/biz/advice.go

157 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package biz
import (
"ai_scheduler/internal/biz/llm_service/third_party"
"ai_scheduler/internal/entitys"
"context"
"encoding/json"
"fmt"
"os"
"strings"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"github.com/volcengine/volcengine-go-sdk/volcengine"
)
type AdviceBiz struct {
hsyq *third_party.Hsyq
}
func NewAdviceBiz(hsyq *third_party.Hsyq) *AdviceBiz {
return &AdviceBiz{
hsyq: hsyq,
}
}
const (
key = "236ba4b6-9daa-4755-b22f-2fd274cd223a"
modelName = "doubao-seed-1-8-251228"
)
var dataMap = map[string]string{
"DialectFeatures": (&entitys.DialectFeatures{}).Example(),
"SentencePatterns": (&entitys.SentencePatterns{}).Example(),
"PersonalityTags": (&entitys.PersonalityTags{}).Example(),
"ToneTags": (&entitys.ToneTags{}).Example(),
"SignatureDialogues": (&entitys.SignatureDialogues{}).Example(),
"RegionValue": (&entitys.RegionValue{}).Example(),
"CompetitionComparison": (&entitys.CompetitionComparison{}).Example(),
"CoreSellingPoints": (&entitys.CoreSellingPoints{}).Example(),
"SupportingFacilities": (&entitys.SupportingFacilities{}).Example(),
"DeveloperBacking": (&entitys.DeveloperBacking{}).Example(),
"NeedsMining": (&entitys.NeedsMining{}).Example(),
"PainPointResponse": (&entitys.PainPointResponse{}).Example(),
"ValueBuilding": (&entitys.ValueBuilding{}).Example(),
"ClosingTechniques": (&entitys.ClosingTechniques{}).Example(),
"CommunicationRhythm": (&entitys.CommunicationRhythm{}).Example(),
}
func (a *AdviceBiz) WordAna(ctx context.Context, wordContent string) error {
examples := a.getAllExamples()
prompt := a.buildSimplePrompt(wordContent, examples)
os.WriteFile("requset.json", []byte(prompt), 0644)
anaContent, err := a.callLlm(ctx, prompt)
if err != nil {
return err
}
os.WriteFile("res.json", []byte(anaContent), 0644)
data := a.parseResponse(anaContent)
jsonData, _ := json.MarshalIndent(data, "", " ")
os.WriteFile("extracted.json", jsonData, 0644)
fmt.Println("✅ 数据已保存到 extracted.json")
return nil
}
func (a *AdviceBiz) callLlm(ctx context.Context, prompt string) (string, error) {
var message = make([]*model.ChatCompletionMessage, 1)
message[0] = &model.ChatCompletionMessage{
Role: model.ChatMessageRoleUser,
Content: &model.ChatCompletionMessageContent{
StringValue: volcengine.String(prompt),
},
}
res, err := a.hsyq.RequestHsyq(ctx, key, modelName, message)
if err != nil {
return "", err
}
return *res.Choices[0].Message.Content.StringValue, nil
}
func (a *AdviceBiz) getAllExamples() map[string]string {
return dataMap
}
func (a *AdviceBiz) buildSimplePrompt(wordContent string, examples map[string]string) string {
// 最简单的提示词模板
template := `分析以下房地产销售对话,按指定格式提取信息:
对话内容:
%s
请按照以下` + fmt.Sprintf("%d", len(examples)) + `个格式生成JSON数据key为格式名称value为对应值
%s
输出要求:
1. 每个结构体一个JSON对象
2. 所有内容必须严格基于提供的对话原文,不得编造
3. 严格按照示例格式
4. 将上述生成的` + fmt.Sprintf("%d", len(examples)) + `个JSON对象,json不需要有可读性,不要有特殊符号,比如"\n",用map[string]json来包裹所有json对象{"SupportingFacilities":{...},"SignatureDialogues":[{...},{...}]}`
// 构建格式部分
var formats strings.Builder
for name, example := range examples {
formats.WriteString(fmt.Sprintf("=== %s ===\n示例%s\n\n", name, example))
}
return fmt.Sprintf(template, wordContent, formats.String())
}
func (a *AdviceBiz) parseResponse(response string) map[string]interface{} {
result := make(map[string]interface{})
// 按空行分割
parts := strings.Split(response, "\n\n")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" || !strings.Contains(part, "{") {
continue
}
// 找到第一个 { 和最后一个 }
start := strings.Index(part, "{")
end := strings.LastIndex(part, "}")
if start == -1 || end == -1 || end <= start {
continue
}
jsonStr := part[start : end+1]
// 尝试解析
var data interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err == nil {
// 判断是什么结构体
for _, name := range getStructNames() {
if strings.Contains(jsonStr, `"`+name+`"`) || strings.Contains(part, name) {
result[name] = data
break
}
}
}
}
return result
}
func getStructNames() []string {
var res = make([]string, 0, len(dataMap))
for k, _ := range dataMap {
res = append(res, k)
}
return res
}