diff --git a/conform.go b/conform.go index c66a2ea..ea1ea84 100644 --- a/conform.go +++ b/conform.go @@ -185,6 +185,41 @@ func formatName(s string) string { return strings.Title(patterns["name"].FindString(first)) } +func getSliceElemType(t reflect.Type) reflect.Type { + var elType reflect.Type + if t.Kind() == reflect.Ptr { + elType = t.Elem().Elem() + } else { + elType = t.Elem() + } + + return elType +} + +func transformValue(tags string, val reflect.Value) reflect.Value { + if val.Kind() == reflect.Ptr && val.IsNil() { + return val + } + + var oldStr string + if val.Kind() == reflect.Ptr { + oldStr = val.Elem().String() + } else { + oldStr = val.String() + } + + newStr := transformString(oldStr, tags) + + var newVal reflect.Value + if val.Kind() == reflect.Ptr { + newVal = reflect.ValueOf(&newStr) + } else { + newVal = reflect.ValueOf(newStr) + } + + return newVal.Convert(val.Type()) +} + // Strings conforms strings based on reflection tags func Strings(iface interface{}) error { ifv := reflect.ValueOf(iface) @@ -201,15 +236,23 @@ func Strings(iface interface{}) error { switch el.Kind() { case reflect.Slice: if el.CanInterface() { - if slice, ok := el.Interface().([]string); ok { - for i, input := range slice { - tags := v.Tag.Get("conform") - slice[i] = transformString(input, tags) + elType := getSliceElemType(v.Type) + + // allow strings and string pointers + str := "" + if elType.ConvertibleTo(reflect.TypeOf(str)) || elType.ConvertibleTo(reflect.TypeOf(&str)) { + tags := v.Tag.Get("conform") + for i := 0; i < el.Len(); i++ { + el.Index(i).Set(transformValue(tags, el.Index(i))) } } else { val := reflect.ValueOf(el.Interface()) for i := 0; i < val.Len(); i++ { - Strings(val.Index(i).Addr().Interface()) + elVal := val.Index(i) + if elVal.Kind() != reflect.Ptr { + elVal = elVal.Addr() + } + Strings(elVal.Interface()) } } } diff --git a/conform_test.go b/conform_test.go index 5acaa0c..9c667ff 100644 --- a/conform_test.go +++ b/conform_test.go @@ -622,3 +622,139 @@ func (t *testSuite) TestMap() { assert.Equal("pickles", s.Catmap["cat1"].Name, "s.StructMap[cat1].Name should be trimmed") } + +func (t *testSuite) TestNilArrayPointerType() { + assert := assert.New(t.T()) + + type Post struct { + HashTags *[]string `conform:"trim"` + } + p := Post{ + } + + Strings(&p) + assert.Nil(p.HashTags, 0) +} + +func (t *testSuite) TestStringPointerArrayType() { + assert := assert.New(t.T()) + + type Post struct { + HashTags []*string `conform:"trim"` + } + h := " hashtag " + p := Post{ + HashTags: []*string{&h, nil}, + } + + Strings(&p) + assert.Len(p.HashTags, 2) + assert.Equal("hashtag", *p.HashTags[0]) + assert.Nil(p.HashTags[1]) +} + +func (t *testSuite) TestStringArrayPointerType() { + assert := assert.New(t.T()) + + type Post struct { + HashTagsPtr *[]string `conform:"trim"` + } + h := " hashtag " + p := Post{ + HashTagsPtr: &[]string{h}, + } + + Strings(&p) + assert.Len(*p.HashTagsPtr, 1) + assert.Equal("hashtag", (*p.HashTagsPtr)[0]) +} + +func (t *testSuite) TestStringPointerArrayPointerType() { + assert := assert.New(t.T()) + + type Post struct { + HashTagsPtr *[]*string `conform:"trim"` + } + h := " hashtag " + p := Post{ + HashTagsPtr: &[]*string{&h, nil}, + } + + Strings(&p) + assert.Len(*p.HashTagsPtr, 2) + assert.Equal("hashtag", *(*p.HashTagsPtr)[0]) + assert.Nil((*p.HashTagsPtr)[1]) +} + +func (t *testSuite) TestCustomStringPointerArrayType() { + assert := assert.New(t.T()) + + type String string + type Post struct { + HashTags []*String `conform:"trim"` + } + h := String(" hashtag ") + p := Post{ + HashTags: []*String{&h, nil}, + } + + Strings(&p) + assert.Len(p.HashTags, 2) + assert.Equal(String("hashtag"), *p.HashTags[0]) + assert.Nil(p.HashTags[1]) +} + +func (t *testSuite) TestCustomStringArrayPointerType() { + assert := assert.New(t.T()) + + type String string + type Post struct { + HashTagsPtr *[]String `conform:"trim"` + } + h := String(" hashtag ") + p := Post{ + HashTagsPtr: &[]String{h}, + } + + Strings(&p) + assert.Len(*p.HashTagsPtr, 1) + assert.Equal(String("hashtag"), (*p.HashTagsPtr)[0]) +} + +func (t *testSuite) TestCustomStringPointerArrayPointerType() { + assert := assert.New(t.T()) + + type String string + type Post struct { + HashTagsPtr *[]*String `conform:"trim"` + } + h := String(" hashtag ") + p := Post{ + HashTagsPtr: &[]*String{&h, nil}, + } + + Strings(&p) + assert.Len(*p.HashTagsPtr, 2) + assert.Equal(String("hashtag"), *(*p.HashTagsPtr)[0]) + assert.Nil((*p.HashTagsPtr)[1]) +} + +func (t *testSuite) TestEmbeddedArrayOfStructs() { + assert := assert.New(t.T()) + + type Bar struct { + Baz string `conform:"trim"` + } + type Foo struct { + Bars *[]*Bar + } + + f := Foo{ + Bars: &[]*Bar{ + {Baz: " baz "}, + }, + } + + Strings(&f) + assert.Equal("baz", (*f.Bars)[0].Baz) +}