diff --git a/database.go b/database.go deleted file mode 100644 index 636bab8..0000000 --- a/database.go +++ /dev/null @@ -1 +0,0 @@ -package database diff --git a/dbtool/dbtool.go b/dbtool/dbtool.go new file mode 100644 index 0000000..60812e7 --- /dev/null +++ b/dbtool/dbtool.go @@ -0,0 +1,135 @@ +package dbtool + +import ( + "database/sql" + "errors" + "reflect" + "strconv" + "strings" + "time" +) + +// ScanToStruct - scan query rows to struct +func ScanToStruct(rows *sql.Rows, ss interface{}) (err error) { + data, err := ResultToMap(rows) + if err != nil { + return err + } + if len(data) == 0 { + return nil + } + err = MapToStruct(data[0], ss) + return nil +} + +func ScanToStructAll(rows *sql.Rows, ss interface{}) error { + sliceVal := reflect.Indirect(reflect.ValueOf(ss)) + if sliceVal.Kind() != reflect.Slice && sliceVal.Kind() != reflect.Map { + return errors.New("need a pointer to a slice or a map") + } + + data, err := ResultToMap(rows) + if err != nil { + return err + } + + for i := range data { + + } + + return nil +} + +func ResultToMap(r *sql.Rows) ([]map[string]interface{}, error) { + result := make([]map[string]interface{}, 0) + cols, err := r.Columns() + if err != nil { + return nil, err + } + vals := make([]sql.RawBytes, len(cols)) + scans := make([]interface{}, len(vals)) + + for i := range vals { + scans[i] = &vals[i] + } + + for r.Next() { + tmp := make(map[string]interface{}) + err := r.Scan(scans...) + if err != nil { + return nil, err + } + for i, v := range vals { + tmp[cols[i]] = v + } + result = append(result, tmp) + } + + return result, nil +} + +func MapToStruct(data map[string]interface{}, out interface{}) error { + ss := reflect.ValueOf(out).Elem() + for i := 0; i < ss.NumField(); i++ { + tag := ss.Type().Field(i).Tag.Get("sql") + name := ss.Type().Field(i).Name + fname := strings.ToLower(name) + if len(tag) > 0 { + fname = tag + } + + if fname == "-" { + continue + } + + d, ok := data[fname] + if !ok { + continue + } + + switch ss.Field(i).Interface().(type) { + case string: + ss.Field(i).SetString(string(d.(sql.RawBytes))) + case int, int8, int16, int32, int64: + str := string(d.(sql.RawBytes)) + tmpi, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return err + } + ss.Field(i).SetInt(tmpi) + case uint, uint8, uint16, uint32, uint64: + str := string(d.(sql.RawBytes)) + tmpi, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return err + } + ss.Field(i).SetUint(tmpi) + case bool: + str := string(d.(sql.RawBytes)) + tmpb, err := strconv.ParseBool(str) + if err != nil { + return err + } + ss.Field(i).SetBool(tmpb) + case float32, float64: + str := string(d.(sql.RawBytes)) + tmpf, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + ss.Field(i).SetFloat(tmpf) + case time.Time: + str := string(d.(sql.RawBytes)) + t, err := time.Parse(time.RFC3339, str) + if err != nil { + return err + } + ss.Field(i).Set(reflect.ValueOf(t)) + default: + str := string(d.(sql.RawBytes)) + ss.Field(i).Set(reflect.ValueOf(str)) + } + } + + return nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..a985bd1 --- /dev/null +++ b/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + + _ "github.com/lib/pq" +) + +// FileT - file db schema +type FileT struct { + ID string `db:"id"` + Name string `db:"name"` +} + +func main() { + connStr := "postgres://postgres@localhost:5432/mystorage?sslmode=disable" + db, err := sql.Open("postgres", connStr) + handleError(err) + + rows, err := db.Query(`select id, name from "storage"."files" where "tmp" = $1 and "trash" = $2 limit 2`, false, false) + handleError(err) + + var id, name string + for rows.Next() { + rows.Scan(&id, &name) + fmt.Printf("id: %s, name: %s \n", id, name) + } +} + +func handleError(err error) { + if err != nil { + log.Fatal(err) + } +}