134 lines
2.7 KiB
Go
134 lines
2.7 KiB
Go
package utils
|
||
|
||
import (
|
||
"geo/internal/config"
|
||
"geo/utils/utils_gorm"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type Db struct {
|
||
Client *gorm.DB
|
||
}
|
||
|
||
func NewGormDb(c *config.Config) (*Db, func()) {
|
||
transDBClient, mf := utils_gorm.DBConn(&c.DB)
|
||
//directDBClient, df := directDB(c, hLog)
|
||
cleanup := func() {
|
||
mf()
|
||
//df()
|
||
}
|
||
return &Db{
|
||
Client: transDBClient,
|
||
//DirectDBClient: directDBClient,
|
||
}, cleanup
|
||
}
|
||
|
||
// GetOne 查询单条记录,返回 map
|
||
func (d *Db) GetOne(sql string, args ...interface{}) (map[string]interface{}, error) {
|
||
var result map[string]interface{}
|
||
|
||
// 使用 Raw 执行原生 SQL,Scan 到 map 需要先获取 rows
|
||
rows, err := d.Client.Raw(sql, args...).Rows()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
if rows.Next() {
|
||
// 获取列名
|
||
columns, err := rows.Columns()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 创建扫描用的切片
|
||
values := make([]interface{}, len(columns))
|
||
valuePtrs := make([]interface{}, len(columns))
|
||
for i := range values {
|
||
valuePtrs[i] = &values[i]
|
||
}
|
||
|
||
if err := rows.Scan(valuePtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result = make(map[string]interface{})
|
||
for i, col := range columns {
|
||
result[col] = values[i]
|
||
}
|
||
return result, nil
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
// GetAll 查询多条记录,返回 map 切片
|
||
func (d *Db) GetAll(sql string, args ...interface{}) ([]map[string]interface{}, error) {
|
||
rows, err := d.Client.Raw(sql, args...).Rows()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
results := make([]map[string]interface{}, 0)
|
||
|
||
for rows.Next() {
|
||
columns, err := rows.Columns()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
values := make([]interface{}, len(columns))
|
||
valuePtrs := make([]interface{}, len(columns))
|
||
for i := range values {
|
||
valuePtrs[i] = &values[i]
|
||
}
|
||
|
||
if err := rows.Scan(valuePtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
row := make(map[string]interface{})
|
||
for i, col := range columns {
|
||
row[col] = values[i]
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// Execute 执行单条 SQL(INSERT/UPDATE/DELETE),返回影响行数
|
||
func (d *Db) Execute(sql string, args ...interface{}) (int64, error) {
|
||
result := d.Client.Exec(sql, args...)
|
||
if result.Error != nil {
|
||
return 0, result.Error
|
||
}
|
||
return result.RowsAffected, nil
|
||
}
|
||
|
||
// ExecuteMany 批量执行 SQL,使用事务
|
||
func (d *Db) ExecuteMany(sql string, argsList [][]interface{}) (int64, error) {
|
||
var total int64
|
||
|
||
// 开始事务
|
||
tx := d.Client.Begin()
|
||
if tx.Error != nil {
|
||
return 0, tx.Error
|
||
}
|
||
|
||
for _, args := range argsList {
|
||
result := tx.Exec(sql, args...)
|
||
if result.Error != nil {
|
||
tx.Rollback()
|
||
return 0, result.Error
|
||
}
|
||
total += result.RowsAffected
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return 0, err
|
||
}
|
||
return total, nil
|
||
}
|