From 857f67afe6bc9494b374a5e7e7b6ffe7d4f44c28 Mon Sep 17 00:00:00 2001 From: Arewa Olakunle Date: Sun, 13 Sep 2020 11:11:17 +0100 Subject: [PATCH] feat: added a new implementation for ReadQuery (#37) --- query.go | 183 ++++++++++++++++++++++++++++++++++++++++++++++++ query_test.go | 103 +++++++++++++++++++++++++++ request_test.go | 56 --------------- requests.go | 44 ------------ time.go | 18 +++++ 5 files changed, 304 insertions(+), 100 deletions(-) create mode 100644 query.go create mode 100644 query_test.go delete mode 100644 request_test.go create mode 100644 time.go diff --git a/query.go b/query.go new file mode 100644 index 0000000..d1d59d7 --- /dev/null +++ b/query.go @@ -0,0 +1,183 @@ +package anansi + +import ( + "net/http" + "reflect" + "strconv" + + "github.com/pkg/errors" +) + +// ParseQuery converts a map of strings to strings to a struct. It uses +// the key and default struct tag to help it determine how to get keys and set defaults. +// Mind that it only supports the types +// - int(all bit sizes) +// - uint(all bit sizes) +// - float(all bit sizes) +// - string +// - ISO8601 time +func ParseQuery(query map[string]string, v interface{}) error { + result := reflect.ValueOf(v).Elem() + resultType := result.Type() + + for i := 0; i < result.NumField(); i++ { + field := resultType.Field(i) + fieldVal := result.Field(i) + fieldType := field.Type + fieldKind := fieldVal.Kind() + + // skip hidden fields + if field.PkgPath != "" { + continue + } + + // get query parameter name + key := field.Tag.Get("key") + if key == "" { + key = field.Name + } + + // for fields with default values + def := field.Tag.Get("default") + + rawValue, ok := query[key] + if !ok { + if def == "" { + fieldVal.Set(reflect.Zero(fieldType)) + continue + } else { + rawValue = def + } + } + + // make sure we're not using a pointer + var ptr reflect.Value + if fieldKind == reflect.Ptr { + // get the underlying type of the pointer + fieldType = fieldType.Elem() + fieldKind = fieldType.Kind() + + // create new pointer to hold the value + ptr = reflect.New(fieldType) + } + + if !fieldVal.CanSet() { + return errors.Errorf("cannot set field %s", field.Name) + } + + var out interface{} + var err error + + switch fieldKind { + case reflect.Bool: + out, err = strconv.ParseBool(rawValue) + if err != nil { + return errors.Wrapf(err, "failed to parse bool %s", field.Name) + } + case reflect.Int: + i, err := strconv.ParseInt(rawValue, 10, fieldType.Bits()) + if err != nil { + return errors.Wrapf(err, "failed to parse int %s", field.Name) + } + out = int(i) + case reflect.Int8: + i, err := strconv.ParseInt(rawValue, 10, 8) + if err != nil { + return errors.Wrapf(err, "failed to parse int8 %s", field.Name) + } + out = int8(i) + case reflect.Int16: + i, err := strconv.ParseInt(rawValue, 10, 16) + if err != nil { + return errors.Wrapf(err, "failed to parse int16 %s", field.Name) + } + out = int16(i) + case reflect.Int32: + i, err := strconv.ParseInt(rawValue, 10, 32) + if err != nil { + return errors.Wrapf(err, "failed to parse int32 %s", field.Name) + } + out = int32(i) + case reflect.Int64: + if out, err = strconv.ParseInt(rawValue, 10, 64); err != nil { + return errors.Wrapf(err, "failed to parse int64 %s", field.Name) + } + case reflect.Uint: + u, err := strconv.ParseUint(rawValue, 10, fieldType.Bits()) + if err != nil { + return errors.Wrapf(err, "failed to parse int %s", field.Name) + } + out = uint(u) + case reflect.Uint8: + u, err := strconv.ParseUint(rawValue, 10, 8) + if err != nil { + return errors.Wrapf(err, "failed to parse int8 %s", field.Name) + } + out = uint8(u) + case reflect.Uint16: + u, err := strconv.ParseUint(rawValue, 10, 16) + if err != nil { + return errors.Wrapf(err, "failed to parse int16 %s", field.Name) + } + out = uint16(u) + case reflect.Uint32: + u, err := strconv.ParseUint(rawValue, 10, 32) + if err != nil { + return errors.Wrapf(err, "failed to parse int32 %s", field.Name) + } + out = uint32(u) + case reflect.Uint64: + if out, err = strconv.ParseUint(rawValue, 10, 64); err != nil { + return errors.Wrapf(err, "failed to parse int64 %s", field.Name) + } + case reflect.Float32: + f, err := strconv.ParseFloat(rawValue, fieldType.Bits()) + if err != nil { + return errors.Wrapf(err, "failed to parse float %s", field.Name) + } + out = float32(f) + case reflect.Float64: + if out, err = strconv.ParseFloat(rawValue, fieldType.Bits()); err != nil { + return errors.Wrapf(err, "failed to parse float %s", field.Name) + } + case reflect.String: + out = rawValue + case reflect.Struct: + // attempt to parse a date value + if out, err = FromISO(rawValue); err != nil { + return errors.Errorf("this function doesn't support %s for the field '%s'", fieldType, field.Name) + } + default: + return errors.Errorf("this function doesn't support %s for the field '%s'", fieldType, field.Name) + } + + // if original kind is pointer, save as pointer value + if fieldVal.Kind() == reflect.Ptr { + // set value pointer is pointing to + reflect.Indirect(ptr).Set(reflect.ValueOf(out)) + fieldVal.Set(ptr) + } else { + fieldVal.Set(reflect.ValueOf(out)) + } + } + + return nil +} + +// ReadQuery reads the query parameters of a request into a struct +func ReadQuery(r *http.Request, v interface{}) { + raw := r.URL.Query() + qMap := make(map[string]string) + + for k := range raw { + qMap[k] = raw.Get(k) + } + + if err := ParseQuery(qMap, v); err != nil { + panic(APIError{ + Code: http.StatusBadRequest, + Message: "We could not parse your request query.", + Err: err, + }) + } +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..32fce9a --- /dev/null +++ b/query_test.go @@ -0,0 +1,103 @@ +package anansi + +import ( + "fmt" + "net/http" + "strconv" + "testing" + "time" + + "syreclabs.com/go/faker" +) + +func TestParseQuery(t *testing.T) { + type mapStruct struct { + Name string `key:"name"` + Age int32 `key:"age" default:"30"` + Limit *int `key:"limit"` + CreatedAt *time.Time `key:"date"` + } + + t.Run("parses a simple map to a struct", func(t *testing.T) { + raw := map[string]string{"name": faker.Name().FirstName(), "age": faker.Number().Number(2)} + var sample mapStruct + if err := ParseQuery(raw, &sample); err != nil { + t.Fatal(err) + } + + if raw["name"] != sample.Name { + t.Errorf("Expected name to be %s, got %s", raw["name"], sample.Name) + } + + i, _ := strconv.Atoi(raw["age"]) + if int32(i) != sample.Age { + t.Errorf("Expected age to be %s, got %d", raw["age"], sample.Age) + } + }) + + t.Run("parses pointer value", func(t *testing.T) { + raw := map[string]string{"date": "2020-09-01"} + var sample mapStruct + if err := ParseQuery(raw, &sample); err != nil { + t.Fatal(err) + } + + if sample.CreatedAt == nil { + t.Fatal("Expected created_at to be set to a value, got nil") + } + + if sample.CreatedAt.IsZero() { + t.Error("Expected created_at to be set to a value, got zero value") + } + }) + + t.Run("sets empty pointer value to nil", func(t *testing.T) { + raw := map[string]string{} + var sample mapStruct + if err := ParseQuery(raw, &sample); err != nil { + t.Fatal(err) + } + + if sample.CreatedAt != nil { + t.Errorf("Expected created_at to be nil, got %v", sample.CreatedAt) + } + }) + + t.Run("zero value is not nil pointer", func(t *testing.T) { + raw := map[string]string{"limit": "0"} + var sample mapStruct + if err := ParseQuery(raw, &sample); err != nil { + t.Fatal(err) + } + + if *sample.Limit != 0 { + t.Errorf("Expected age to be zero value, got %d", sample.Limit) + } + }) +} + +func TestReadQuery(t *testing.T) { + type myQuery struct { + Account string `key:"nuban"` + Start *time.Time `key:"from"` + End *time.Time `key:"to"` + } + + req, err := http.NewRequest("GET", "https://sample.com", nil) + if err != nil { + t.Fatal(err) + } + nuban := faker.Number().Number(10) + req.URL.RawQuery = fmt.Sprintf("nuban=%s", nuban) + + var sample myQuery + ReadQuery(req, &sample) + + if sample.Account != nuban { + t.Errorf("Expected nuban to be %s, got %s", nuban, sample.Account) + } + + if sample.Start != nil || sample.End != nil { + t.Error("Expected from and to nil") + } +} diff --git a/request_test.go b/request_test.go deleted file mode 100644 index f2c0a07..0000000 --- a/request_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package anansi - -import ( - "net/http" - "strconv" - "testing" - - ozzo "github.com/go-ozzo/ozzo-validation/v4" - "syreclabs.com/go/faker" -) - -type userQuery struct { - Name string `json:"name"` - Age int `json:"age"` -} - -func (u *userQuery) Validate() error { - return ozzo.ValidateStruct(u, - ozzo.Field(&u.Name, ozzo.Required), - ozzo.Field(&u.Age, ozzo.Required), - ) -} - -func TestReadQuery(t *testing.T) { - fName := faker.Name().FirstName() - ageOne := faker.Number().Number(2) - ageTwo := faker.Number().Number(2) - - req, err := http.NewRequest("GET", "https://google.com", nil) - if err != nil { - t.Fatal(err) - } - q := req.URL.Query() - - q.Add("name", fName) - q.Add("age", ageOne) - q.Add("age", ageTwo) - - req.URL.RawQuery = q.Encode() - - uQ := new(userQuery) - ReadQuery(req, uQ) - - if uQ.Name != fName { - t.Errorf("Expected name %s. got name %s", fName, uQ.Name) - } - - uQAge, err := strconv.Atoi(ageOne) - if err != nil { - t.Fatal(err) - } - - if uQ.Age != uQAge { - t.Errorf("Expected name %s. got name %s", fName, uQ.Name) - } -} diff --git a/requests.go b/requests.go index 5f7cb6e..8a56ebf 100644 --- a/requests.go +++ b/requests.go @@ -13,7 +13,6 @@ import ( "github.com/go-chi/chi" ozzo "github.com/go-ozzo/ozzo-validation/v4" - "github.com/mitchellh/mapstructure" ) // ReadBody extracts the bytes in a request body without destroying the contents of the body @@ -78,49 +77,6 @@ func ReadJSON(r *http.Request, v interface{}) { } } -// ReadQuery reads the requests URL query parameters into a struct. -// It doesn't support multi-value parameters -func ReadQuery(r *http.Request, v interface{}) { - raw := r.URL.Query() - qMap := make(map[string]string) - - for k := range raw { - qMap[k] = raw.Get(k) - } - - // convert claims data map to struct - config := &mapstructure.DecoderConfig{ - Result: v, - TagName: "json", - WeaklyTypedInput: true, - } - decoder, err := mapstructure.NewDecoder(config) - - if err != nil { - panic(APIError{ - Code: http.StatusBadRequest, - Message: "We cannot parse your request body.", - Err: err, - }) - } - - if err := decoder.Decode(qMap); err != nil { - panic(APIError{ - Code: http.StatusBadRequest, - Message: "We cannot parse your request body.", - Err: err, - }) - } - - if err := ozzo.Validate(v); err != nil { - panic(APIError{ - Code: http.StatusBadRequest, - Message: "We could not validate your request.", - Meta: err, - }) - } -} - // IDParam extracts a uint URL parameter from the given request func IDParam(r *http.Request, name string) uint { param := chi.URLParam(r, name) diff --git a/time.go b/time.go new file mode 100644 index 0000000..319d8b0 --- /dev/null +++ b/time.go @@ -0,0 +1,18 @@ +package anansi + +import "time" + +var ISO_FORMATS = [3]string{"2006-01-02T15:04:05.000Z", "2006-01-02", "2006-01-02T15:04:05"} + +func FromISO(date string) (time.Time, error) { + var err error + + for _, format := range ISO_FORMATS { + createdAt, err := time.Parse(format, date) + if err == nil { + return createdAt, nil + } + } + + return time.Time{}, err +}