ai-courseware/eino-project/internal/data/data.go

185 lines
4.6 KiB
Go

package data
import (
"context"
"database/sql"
"time"
"eino-project/internal/conf"
contextpkg "eino-project/internal/domain/context"
"eino-project/internal/domain/monitor"
"eino-project/internal/domain/vector"
"github.com/go-kratos/kratos/v2/log"
"github.com/redis/go-redis/v9"
"github.com/google/wire"
_ "github.com/mattn/go-sqlite3"
)
// ProviderSet is data providers.
var ProviderSet = wire.NewSet(NewData)
// Data 数据层结构
type Data struct {
DB *sql.DB
RDB *redis.Client
VectorService vector.VectorService
DocProcessor vector.DocumentProcessor
KnowledgeSearcher vector.KnowledgeSearcher
log *log.Helper
ContextManager contextpkg.ContextManager
Monitor monitor.Monitor
}
// NewData 创建数据层
func NewData(c *conf.Data, logger log.Logger, vs vector.VectorService, dp vector.DocumentProcessor, ks vector.KnowledgeSearcher, contextManager contextpkg.ContextManager, monitor monitor.Monitor) (*Data, func(), error) {
helper := log.NewHelper(logger)
// 初始化 SQLite 数据库
db, err := sql.Open(c.Database.Driver, c.Database.Source)
if err != nil {
return nil, nil, err
}
// 初始化 Redis 客户端
rdb := redis.NewClient(&redis.Options{
Addr: c.Redis.Addr,
Password: c.Redis.Password,
DB: int(c.Redis.Db),
ReadTimeout: c.Redis.ReadTimeout.AsDuration(),
WriteTimeout: c.Redis.WriteTimeout.AsDuration(),
DialTimeout: c.Redis.DialTimeout.AsDuration(),
PoolSize: int(c.Redis.PoolSize),
})
// 测试 Redis 连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := rdb.Ping(ctx).Err(); err != nil {
helper.Warnf("Redis connection failed: %v", err)
}
// 初始化数据库表
if err := initTables(db); err != nil {
return nil, nil, err
}
d := &Data{
DB: db,
RDB: rdb,
VectorService: vs,
DocProcessor: dp,
KnowledgeSearcher: ks,
log: helper,
ContextManager: contextManager,
Monitor: monitor,
}
cleanup := func() {
helper.Info("Closing data resources")
if err := db.Close(); err != nil {
helper.Errorf("Failed to close database: %v", err)
}
if err := rdb.Close(); err != nil {
helper.Errorf("Failed to close redis: %v", err)
}
}
return d, cleanup, nil
}
// initTables 初始化数据库表
func initTables(db *sql.DB) error {
// 创建会话表
createSessionsTable := `
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);`
if _, err := db.Exec(createSessionsTable); err != nil {
return err
}
// 创建消息表
createMessagesTable := `
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
message TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
type TEXT DEFAULT 'user',
FOREIGN KEY (session_id) REFERENCES sessions(id)
);`
if _, err := db.Exec(createMessagesTable); err != nil {
return err
}
// 创建知识库文档表
createKnowledgeTable := `
CREATE TABLE IF NOT EXISTS knowledge_documents (
id TEXT PRIMARY KEY,
file_name TEXT NOT NULL,
file_type TEXT NOT NULL,
content_preview TEXT,
upload_time DATETIME DEFAULT CURRENT_TIMESTAMP,
status TEXT DEFAULT 'active'
);`
if _, err := db.Exec(createKnowledgeTable); err != nil {
return err
}
// 创建订单表
createOrdersTable := `
CREATE TABLE IF NOT EXISTS orders (
order_id TEXT PRIMARY KEY,
status TEXT NOT NULL,
product TEXT NOT NULL,
amount REAL NOT NULL,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
need_ai BOOLEAN DEFAULT FALSE
);`
if _, err := db.Exec(createOrdersTable); err != nil {
return err
}
// 插入一些模拟订单数据
insertMockOrders(db)
return nil
}
// insertMockOrders 插入模拟订单数据
func insertMockOrders(db *sql.DB) {
orders := []struct {
orderID string
status string
product string
amount float64
needAI bool
}{
{"ORD001", "已发货", "智能手机", 2999.00, false},
{"ORD002", "处理中", "笔记本电脑", 5999.00, true},
{"ORD003", "已完成", "无线耳机", 299.00, false},
{"ORD004", "退款中", "平板电脑", 1999.00, true},
{"ORD005", "已取消", "智能手表", 1299.00, false},
}
for _, order := range orders {
_, err := db.Exec(`
INSERT OR IGNORE INTO orders (order_id, status, product, amount, need_ai)
VALUES (?, ?, ?, ?, ?)`,
order.orderID, order.status, order.product, order.amount, order.needAI)
if err != nil {
// 忽略插入错误,可能是重复数据
continue
}
}
}