package dbtool import ( "database/sql" "errors" "reflect" "strconv" "strings" "time" ) // ScanToStruct - scan query rows to struct func ScanToStruct(rows *sql.Rows, ss interface{}) (ok bool, err error) { data, err := ResultToMap(rows) if err != nil { return false, err } if len(data) == 0 { return false, nil } err = MapToStruct(data[0], ss) if err != nil { return false, err } return true, nil } // ScanToStructAll - func ScanToStructAll(rows *sql.Rows, ss interface{}) error { sliceVal := reflect.Indirect(reflect.ValueOf(ss)) if sliceVal.Kind() != reflect.Slice { return errors.New("need a pointer to a slice ") } vt := sliceVal.Type().Elem() if vt.Kind() != reflect.Struct { return errors.New("need struct slice") } data, err := ResultToMap(rows) if err != nil { return err } for i := range data { tmp := reflect.New(vt) iface := tmp.Interface() err = MapToStruct(data[i], iface) if err != nil { return err } sliceVal.Set(reflect.Append(sliceVal, reflect.ValueOf(iface).Elem())) } return nil } // ResultToMap - func ResultToMap(r *sql.Rows) ([]map[string]interface{}, error) { result := make([]map[string]interface{}, 0) cols, err := r.Columns() if err != nil { return nil, err } vals := make([]sql.RawBytes, len(cols)) scans := make([]interface{}, len(vals)) for i := range vals { scans[i] = &vals[i] } for r.Next() { tmp := make(map[string]interface{}) err := r.Scan(scans...) if err != nil { return nil, err } for i, v := range vals { tmp[cols[i]] = v vals[i] = nil } result = append(result, tmp) } return result, nil } // MapToStruct - func MapToStruct(data map[string]interface{}, out interface{}) error { ss := reflect.ValueOf(out).Elem() for i := 0; i < ss.NumField(); i++ { tag := ss.Type().Field(i).Tag.Get("sql") name := ss.Type().Field(i).Name fname := strings.ToLower(name) if len(tag) > 0 { fname = tag } if fname == "-" { continue } d, ok := data[fname] if !ok { continue } switch ss.Field(i).Interface().(type) { case string: ss.Field(i).SetString(string(d.(sql.RawBytes))) case int, int8, int16, int32, int64: str := string(d.(sql.RawBytes)) tmpi, err := strconv.ParseInt(str, 10, 64) if err != nil { return err } ss.Field(i).SetInt(tmpi) case uint, uint8, uint16, uint32, uint64: str := string(d.(sql.RawBytes)) tmpi, err := strconv.ParseUint(str, 10, 64) if err != nil { return err } ss.Field(i).SetUint(tmpi) case bool: str := string(d.(sql.RawBytes)) tmpb, err := strconv.ParseBool(str) if err != nil { return err } ss.Field(i).SetBool(tmpb) case float32, float64: str := string(d.(sql.RawBytes)) tmpf, err := strconv.ParseFloat(str, 64) if err != nil { return err } ss.Field(i).SetFloat(tmpf) case time.Time: str := string(d.(sql.RawBytes)) t, err := time.Parse(time.RFC3339, str) if err != nil { return err } ss.Field(i).Set(reflect.ValueOf(t)) default: str := string(d.(sql.RawBytes)) ss.Field(i).Set(reflect.ValueOf(str)) } } return nil }