74 lines
1.6 KiB
Go
74 lines
1.6 KiB
Go
package data
|
|
|
|
import (
|
|
"center-api/internal/conf"
|
|
"center-api/internal/data/ent"
|
|
"center-api/internal/data/ent/intercept"
|
|
"context"
|
|
ent2 "entgo.io/ent"
|
|
"entgo.io/ent/dialect/sql"
|
|
"fmt"
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
// ContextTxKey 用来承载事务的上下文
|
|
type ContextTxKey struct{}
|
|
|
|
// Db .
|
|
type Db struct {
|
|
Client *ent.Client
|
|
}
|
|
|
|
// NewDb .
|
|
func NewDb(c *conf.Bootstrap, hLog *log.Helper) (*Db, func(), error) {
|
|
client, err := buildDbClient(c.Data.Db)
|
|
|
|
cleanup := func() {
|
|
if client != nil {
|
|
if err := client.Close(); err != nil {
|
|
hLog.Error("关闭 masterClient 失败:", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return &Db{Client: client}, cleanup, err
|
|
}
|
|
|
|
func (d Db) GetDb(ctx context.Context) *ent.Client {
|
|
tx, ok := ctx.Value(ContextTxKey{}).(*ent.Tx)
|
|
if ok {
|
|
return tx.Client()
|
|
}
|
|
return d.Client
|
|
}
|
|
|
|
// buildDbClient 构建db client
|
|
func buildDbClient(dbConf *conf.Data_Database) (*ent.Client, error) {
|
|
db, err := sql.Open(dbConf.GetDriver(), dbConf.GetSource())
|
|
if err != nil {
|
|
fmt.Println("连接db 失败: ", err)
|
|
return nil, err
|
|
}
|
|
db.DB().SetMaxIdleConns(int(dbConf.MaxIdle))
|
|
db.DB().SetMaxOpenConns(int(dbConf.MaxOpen))
|
|
db.DB().SetConnMaxLifetime(dbConf.MaxLifetime.AsDuration())
|
|
var client *ent.Client
|
|
client = ent.NewClient(ent.Driver(db))
|
|
|
|
//关闭默认启用的 select distinct
|
|
client.Intercept(
|
|
intercept.Func(func(ctx context.Context, q intercept.Query) error {
|
|
// Skip setting Unique if the modifier was set explicitly.
|
|
if ent2.QueryFromContext(ctx).Unique == nil {
|
|
q.Unique(false)
|
|
}
|
|
return nil
|
|
}),
|
|
)
|
|
if dbConf.IsDebug {
|
|
client = client.Debug()
|
|
}
|
|
return client, nil
|
|
}
|