Skip to content

Commit 053a094

Browse files
committedJun 28, 2021
refactor splitTag function (#1960)
Reviewed-on: https://gitea.com/xorm/xorm/pulls/1960 Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com> Co-committed-by: Lunny Xiao <xiaolunwen@gmail.com>
1 parent 44f892f commit 053a094

File tree

5 files changed

+774
-243
lines changed

5 files changed

+774
-243
lines changed
 

‎.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@ test.db.sql
3636
*coverage.out
3737
test.db
3838
integrations/*.sql
39-
integrations/test_sqlite*
39+
integrations/test_sqlite*
40+
cover.out

‎tags/parser.go

+150-180
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package tags
77
import (
88
"encoding/gob"
99
"errors"
10-
"fmt"
1110
"reflect"
1211
"strings"
1312
"sync"
@@ -23,7 +22,7 @@ import (
2322

2423
var (
2524
// ErrUnsupportedType represents an unsupported type error
26-
ErrUnsupportedType = errors.New("Unsupported type")
25+
ErrUnsupportedType = errors.New("unsupported type")
2726
)
2827

2928
// Parser represents a parser for xorm tag
@@ -125,6 +124,145 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
125124
}
126125
}
127126

127+
var ErrIgnoreField = errors.New("field will be ignored")
128+
129+
func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
130+
var sqlType schemas.SQLType
131+
if fieldValue.CanAddr() {
132+
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
133+
sqlType = schemas.SQLType{Name: schemas.Text}
134+
}
135+
}
136+
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
137+
sqlType = schemas.SQLType{Name: schemas.Text}
138+
} else {
139+
sqlType = schemas.Type2SQLType(field.Type)
140+
}
141+
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
142+
field.Name, sqlType, sqlType.DefaultLength,
143+
sqlType.DefaultLength2, true)
144+
145+
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
146+
col.IsAutoIncrement = true
147+
col.IsPrimaryKey = true
148+
col.Nullable = false
149+
}
150+
return col, nil
151+
}
152+
153+
func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
154+
var col = &schemas.Column{
155+
FieldName: field.Name,
156+
Nullable: true,
157+
IsPrimaryKey: false,
158+
IsAutoIncrement: false,
159+
MapType: schemas.TWOSIDES,
160+
Indexes: make(map[string]int),
161+
DefaultIsEmpty: true,
162+
}
163+
164+
var ctx = Context{
165+
table: table,
166+
col: col,
167+
fieldValue: fieldValue,
168+
indexNames: make(map[string]int),
169+
parser: parser,
170+
}
171+
172+
for j, tag := range tags {
173+
if ctx.ignoreNext {
174+
ctx.ignoreNext = false
175+
continue
176+
}
177+
178+
ctx.tag = tag
179+
ctx.tagUname = strings.ToUpper(tag.name)
180+
181+
if j > 0 {
182+
ctx.preTag = strings.ToUpper(tags[j-1].name)
183+
}
184+
if j < len(tags)-1 {
185+
ctx.nextTag = tags[j+1].name
186+
} else {
187+
ctx.nextTag = ""
188+
}
189+
190+
if h, ok := parser.handlers[ctx.tagUname]; ok {
191+
if err := h(&ctx); err != nil {
192+
return nil, err
193+
}
194+
} else {
195+
if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") {
196+
col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1]
197+
} else {
198+
col.Name = ctx.tag.name
199+
}
200+
}
201+
202+
if ctx.hasCacheTag {
203+
if parser.cacherMgr.GetDefaultCacher() != nil {
204+
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
205+
} else {
206+
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
207+
}
208+
}
209+
if ctx.hasNoCacheTag {
210+
parser.cacherMgr.SetCacher(table.Name, nil)
211+
}
212+
}
213+
214+
if col.SQLType.Name == "" {
215+
col.SQLType = schemas.Type2SQLType(field.Type)
216+
}
217+
parser.dialect.SQLType(col)
218+
if col.Length == 0 {
219+
col.Length = col.SQLType.DefaultLength
220+
}
221+
if col.Length2 == 0 {
222+
col.Length2 = col.SQLType.DefaultLength2
223+
}
224+
if col.Name == "" {
225+
col.Name = parser.columnMapper.Obj2Table(field.Name)
226+
}
227+
228+
if ctx.isUnique {
229+
ctx.indexNames[col.Name] = schemas.UniqueType
230+
} else if ctx.isIndex {
231+
ctx.indexNames[col.Name] = schemas.IndexType
232+
}
233+
234+
for indexName, indexType := range ctx.indexNames {
235+
addIndex(indexName, table, col, indexType)
236+
}
237+
238+
return col, nil
239+
}
240+
241+
func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
242+
var (
243+
tag = field.Tag
244+
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
245+
)
246+
if ormTagStr == "-" {
247+
return nil, ErrIgnoreField
248+
}
249+
if ormTagStr == "" {
250+
return parser.parseFieldWithNoTag(field, fieldValue)
251+
}
252+
tags, err := splitTag(ormTagStr)
253+
if err != nil {
254+
return nil, err
255+
}
256+
return parser.parseFieldWithTags(table, field, fieldValue, tags)
257+
}
258+
259+
func isNotTitle(n string) bool {
260+
for _, c := range n {
261+
return unicode.IsLower(c)
262+
}
263+
return true
264+
}
265+
128266
// Parse parses a struct as a table information
129267
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
130268
t := v.Type()
@@ -140,193 +278,25 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
140278
table.Type = t
141279
table.Name = names.GetTableName(parser.tableMapper, v)
142280

143-
var idFieldColName string
144-
var hasCacheTag, hasNoCacheTag bool
145-
146281
for i := 0; i < t.NumField(); i++ {
147-
var isUnexportField bool
148-
for _, c := range t.Field(i).Name {
149-
if unicode.IsLower(c) {
150-
isUnexportField = true
151-
}
152-
break
153-
}
154-
if isUnexportField {
282+
if isNotTitle(t.Field(i).Name) {
155283
continue
156284
}
157285

158-
tag := t.Field(i).Tag
159-
ormTagStr := tag.Get(parser.identifier)
160-
var col *schemas.Column
161-
fieldValue := v.Field(i)
162-
fieldType := fieldValue.Type()
163-
164-
if ormTagStr != "" {
165-
col = &schemas.Column{
166-
FieldName: t.Field(i).Name,
167-
Nullable: true,
168-
IsPrimaryKey: false,
169-
IsAutoIncrement: false,
170-
MapType: schemas.TWOSIDES,
171-
Indexes: make(map[string]int),
172-
DefaultIsEmpty: true,
173-
}
174-
tags := splitTag(ormTagStr)
175-
176-
if len(tags) > 0 {
177-
if tags[0] == "-" {
178-
continue
179-
}
180-
181-
var ctx = Context{
182-
table: table,
183-
col: col,
184-
fieldValue: fieldValue,
185-
indexNames: make(map[string]int),
186-
parser: parser,
187-
}
188-
189-
if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") {
190-
pStart := strings.Index(tags[0], "(")
191-
if pStart > -1 && strings.HasSuffix(tags[0], ")") {
192-
var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool {
193-
return r == '\'' || r == '"'
194-
})
195-
196-
ctx.params = []string{tagPrefix}
197-
}
198-
199-
if err := ExtendsTagHandler(&ctx); err != nil {
200-
return nil, err
201-
}
202-
continue
203-
}
204-
205-
for j, key := range tags {
206-
if ctx.ignoreNext {
207-
ctx.ignoreNext = false
208-
continue
209-
}
210-
211-
k := strings.ToUpper(key)
212-
ctx.tagName = k
213-
ctx.params = []string{}
214-
215-
pStart := strings.Index(k, "(")
216-
if pStart == 0 {
217-
return nil, errors.New("( could not be the first character")
218-
}
219-
if pStart > -1 {
220-
if !strings.HasSuffix(k, ")") {
221-
return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key)
222-
}
223-
224-
ctx.tagName = k[:pStart]
225-
ctx.params = strings.Split(key[pStart+1:len(k)-1], ",")
226-
}
227-
228-
if j > 0 {
229-
ctx.preTag = strings.ToUpper(tags[j-1])
230-
}
231-
if j < len(tags)-1 {
232-
ctx.nextTag = tags[j+1]
233-
} else {
234-
ctx.nextTag = ""
235-
}
236-
237-
if h, ok := parser.handlers[ctx.tagName]; ok {
238-
if err := h(&ctx); err != nil {
239-
return nil, err
240-
}
241-
} else {
242-
if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
243-
col.Name = key[1 : len(key)-1]
244-
} else {
245-
col.Name = key
246-
}
247-
}
248-
249-
if ctx.hasCacheTag {
250-
hasCacheTag = true
251-
}
252-
if ctx.hasNoCacheTag {
253-
hasNoCacheTag = true
254-
}
255-
}
256-
257-
if col.SQLType.Name == "" {
258-
col.SQLType = schemas.Type2SQLType(fieldType)
259-
}
260-
parser.dialect.SQLType(col)
261-
if col.Length == 0 {
262-
col.Length = col.SQLType.DefaultLength
263-
}
264-
if col.Length2 == 0 {
265-
col.Length2 = col.SQLType.DefaultLength2
266-
}
267-
if col.Name == "" {
268-
col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name)
269-
}
270-
271-
if ctx.isUnique {
272-
ctx.indexNames[col.Name] = schemas.UniqueType
273-
} else if ctx.isIndex {
274-
ctx.indexNames[col.Name] = schemas.IndexType
275-
}
276-
277-
for indexName, indexType := range ctx.indexNames {
278-
addIndex(indexName, table, col, indexType)
279-
}
280-
}
281-
} else {
282-
var sqlType schemas.SQLType
283-
if fieldValue.CanAddr() {
284-
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
285-
sqlType = schemas.SQLType{Name: schemas.Text}
286-
}
287-
}
288-
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
289-
sqlType = schemas.SQLType{Name: schemas.Text}
290-
} else {
291-
sqlType = schemas.Type2SQLType(fieldType)
292-
}
293-
col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name),
294-
t.Field(i).Name, sqlType, sqlType.DefaultLength,
295-
sqlType.DefaultLength2, true)
286+
var (
287+
field = t.Field(i)
288+
fieldValue = v.Field(i)
289+
)
296290

297-
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
298-
idFieldColName = col.Name
299-
}
300-
}
301-
if col.IsAutoIncrement {
302-
col.Nullable = false
291+
col, err := parser.parseField(table, field, fieldValue)
292+
if err == ErrIgnoreField {
293+
continue
294+
} else if err != nil {
295+
return nil, err
303296
}
304297

305298
table.AddColumn(col)
306299
} // end for
307300

308-
if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
309-
col := table.GetColumn(idFieldColName)
310-
col.IsPrimaryKey = true
311-
col.IsAutoIncrement = true
312-
col.Nullable = false
313-
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
314-
table.AutoIncrement = col.Name
315-
}
316-
317-
if hasCacheTag {
318-
if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided
319-
//engine.logger.Info("enable cache on table:", table.Name)
320-
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
321-
} else {
322-
//engine.logger.Info("enable LRU cache on table:", table.Name)
323-
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
324-
}
325-
}
326-
if hasNoCacheTag {
327-
//engine.logger.Info("disable cache on table:", table.Name)
328-
parser.cacherMgr.SetCacher(table.Name, nil)
329-
}
330-
331301
return table, nil
332302
}

‎tags/parser_test.go

+455-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ package tags
66

77
import (
88
"reflect"
9+
"strings"
910
"testing"
11+
"time"
1012

11-
"github.com/stretchr/testify/assert"
1213
"xorm.io/xorm/caches"
1314
"xorm.io/xorm/dialects"
1415
"xorm.io/xorm/names"
16+
"xorm.io/xorm/schemas"
17+
18+
"github.com/stretchr/testify/assert"
1519
)
1620

1721
type ParseTableName1 struct{}
@@ -80,21 +84,469 @@ func TestParseWithOtherIdentifier(t *testing.T) {
8084
parser := NewParser(
8185
"xorm",
8286
dialects.QueryDialect("mysql"),
83-
names.GonicMapper{},
87+
names.SameMapper{},
8488
names.SnakeMapper{},
8589
caches.NewManager(),
8690
)
8791

8892
type StructWithDBTag struct {
8993
FieldFoo string `db:"foo"`
9094
}
95+
9196
parser.SetIdentifier("db")
9297
table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag)))
9398
assert.NoError(t, err)
94-
assert.EqualValues(t, "struct_with_db_tag", table.Name)
99+
assert.EqualValues(t, "StructWithDBTag", table.Name)
95100
assert.EqualValues(t, 1, len(table.Columns()))
96101

97102
for _, col := range table.Columns() {
98103
assert.EqualValues(t, "foo", col.Name)
99104
}
100105
}
106+
107+
func TestParseWithIgnore(t *testing.T) {
108+
parser := NewParser(
109+
"db",
110+
dialects.QueryDialect("mysql"),
111+
names.SameMapper{},
112+
names.SnakeMapper{},
113+
caches.NewManager(),
114+
)
115+
116+
type StructWithIgnoreTag struct {
117+
FieldFoo string `db:"-"`
118+
}
119+
120+
table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag)))
121+
assert.NoError(t, err)
122+
assert.EqualValues(t, "StructWithIgnoreTag", table.Name)
123+
assert.EqualValues(t, 0, len(table.Columns()))
124+
}
125+
126+
func TestParseWithAutoincrement(t *testing.T) {
127+
parser := NewParser(
128+
"db",
129+
dialects.QueryDialect("mysql"),
130+
names.SnakeMapper{},
131+
names.GonicMapper{},
132+
caches.NewManager(),
133+
)
134+
135+
type StructWithAutoIncrement struct {
136+
ID int64
137+
}
138+
139+
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement)))
140+
assert.NoError(t, err)
141+
assert.EqualValues(t, "struct_with_auto_increment", table.Name)
142+
assert.EqualValues(t, 1, len(table.Columns()))
143+
assert.EqualValues(t, "id", table.Columns()[0].Name)
144+
assert.True(t, table.Columns()[0].IsAutoIncrement)
145+
assert.True(t, table.Columns()[0].IsPrimaryKey)
146+
}
147+
148+
func TestParseWithAutoincrement2(t *testing.T) {
149+
parser := NewParser(
150+
"db",
151+
dialects.QueryDialect("mysql"),
152+
names.SnakeMapper{},
153+
names.GonicMapper{},
154+
caches.NewManager(),
155+
)
156+
157+
type StructWithAutoIncrement2 struct {
158+
ID int64 `db:"pk autoincr"`
159+
}
160+
161+
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2)))
162+
assert.NoError(t, err)
163+
assert.EqualValues(t, "struct_with_auto_increment2", table.Name)
164+
assert.EqualValues(t, 1, len(table.Columns()))
165+
assert.EqualValues(t, "id", table.Columns()[0].Name)
166+
assert.True(t, table.Columns()[0].IsAutoIncrement)
167+
assert.True(t, table.Columns()[0].IsPrimaryKey)
168+
assert.False(t, table.Columns()[0].Nullable)
169+
}
170+
171+
func TestParseWithNullable(t *testing.T) {
172+
parser := NewParser(
173+
"db",
174+
dialects.QueryDialect("mysql"),
175+
names.SnakeMapper{},
176+
names.GonicMapper{},
177+
caches.NewManager(),
178+
)
179+
180+
type StructWithNullable struct {
181+
Name string `db:"notnull"`
182+
FullName string `db:"null comment('column comment,字段注释')"`
183+
}
184+
185+
table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable)))
186+
assert.NoError(t, err)
187+
assert.EqualValues(t, "struct_with_nullable", table.Name)
188+
assert.EqualValues(t, 2, len(table.Columns()))
189+
assert.EqualValues(t, "name", table.Columns()[0].Name)
190+
assert.EqualValues(t, "full_name", table.Columns()[1].Name)
191+
assert.False(t, table.Columns()[0].Nullable)
192+
assert.True(t, table.Columns()[1].Nullable)
193+
assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment)
194+
}
195+
196+
func TestParseWithTimes(t *testing.T) {
197+
parser := NewParser(
198+
"db",
199+
dialects.QueryDialect("mysql"),
200+
names.SnakeMapper{},
201+
names.GonicMapper{},
202+
caches.NewManager(),
203+
)
204+
205+
type StructWithTimes struct {
206+
Name string `db:"notnull"`
207+
CreatedAt time.Time `db:"created"`
208+
UpdatedAt time.Time `db:"updated"`
209+
DeletedAt time.Time `db:"deleted"`
210+
}
211+
212+
table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes)))
213+
assert.NoError(t, err)
214+
assert.EqualValues(t, "struct_with_times", table.Name)
215+
assert.EqualValues(t, 4, len(table.Columns()))
216+
assert.EqualValues(t, "name", table.Columns()[0].Name)
217+
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
218+
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
219+
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
220+
assert.False(t, table.Columns()[0].Nullable)
221+
assert.True(t, table.Columns()[1].Nullable)
222+
assert.True(t, table.Columns()[1].IsCreated)
223+
assert.True(t, table.Columns()[2].Nullable)
224+
assert.True(t, table.Columns()[2].IsUpdated)
225+
assert.True(t, table.Columns()[3].Nullable)
226+
assert.True(t, table.Columns()[3].IsDeleted)
227+
}
228+
229+
func TestParseWithExtends(t *testing.T) {
230+
parser := NewParser(
231+
"db",
232+
dialects.QueryDialect("mysql"),
233+
names.SnakeMapper{},
234+
names.GonicMapper{},
235+
caches.NewManager(),
236+
)
237+
238+
type StructWithEmbed struct {
239+
Name string
240+
CreatedAt time.Time `db:"created"`
241+
UpdatedAt time.Time `db:"updated"`
242+
DeletedAt time.Time `db:"deleted"`
243+
}
244+
245+
type StructWithExtends struct {
246+
SW StructWithEmbed `db:"extends"`
247+
}
248+
249+
table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends)))
250+
assert.NoError(t, err)
251+
assert.EqualValues(t, "struct_with_extends", table.Name)
252+
assert.EqualValues(t, 4, len(table.Columns()))
253+
assert.EqualValues(t, "name", table.Columns()[0].Name)
254+
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
255+
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
256+
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
257+
assert.True(t, table.Columns()[0].Nullable)
258+
assert.True(t, table.Columns()[1].Nullable)
259+
assert.True(t, table.Columns()[1].IsCreated)
260+
assert.True(t, table.Columns()[2].Nullable)
261+
assert.True(t, table.Columns()[2].IsUpdated)
262+
assert.True(t, table.Columns()[3].Nullable)
263+
assert.True(t, table.Columns()[3].IsDeleted)
264+
}
265+
266+
func TestParseWithCache(t *testing.T) {
267+
parser := NewParser(
268+
"db",
269+
dialects.QueryDialect("mysql"),
270+
names.SnakeMapper{},
271+
names.GonicMapper{},
272+
caches.NewManager(),
273+
)
274+
275+
type StructWithCache struct {
276+
Name string `db:"cache"`
277+
}
278+
279+
table, err := parser.Parse(reflect.ValueOf(new(StructWithCache)))
280+
assert.NoError(t, err)
281+
assert.EqualValues(t, "struct_with_cache", table.Name)
282+
assert.EqualValues(t, 1, len(table.Columns()))
283+
assert.EqualValues(t, "name", table.Columns()[0].Name)
284+
assert.True(t, table.Columns()[0].Nullable)
285+
cacher := parser.cacherMgr.GetCacher(table.Name)
286+
assert.NotNil(t, cacher)
287+
}
288+
289+
func TestParseWithNoCache(t *testing.T) {
290+
parser := NewParser(
291+
"db",
292+
dialects.QueryDialect("mysql"),
293+
names.SnakeMapper{},
294+
names.GonicMapper{},
295+
caches.NewManager(),
296+
)
297+
298+
type StructWithNoCache struct {
299+
Name string `db:"nocache"`
300+
}
301+
302+
table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache)))
303+
assert.NoError(t, err)
304+
assert.EqualValues(t, "struct_with_no_cache", table.Name)
305+
assert.EqualValues(t, 1, len(table.Columns()))
306+
assert.EqualValues(t, "name", table.Columns()[0].Name)
307+
assert.True(t, table.Columns()[0].Nullable)
308+
cacher := parser.cacherMgr.GetCacher(table.Name)
309+
assert.Nil(t, cacher)
310+
}
311+
312+
func TestParseWithEnum(t *testing.T) {
313+
parser := NewParser(
314+
"db",
315+
dialects.QueryDialect("mysql"),
316+
names.SnakeMapper{},
317+
names.GonicMapper{},
318+
caches.NewManager(),
319+
)
320+
321+
type StructWithEnum struct {
322+
Name string `db:"enum('alice', 'bob')"`
323+
}
324+
325+
table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum)))
326+
assert.NoError(t, err)
327+
assert.EqualValues(t, "struct_with_enum", table.Name)
328+
assert.EqualValues(t, 1, len(table.Columns()))
329+
assert.EqualValues(t, "name", table.Columns()[0].Name)
330+
assert.True(t, table.Columns()[0].Nullable)
331+
assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name))
332+
assert.EqualValues(t, map[string]int{
333+
"alice": 0,
334+
"bob": 1,
335+
}, table.Columns()[0].EnumOptions)
336+
}
337+
338+
func TestParseWithSet(t *testing.T) {
339+
parser := NewParser(
340+
"db",
341+
dialects.QueryDialect("mysql"),
342+
names.SnakeMapper{},
343+
names.GonicMapper{},
344+
caches.NewManager(),
345+
)
346+
347+
type StructWithSet struct {
348+
Name string `db:"set('alice', 'bob')"`
349+
}
350+
351+
table, err := parser.Parse(reflect.ValueOf(new(StructWithSet)))
352+
assert.NoError(t, err)
353+
assert.EqualValues(t, "struct_with_set", table.Name)
354+
assert.EqualValues(t, 1, len(table.Columns()))
355+
assert.EqualValues(t, "name", table.Columns()[0].Name)
356+
assert.True(t, table.Columns()[0].Nullable)
357+
assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name))
358+
assert.EqualValues(t, map[string]int{
359+
"alice": 0,
360+
"bob": 1,
361+
}, table.Columns()[0].SetOptions)
362+
}
363+
364+
func TestParseWithIndex(t *testing.T) {
365+
parser := NewParser(
366+
"db",
367+
dialects.QueryDialect("mysql"),
368+
names.SnakeMapper{},
369+
names.GonicMapper{},
370+
caches.NewManager(),
371+
)
372+
373+
type StructWithIndex struct {
374+
Name string `db:"index"`
375+
Name2 string `db:"index(s)"`
376+
Name3 string `db:"unique"`
377+
}
378+
379+
table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex)))
380+
assert.NoError(t, err)
381+
assert.EqualValues(t, "struct_with_index", table.Name)
382+
assert.EqualValues(t, 3, len(table.Columns()))
383+
assert.EqualValues(t, "name", table.Columns()[0].Name)
384+
assert.EqualValues(t, "name2", table.Columns()[1].Name)
385+
assert.EqualValues(t, "name3", table.Columns()[2].Name)
386+
assert.True(t, table.Columns()[0].Nullable)
387+
assert.True(t, table.Columns()[1].Nullable)
388+
assert.True(t, table.Columns()[2].Nullable)
389+
assert.EqualValues(t, 1, len(table.Columns()[0].Indexes))
390+
assert.EqualValues(t, 1, len(table.Columns()[1].Indexes))
391+
assert.EqualValues(t, 1, len(table.Columns()[2].Indexes))
392+
}
393+
394+
func TestParseWithVersion(t *testing.T) {
395+
parser := NewParser(
396+
"db",
397+
dialects.QueryDialect("mysql"),
398+
names.SnakeMapper{},
399+
names.GonicMapper{},
400+
caches.NewManager(),
401+
)
402+
403+
type StructWithVersion struct {
404+
Name string
405+
Version int `db:"version"`
406+
}
407+
408+
table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion)))
409+
assert.NoError(t, err)
410+
assert.EqualValues(t, "struct_with_version", table.Name)
411+
assert.EqualValues(t, 2, len(table.Columns()))
412+
assert.EqualValues(t, "name", table.Columns()[0].Name)
413+
assert.EqualValues(t, "version", table.Columns()[1].Name)
414+
assert.True(t, table.Columns()[0].Nullable)
415+
assert.True(t, table.Columns()[1].Nullable)
416+
assert.True(t, table.Columns()[1].IsVersion)
417+
}
418+
419+
func TestParseWithLocale(t *testing.T) {
420+
parser := NewParser(
421+
"db",
422+
dialects.QueryDialect("mysql"),
423+
names.SnakeMapper{},
424+
names.GonicMapper{},
425+
caches.NewManager(),
426+
)
427+
428+
type StructWithLocale struct {
429+
UTCLocale time.Time `db:"utc"`
430+
LocalLocale time.Time `db:"local"`
431+
}
432+
433+
table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale)))
434+
assert.NoError(t, err)
435+
assert.EqualValues(t, "struct_with_locale", table.Name)
436+
assert.EqualValues(t, 2, len(table.Columns()))
437+
assert.EqualValues(t, "utc_locale", table.Columns()[0].Name)
438+
assert.EqualValues(t, "local_locale", table.Columns()[1].Name)
439+
assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone)
440+
assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone)
441+
}
442+
443+
func TestParseWithDefault(t *testing.T) {
444+
parser := NewParser(
445+
"db",
446+
dialects.QueryDialect("mysql"),
447+
names.SnakeMapper{},
448+
names.GonicMapper{},
449+
caches.NewManager(),
450+
)
451+
452+
type StructWithDefault struct {
453+
Default1 time.Time `db:"default '1970-01-01 00:00:00'"`
454+
Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"`
455+
}
456+
457+
table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault)))
458+
assert.NoError(t, err)
459+
assert.EqualValues(t, "struct_with_default", table.Name)
460+
assert.EqualValues(t, 2, len(table.Columns()))
461+
assert.EqualValues(t, "default1", table.Columns()[0].Name)
462+
assert.EqualValues(t, "default2", table.Columns()[1].Name)
463+
assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default)
464+
assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default)
465+
}
466+
467+
func TestParseWithOnlyToDB(t *testing.T) {
468+
parser := NewParser(
469+
"db",
470+
dialects.QueryDialect("mysql"),
471+
names.GonicMapper{
472+
"DB": true,
473+
},
474+
names.SnakeMapper{},
475+
caches.NewManager(),
476+
)
477+
478+
type StructWithOnlyToDB struct {
479+
Default1 time.Time `db:"->"`
480+
Default2 time.Time `db:"<-"`
481+
}
482+
483+
table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB)))
484+
assert.NoError(t, err)
485+
assert.EqualValues(t, "struct_with_only_to_db", table.Name)
486+
assert.EqualValues(t, 2, len(table.Columns()))
487+
assert.EqualValues(t, "default1", table.Columns()[0].Name)
488+
assert.EqualValues(t, "default2", table.Columns()[1].Name)
489+
assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType)
490+
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
491+
}
492+
493+
func TestParseWithJSON(t *testing.T) {
494+
parser := NewParser(
495+
"db",
496+
dialects.QueryDialect("mysql"),
497+
names.GonicMapper{
498+
"JSON": true,
499+
},
500+
names.SnakeMapper{},
501+
caches.NewManager(),
502+
)
503+
504+
type StructWithJSON struct {
505+
Default1 []string `db:"json"`
506+
}
507+
508+
table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON)))
509+
assert.NoError(t, err)
510+
assert.EqualValues(t, "struct_with_json", table.Name)
511+
assert.EqualValues(t, 1, len(table.Columns()))
512+
assert.EqualValues(t, "default1", table.Columns()[0].Name)
513+
assert.True(t, table.Columns()[0].IsJSON)
514+
}
515+
516+
func TestParseWithSQLType(t *testing.T) {
517+
parser := NewParser(
518+
"db",
519+
dialects.QueryDialect("mysql"),
520+
names.GonicMapper{
521+
"SQL": true,
522+
},
523+
names.GonicMapper{
524+
"UUID": true,
525+
},
526+
caches.NewManager(),
527+
)
528+
529+
type StructWithSQLType struct {
530+
Col1 string `db:"varchar(32)"`
531+
Col2 string `db:"char(32)"`
532+
Int int64 `db:"bigint"`
533+
DateTime time.Time `db:"datetime"`
534+
UUID string `db:"uuid"`
535+
}
536+
537+
table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType)))
538+
assert.NoError(t, err)
539+
assert.EqualValues(t, "struct_with_sql_type", table.Name)
540+
assert.EqualValues(t, 5, len(table.Columns()))
541+
assert.EqualValues(t, "col1", table.Columns()[0].Name)
542+
assert.EqualValues(t, "col2", table.Columns()[1].Name)
543+
assert.EqualValues(t, "int", table.Columns()[2].Name)
544+
assert.EqualValues(t, "date_time", table.Columns()[3].Name)
545+
assert.EqualValues(t, "uuid", table.Columns()[4].Name)
546+
547+
assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name)
548+
assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name)
549+
assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name)
550+
assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name)
551+
assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name)
552+
}

‎tags/tag.go

+98-49
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,74 @@ import (
1414
"xorm.io/xorm/schemas"
1515
)
1616

17-
func splitTag(tag string) (tags []string) {
18-
tag = strings.TrimSpace(tag)
19-
var hasQuote = false
20-
var lastIdx = 0
21-
for i, t := range tag {
22-
if t == '\'' {
23-
hasQuote = !hasQuote
24-
} else if t == ' ' {
25-
if lastIdx < i && !hasQuote {
26-
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
27-
lastIdx = i + 1
17+
type tag struct {
18+
name string
19+
params []string
20+
}
21+
22+
func splitTag(tagStr string) ([]tag, error) {
23+
tagStr = strings.TrimSpace(tagStr)
24+
var (
25+
inQuote bool
26+
inBigQuote bool
27+
lastIdx int
28+
curTag tag
29+
paramStart int
30+
tags []tag
31+
)
32+
for i, t := range tagStr {
33+
switch t {
34+
case '\'':
35+
inQuote = !inQuote
36+
case ' ':
37+
if !inQuote && !inBigQuote {
38+
if lastIdx < i {
39+
if curTag.name == "" {
40+
curTag.name = tagStr[lastIdx:i]
41+
}
42+
tags = append(tags, curTag)
43+
lastIdx = i + 1
44+
curTag = tag{}
45+
} else if lastIdx == i {
46+
lastIdx = i + 1
47+
}
48+
} else if inBigQuote && !inQuote {
49+
paramStart = i + 1
50+
}
51+
case ',':
52+
if !inQuote && !inBigQuote {
53+
return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr)
54+
}
55+
if !inQuote && inBigQuote {
56+
curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i]))
57+
paramStart = i + 1
58+
}
59+
case '(':
60+
inBigQuote = true
61+
if !inQuote {
62+
curTag.name = tagStr[lastIdx:i]
63+
paramStart = i + 1
64+
}
65+
case ')':
66+
inBigQuote = false
67+
if !inQuote {
68+
curTag.params = append(curTag.params, tagStr[paramStart:i])
2869
}
2970
}
3071
}
31-
if lastIdx < len(tag) {
32-
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
72+
if lastIdx < len(tagStr) {
73+
if curTag.name == "" {
74+
curTag.name = tagStr[lastIdx:]
75+
}
76+
tags = append(tags, curTag)
3377
}
34-
return
78+
return tags, nil
3579
}
3680

3781
// Context represents a context for xorm tag parse.
3882
type Context struct {
39-
tagName string
40-
params []string
83+
tag
84+
tagUname string
4185
preTag, nextTag string
4286
table *schemas.Table
4387
col *schemas.Column
@@ -76,6 +120,7 @@ var (
76120
"CACHE": CacheTagHandler,
77121
"NOCACHE": NoCacheTagHandler,
78122
"COMMENT": CommentTagHandler,
123+
"EXTENDS": ExtendsTagHandler,
79124
}
80125
)
81126

@@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error {
124169
// AutoIncrTagHandler describes autoincr tag handler
125170
func AutoIncrTagHandler(ctx *Context) error {
126171
ctx.col.IsAutoIncrement = true
172+
ctx.col.Nullable = false
127173
/*
128174
if len(ctx.params) > 0 {
129175
autoStartInt, err := strconv.Atoi(ctx.params[0])
@@ -225,41 +271,44 @@ func CommentTagHandler(ctx *Context) error {
225271

226272
// SQLTypeTagHandler describes SQL Type tag handler
227273
func SQLTypeTagHandler(ctx *Context) error {
228-
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
229-
if strings.EqualFold(ctx.tagName, "JSON") {
274+
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}
275+
if ctx.tagUname == "JSON" {
230276
ctx.col.IsJSON = true
231277
}
232-
if len(ctx.params) > 0 {
233-
if ctx.tagName == schemas.Enum {
234-
ctx.col.EnumOptions = make(map[string]int)
235-
for k, v := range ctx.params {
236-
v = strings.TrimSpace(v)
237-
v = strings.Trim(v, "'")
238-
ctx.col.EnumOptions[v] = k
278+
if len(ctx.params) == 0 {
279+
return nil
280+
}
281+
282+
switch ctx.tagUname {
283+
case schemas.Enum:
284+
ctx.col.EnumOptions = make(map[string]int)
285+
for k, v := range ctx.params {
286+
v = strings.TrimSpace(v)
287+
v = strings.Trim(v, "'")
288+
ctx.col.EnumOptions[v] = k
289+
}
290+
case schemas.Set:
291+
ctx.col.SetOptions = make(map[string]int)
292+
for k, v := range ctx.params {
293+
v = strings.TrimSpace(v)
294+
v = strings.Trim(v, "'")
295+
ctx.col.SetOptions[v] = k
296+
}
297+
default:
298+
var err error
299+
if len(ctx.params) == 2 {
300+
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
301+
if err != nil {
302+
return err
239303
}
240-
} else if ctx.tagName == schemas.Set {
241-
ctx.col.SetOptions = make(map[string]int)
242-
for k, v := range ctx.params {
243-
v = strings.TrimSpace(v)
244-
v = strings.Trim(v, "'")
245-
ctx.col.SetOptions[v] = k
304+
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
305+
if err != nil {
306+
return err
246307
}
247-
} else {
248-
var err error
249-
if len(ctx.params) == 2 {
250-
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
251-
if err != nil {
252-
return err
253-
}
254-
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
255-
if err != nil {
256-
return err
257-
}
258-
} else if len(ctx.params) == 1 {
259-
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
260-
if err != nil {
261-
return err
262-
}
308+
} else if len(ctx.params) == 1 {
309+
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
310+
if err != nil {
311+
return err
263312
}
264313
}
265314
}
@@ -293,7 +342,7 @@ func ExtendsTagHandler(ctx *Context) error {
293342
var tagPrefix = ctx.col.FieldName
294343
if len(ctx.params) > 0 {
295344
col.Nullable = isPtr
296-
tagPrefix = ctx.params[0]
345+
tagPrefix = strings.Trim(ctx.params[0], "'")
297346
if col.IsPrimaryKey {
298347
col.Name = ctx.col.FieldName
299348
col.IsPrimaryKey = false
@@ -315,7 +364,7 @@ func ExtendsTagHandler(ctx *Context) error {
315364
default:
316365
//TODO: warning
317366
}
318-
return nil
367+
return ErrIgnoreField
319368
}
320369

321370
// CacheTagHandler describes cache tag handler

‎tags/tag_test.go

+69-10
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,83 @@ package tags
77
import (
88
"testing"
99

10-
"xorm.io/xorm/internal/utils"
10+
"github.com/stretchr/testify/assert"
1111
)
1212

1313
func TestSplitTag(t *testing.T) {
1414
var cases = []struct {
1515
tag string
16-
tags []string
16+
tags []tag
1717
}{
18-
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}},
19-
{"TEXT", []string{"TEXT"}},
20-
{"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}},
21-
{"json binary", []string{"json", "binary"}},
18+
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
19+
{
20+
name: "not",
21+
},
22+
{
23+
name: "null",
24+
},
25+
{
26+
name: "default",
27+
},
28+
{
29+
name: "'2000-01-01 00:00:00'",
30+
},
31+
{
32+
name: "TIMESTAMP",
33+
},
34+
},
35+
},
36+
{"TEXT", []tag{
37+
{
38+
name: "TEXT",
39+
},
40+
},
41+
},
42+
{"default('2000-01-01 00:00:00')", []tag{
43+
{
44+
name: "default",
45+
params: []string{
46+
"'2000-01-01 00:00:00'",
47+
},
48+
},
49+
},
50+
},
51+
{"json binary", []tag{
52+
{
53+
name: "json",
54+
},
55+
{
56+
name: "binary",
57+
},
58+
},
59+
},
60+
{"numeric(10, 2)", []tag{
61+
{
62+
name: "numeric",
63+
params: []string{"10", "2"},
64+
},
65+
},
66+
},
67+
{"numeric(10, 2) notnull", []tag{
68+
{
69+
name: "numeric",
70+
params: []string{"10", "2"},
71+
},
72+
{
73+
name: "notnull",
74+
},
75+
},
76+
},
2277
}
2378

2479
for _, kase := range cases {
25-
tags := splitTag(kase.tag)
26-
if !utils.SliceEq(tags, kase.tags) {
27-
t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags)
28-
}
80+
t.Run(kase.tag, func(t *testing.T) {
81+
tags, err := splitTag(kase.tag)
82+
assert.NoError(t, err)
83+
assert.EqualValues(t, len(tags), len(kase.tags))
84+
for i := 0; i < len(tags); i++ {
85+
assert.Equal(t, tags[i], kase.tags[i])
86+
}
87+
})
2988
}
3089
}

0 commit comments

Comments
 (0)
Please sign in to comment.