100 lines
1.7 KiB
Go
100 lines
1.7 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"excel_export/biz/export"
|
|
"fmt"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"reflect"
|
|
)
|
|
|
|
var _ export.DataFetcher = new(Db)
|
|
|
|
type Db struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewDb(str string) (*Db, error) {
|
|
db, err := gorm.Open(mysql.Open(str + ""))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
//db = db.Debug()
|
|
return &Db{
|
|
db: db,
|
|
}, nil
|
|
}
|
|
|
|
func (d *Db) Fetch(s string) (*export.Data, error) {
|
|
|
|
rows, err := d.db.Raw(s).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
titles, values, err := getRowStruct(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var data []interface{}
|
|
|
|
for rows.Next() {
|
|
rows.Scan(values...)
|
|
row := make([]interface{}, len(values))
|
|
for i, value := range values {
|
|
|
|
vv := reflect.ValueOf(value)
|
|
if vv.Kind() == reflect.Ptr {
|
|
//将指针对象转换为值对象
|
|
value = vv.Elem().Interface()
|
|
}
|
|
|
|
switch vvv := value.(type) {
|
|
case driver.Valuer:
|
|
row[i], _ = vvv.Value()
|
|
case sql.RawBytes:
|
|
row[i] = string(vvv)
|
|
default:
|
|
row[i] = vvv
|
|
}
|
|
}
|
|
|
|
data = append(data, row)
|
|
// 业务逻辑...
|
|
}
|
|
|
|
return &export.Data{
|
|
Title: titles,
|
|
Data: data,
|
|
}, nil
|
|
}
|
|
|
|
func getRowStruct(rows *sql.Rows) (title []string, t []interface{}, err error) {
|
|
titles, _ := rows.Columns()
|
|
|
|
tt, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("ColumnTypes: %v", err)
|
|
}
|
|
|
|
types := make([]reflect.Type, len(tt))
|
|
for i, tp := range tt {
|
|
st := tp.ScanType()
|
|
if st == nil {
|
|
return nil, nil, fmt.Errorf("scantype is null for column %q", tp.Name())
|
|
}
|
|
types[i] = st
|
|
}
|
|
|
|
values := make([]interface{}, len(tt))
|
|
for i := range values {
|
|
values[i] = reflect.New(types[i]).Interface()
|
|
}
|
|
|
|
return titles, values, nil
|
|
|
|
}
|