geoGo/utils/gorm.go

134 lines
2.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils
import (
"geo/internal/config"
"geo/utils/utils_gorm"
"gorm.io/gorm"
)
type Db struct {
Client *gorm.DB
}
func NewGormDb(c *config.Config) (*Db, func()) {
transDBClient, mf := utils_gorm.DBConn(&c.DB)
//directDBClient, df := directDB(c, hLog)
cleanup := func() {
mf()
//df()
}
return &Db{
Client: transDBClient,
//DirectDBClient: directDBClient,
}, cleanup
}
// GetOne 查询单条记录,返回 map
func (d *Db) GetOne(sql string, args ...interface{}) (map[string]interface{}, error) {
var result map[string]interface{}
// 使用 Raw 执行原生 SQLScan 到 map 需要先获取 rows
rows, err := d.Client.Raw(sql, args...).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() {
// 获取列名
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 创建扫描用的切片
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
result = make(map[string]interface{})
for i, col := range columns {
result[col] = values[i]
}
return result, nil
}
return nil, nil
}
// GetAll 查询多条记录,返回 map 切片
func (d *Db) GetAll(sql string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := d.Client.Raw(sql, args...).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
results := make([]map[string]interface{}, 0)
for rows.Next() {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
row := make(map[string]interface{})
for i, col := range columns {
row[col] = values[i]
}
results = append(results, row)
}
return results, nil
}
// Execute 执行单条 SQLINSERT/UPDATE/DELETE返回影响行数
func (d *Db) Execute(sql string, args ...interface{}) (int64, error) {
result := d.Client.Exec(sql, args...)
if result.Error != nil {
return 0, result.Error
}
return result.RowsAffected, nil
}
// ExecuteMany 批量执行 SQL使用事务
func (d *Db) ExecuteMany(sql string, argsList [][]interface{}) (int64, error) {
var total int64
// 开始事务
tx := d.Client.Begin()
if tx.Error != nil {
return 0, tx.Error
}
for _, args := range argsList {
result := tx.Exec(sql, args...)
if result.Error != nil {
tx.Rollback()
return 0, result.Error
}
total += result.RowsAffected
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return 0, err
}
return total, nil
}