package dbtool import ( "database/sql" "errors" "reflect" "strconv" "strings" "time" ) // ScanToStruct - scan query rows to struct func ScanToStruct(rows *sql.Rows, ss interface{}) (err error) { data, err := ResultToMap(rows) if err != nil { return err } if len(data) == 0 { return nil } err = MapToStruct(data[0], ss) return nil } func ScanToStructAll(rows *sql.Rows, ss interface{}) error { sliceVal := reflect.Indirect(reflect.ValueOf(ss)) if sliceVal.Kind() != reflect.Slice && sliceVal.Kind() != reflect.Map { return errors.New("need a pointer to a slice or a map") } data, err := ResultToMap(rows) if err != nil { return err } for i := range data { } return nil } func ResultToMap(r *sql.Rows) ([]map[string]interface{}, error) { result := make([]map[string]interface{}, 0) cols, err := r.Columns() if err != nil { return nil, err } vals := make([]sql.RawBytes, len(cols)) scans := make([]interface{}, len(vals)) for i := range vals { scans[i] = &vals[i] } for r.Next() { tmp := make(map[string]interface{}) err := r.Scan(scans...) if err != nil { return nil, err } for i, v := range vals { tmp[cols[i]] = v } result = append(result, tmp) } return result, nil } func MapToStruct(data map[string]interface{}, out interface{}) error { ss := reflect.ValueOf(out).Elem() for i := 0; i < ss.NumField(); i++ { tag := ss.Type().Field(i).Tag.Get("sql") name := ss.Type().Field(i).Name fname := strings.ToLower(name) if len(tag) > 0 { fname = tag } if fname == "-" { continue } d, ok := data[fname] if !ok { continue } switch ss.Field(i).Interface().(type) { case string: ss.Field(i).SetString(string(d.(sql.RawBytes))) case int, int8, int16, int32, int64: str := string(d.(sql.RawBytes)) tmpi, err := strconv.ParseInt(str, 10, 64) if err != nil { return err } ss.Field(i).SetInt(tmpi) case uint, uint8, uint16, uint32, uint64: str := string(d.(sql.RawBytes)) tmpi, err := strconv.ParseUint(str, 10, 64) if err != nil { return err } ss.Field(i).SetUint(tmpi) case bool: str := string(d.(sql.RawBytes)) tmpb, err := strconv.ParseBool(str) if err != nil { return err } ss.Field(i).SetBool(tmpb) case float32, float64: str := string(d.(sql.RawBytes)) tmpf, err := strconv.ParseFloat(str, 64) if err != nil { return err } ss.Field(i).SetFloat(tmpf) case time.Time: str := string(d.(sql.RawBytes)) t, err := time.Parse(time.RFC3339, str) if err != nil { return err } ss.Field(i).Set(reflect.ValueOf(t)) default: str := string(d.(sql.RawBytes)) ss.Field(i).Set(reflect.ValueOf(str)) } } return nil }