149 lines
3.0 KiB
Go
149 lines
3.0 KiB
Go
package dbtool
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ScanToStruct - scan query rows to struct
|
|
func ScanToStruct(rows *sql.Rows, ss interface{}) (ok bool, err error) {
|
|
data, err := ResultToMap(rows)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if len(data) == 0 {
|
|
return false, nil
|
|
}
|
|
err = MapToStruct(data[0], ss)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func ScanToStructAll(rows *sql.Rows, ss interface{}) error {
|
|
sliceVal := reflect.Indirect(reflect.ValueOf(ss))
|
|
if sliceVal.Kind() != reflect.Slice {
|
|
return errors.New("need a pointer to a slice ")
|
|
}
|
|
vt := sliceVal.Type().Elem()
|
|
if vt.Kind() != reflect.Struct {
|
|
return errors.New("need struct slice")
|
|
}
|
|
|
|
data, err := ResultToMap(rows)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := range data {
|
|
tmp := reflect.New(vt)
|
|
iface := tmp.Interface()
|
|
err = MapToStruct(data[i], iface)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sliceVal.Set(reflect.Append(sliceVal, reflect.ValueOf(iface).Elem()))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ResultToMap(r *sql.Rows) ([]map[string]interface{}, error) {
|
|
result := make([]map[string]interface{}, 0)
|
|
cols, err := r.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vals := make([]sql.RawBytes, len(cols))
|
|
scans := make([]interface{}, len(vals))
|
|
|
|
for i := range vals {
|
|
scans[i] = &vals[i]
|
|
}
|
|
|
|
for r.Next() {
|
|
tmp := make(map[string]interface{})
|
|
err := r.Scan(scans...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i, v := range vals {
|
|
tmp[cols[i]] = v
|
|
}
|
|
result = append(result, tmp)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func MapToStruct(data map[string]interface{}, out interface{}) error {
|
|
ss := reflect.ValueOf(out).Elem()
|
|
for i := 0; i < ss.NumField(); i++ {
|
|
tag := ss.Type().Field(i).Tag.Get("sql")
|
|
name := ss.Type().Field(i).Name
|
|
fname := strings.ToLower(name)
|
|
if len(tag) > 0 {
|
|
fname = tag
|
|
}
|
|
|
|
if fname == "-" {
|
|
continue
|
|
}
|
|
|
|
d, ok := data[fname]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
switch ss.Field(i).Interface().(type) {
|
|
case string:
|
|
ss.Field(i).SetString(string(d.(sql.RawBytes)))
|
|
case int, int8, int16, int32, int64:
|
|
str := string(d.(sql.RawBytes))
|
|
tmpi, err := strconv.ParseInt(str, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss.Field(i).SetInt(tmpi)
|
|
case uint, uint8, uint16, uint32, uint64:
|
|
str := string(d.(sql.RawBytes))
|
|
tmpi, err := strconv.ParseUint(str, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss.Field(i).SetUint(tmpi)
|
|
case bool:
|
|
str := string(d.(sql.RawBytes))
|
|
tmpb, err := strconv.ParseBool(str)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss.Field(i).SetBool(tmpb)
|
|
case float32, float64:
|
|
str := string(d.(sql.RawBytes))
|
|
tmpf, err := strconv.ParseFloat(str, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss.Field(i).SetFloat(tmpf)
|
|
case time.Time:
|
|
str := string(d.(sql.RawBytes))
|
|
t, err := time.Parse(time.RFC3339, str)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss.Field(i).Set(reflect.ValueOf(t))
|
|
default:
|
|
str := string(d.(sql.RawBytes))
|
|
ss.Field(i).Set(reflect.ValueOf(str))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|