Skip to content

Commit

Permalink
feat: added a new implementation for ReadQuery (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
noxecane authored Sep 13, 2020
1 parent 047177d commit 857f67a
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 100 deletions.
183 changes: 183 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package anansi

import (
"net/http"
"reflect"
"strconv"

"github.com/pkg/errors"
)

// ParseQuery converts a map of strings to strings to a struct. It uses
// the key and default struct tag to help it determine how to get keys and set defaults.
// Mind that it only supports the types
// - int(all bit sizes)
// - uint(all bit sizes)
// - float(all bit sizes)
// - string
// - ISO8601 time
func ParseQuery(query map[string]string, v interface{}) error {
result := reflect.ValueOf(v).Elem()
resultType := result.Type()

for i := 0; i < result.NumField(); i++ {
field := resultType.Field(i)
fieldVal := result.Field(i)
fieldType := field.Type
fieldKind := fieldVal.Kind()

// skip hidden fields
if field.PkgPath != "" {
continue
}

// get query parameter name
key := field.Tag.Get("key")
if key == "" {
key = field.Name
}

// for fields with default values
def := field.Tag.Get("default")

rawValue, ok := query[key]
if !ok {
if def == "" {
fieldVal.Set(reflect.Zero(fieldType))
continue
} else {
rawValue = def
}
}

// make sure we're not using a pointer
var ptr reflect.Value
if fieldKind == reflect.Ptr {
// get the underlying type of the pointer
fieldType = fieldType.Elem()
fieldKind = fieldType.Kind()

// create new pointer to hold the value
ptr = reflect.New(fieldType)
}

if !fieldVal.CanSet() {
return errors.Errorf("cannot set field %s", field.Name)
}

var out interface{}
var err error

switch fieldKind {
case reflect.Bool:
out, err = strconv.ParseBool(rawValue)
if err != nil {
return errors.Wrapf(err, "failed to parse bool %s", field.Name)
}
case reflect.Int:
i, err := strconv.ParseInt(rawValue, 10, fieldType.Bits())
if err != nil {
return errors.Wrapf(err, "failed to parse int %s", field.Name)
}
out = int(i)
case reflect.Int8:
i, err := strconv.ParseInt(rawValue, 10, 8)
if err != nil {
return errors.Wrapf(err, "failed to parse int8 %s", field.Name)
}
out = int8(i)
case reflect.Int16:
i, err := strconv.ParseInt(rawValue, 10, 16)
if err != nil {
return errors.Wrapf(err, "failed to parse int16 %s", field.Name)
}
out = int16(i)
case reflect.Int32:
i, err := strconv.ParseInt(rawValue, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse int32 %s", field.Name)
}
out = int32(i)
case reflect.Int64:
if out, err = strconv.ParseInt(rawValue, 10, 64); err != nil {
return errors.Wrapf(err, "failed to parse int64 %s", field.Name)
}
case reflect.Uint:
u, err := strconv.ParseUint(rawValue, 10, fieldType.Bits())
if err != nil {
return errors.Wrapf(err, "failed to parse int %s", field.Name)
}
out = uint(u)
case reflect.Uint8:
u, err := strconv.ParseUint(rawValue, 10, 8)
if err != nil {
return errors.Wrapf(err, "failed to parse int8 %s", field.Name)
}
out = uint8(u)
case reflect.Uint16:
u, err := strconv.ParseUint(rawValue, 10, 16)
if err != nil {
return errors.Wrapf(err, "failed to parse int16 %s", field.Name)
}
out = uint16(u)
case reflect.Uint32:
u, err := strconv.ParseUint(rawValue, 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to parse int32 %s", field.Name)
}
out = uint32(u)
case reflect.Uint64:
if out, err = strconv.ParseUint(rawValue, 10, 64); err != nil {
return errors.Wrapf(err, "failed to parse int64 %s", field.Name)
}
case reflect.Float32:
f, err := strconv.ParseFloat(rawValue, fieldType.Bits())
if err != nil {
return errors.Wrapf(err, "failed to parse float %s", field.Name)
}
out = float32(f)
case reflect.Float64:
if out, err = strconv.ParseFloat(rawValue, fieldType.Bits()); err != nil {
return errors.Wrapf(err, "failed to parse float %s", field.Name)
}
case reflect.String:
out = rawValue
case reflect.Struct:
// attempt to parse a date value
if out, err = FromISO(rawValue); err != nil {
return errors.Errorf("this function doesn't support %s for the field '%s'", fieldType, field.Name)
}
default:
return errors.Errorf("this function doesn't support %s for the field '%s'", fieldType, field.Name)
}

// if original kind is pointer, save as pointer value
if fieldVal.Kind() == reflect.Ptr {
// set value pointer is pointing to
reflect.Indirect(ptr).Set(reflect.ValueOf(out))
fieldVal.Set(ptr)
} else {
fieldVal.Set(reflect.ValueOf(out))
}
}

return nil
}

// ReadQuery reads the query parameters of a request into a struct
func ReadQuery(r *http.Request, v interface{}) {
raw := r.URL.Query()
qMap := make(map[string]string)

for k := range raw {
qMap[k] = raw.Get(k)
}

if err := ParseQuery(qMap, v); err != nil {
panic(APIError{
Code: http.StatusBadRequest,
Message: "We could not parse your request query.",
Err: err,
})
}
}
103 changes: 103 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package anansi

import (
"fmt"
"net/http"
"strconv"
"testing"
"time"

"syreclabs.com/go/faker"
)

func TestParseQuery(t *testing.T) {
type mapStruct struct {
Name string `key:"name"`
Age int32 `key:"age" default:"30"`
Limit *int `key:"limit"`
CreatedAt *time.Time `key:"date"`
}

t.Run("parses a simple map to a struct", func(t *testing.T) {
raw := map[string]string{"name": faker.Name().FirstName(), "age": faker.Number().Number(2)}
var sample mapStruct
if err := ParseQuery(raw, &sample); err != nil {
t.Fatal(err)
}

if raw["name"] != sample.Name {
t.Errorf("Expected name to be %s, got %s", raw["name"], sample.Name)
}

i, _ := strconv.Atoi(raw["age"])
if int32(i) != sample.Age {
t.Errorf("Expected age to be %s, got %d", raw["age"], sample.Age)
}
})

t.Run("parses pointer value", func(t *testing.T) {
raw := map[string]string{"date": "2020-09-01"}
var sample mapStruct
if err := ParseQuery(raw, &sample); err != nil {
t.Fatal(err)
}

if sample.CreatedAt == nil {
t.Fatal("Expected created_at to be set to a value, got nil")
}

if sample.CreatedAt.IsZero() {
t.Error("Expected created_at to be set to a value, got zero value")
}
})

t.Run("sets empty pointer value to nil", func(t *testing.T) {
raw := map[string]string{}
var sample mapStruct
if err := ParseQuery(raw, &sample); err != nil {
t.Fatal(err)
}

if sample.CreatedAt != nil {
t.Errorf("Expected created_at to be nil, got %v", sample.CreatedAt)
}
})

t.Run("zero value is not nil pointer", func(t *testing.T) {
raw := map[string]string{"limit": "0"}
var sample mapStruct
if err := ParseQuery(raw, &sample); err != nil {
t.Fatal(err)
}

if *sample.Limit != 0 {
t.Errorf("Expected age to be zero value, got %d", sample.Limit)
}
})
}

func TestReadQuery(t *testing.T) {
type myQuery struct {
Account string `key:"nuban"`
Start *time.Time `key:"from"`
End *time.Time `key:"to"`
}

req, err := http.NewRequest("GET", "https://sample.com", nil)
if err != nil {
t.Fatal(err)
}
nuban := faker.Number().Number(10)
req.URL.RawQuery = fmt.Sprintf("nuban=%s", nuban)

var sample myQuery
ReadQuery(req, &sample)

if sample.Account != nuban {
t.Errorf("Expected nuban to be %s, got %s", nuban, sample.Account)
}

if sample.Start != nil || sample.End != nil {
t.Error("Expected from and to nil")
}
}
56 changes: 0 additions & 56 deletions request_test.go

This file was deleted.

Loading

0 comments on commit 857f67a

Please sign in to comment.