Skip to content

Commit

Permalink
Refactor getting column name to actually use schema from model (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ActiveChooN authored Oct 9, 2024
1 parent 68c44a5 commit 34a41a4
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 59 deletions.
52 changes: 24 additions & 28 deletions gin-gorm-filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
"reflect"
"regexp"
"strings"
"sync"

"github.com/gin-gonic/gin"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

type queryParams struct {
Expand All @@ -36,8 +38,7 @@ const (
)

var (
columnNameRegexp = regexp.MustCompile(`(?m)column:(\w{1,}).*`)
paramNameRegexp = regexp.MustCompile(`(?m)param:(\w{1,}).*`)
paramNameRegexp = regexp.MustCompile(`(?m)param:(\w{1,}).*`)
)

func orderBy(db *gorm.DB, params queryParams) *gorm.DB {
Expand Down Expand Up @@ -67,33 +68,23 @@ func paginate(db *gorm.DB, params queryParams) *gorm.DB {
return db.Offset(offset).Limit(params.PageSize)
}

func getColumnNameForField(field reflect.StructField) string {
fieldTag := field.Tag.Get("gorm")
res := columnNameRegexp.FindStringSubmatch(fieldTag)
if len(res) == 2 {
return res[1]
}
return field.Name
}

func searchField(field reflect.StructField, phrase string) clause.Expression {
func searchField(columnName string, field reflect.StructField, phrase string) clause.Expression {
filterTag := field.Tag.Get(tagKey)
columnName := getColumnNameForField(field)

if strings.Contains(filterTag, "searchable") {
return clause.Like{
Column: clause.Expr{SQL: "LOWER(?)", Vars: []interface{}{columnName}},
Column: clause.Expr{SQL: "LOWER(?)", Vars: []interface{}{clause.Column{Table: clause.CurrentTable, Name: columnName}}},
Value: "%" + strings.ToLower(phrase) + "%",
}
}
return nil
}

func filterField(field reflect.StructField, phrase string) clause.Expression {
func filterField(columnName string, field reflect.StructField, phrase string) clause.Expression {
var paramName string
if !strings.Contains(field.Tag.Get(tagKey), "filterable") {
return nil
}
columnName := getColumnNameForField(field)
paramMatch := paramNameRegexp.FindStringSubmatch(field.Tag.Get(tagKey))
if len(paramMatch) == 2 {
paramName = paramMatch[1]
Expand All @@ -112,37 +103,42 @@ func filterField(field reflect.StructField, phrase string) clause.Expression {
if len(filterSubPhraseMatch) == 3 {
switch filterSubPhraseMatch[1] {
case ">=":
return clause.Gte{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Gte{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
case "<=":
return clause.Lte{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Lte{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
case "!=":
return clause.Neq{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Neq{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
case ">":
return clause.Gt{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
case "<":
return clause.Lt{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Lt{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
case "~":
return clause.Like{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Like{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
default:
return clause.Eq{Column: columnName, Value: filterSubPhraseMatch[2]}
return clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: columnName}, Value: filterSubPhraseMatch[2]}
}
}
return nil
}

func expressionByField(
db *gorm.DB, phrases []string, modelType reflect.Type,
operator func(reflect.StructField, string) clause.Expression,
db *gorm.DB, phrases []string,
operator func(string, reflect.StructField, string) clause.Expression,
predicate func(...clause.Expression) clause.Expression,
) *gorm.DB {
modelType := reflect.TypeOf(db.Statement.Model).Elem()
numFields := modelType.NumField()
modelSchema, err := schema.Parse(db.Statement.Model, &sync.Map{}, db.NamingStrategy)
if err != nil {
return db
}
var allExpressions []clause.Expression

for _, phrase := range phrases {
expressions := make([]clause.Expression, 0, numFields)
for i := 0; i < numFields; i++ {
field := modelType.Field(i)
expression := operator(field, phrase)
expression := operator(modelSchema.LookUpField(field.Name).DBName, field, phrase)
if expression != nil {
expressions = append(expressions, expression)
}
Expand Down Expand Up @@ -189,10 +185,10 @@ func FilterByQuery(c *gin.Context, config int) func(db *gorm.DB) *gorm.DB {
modelType := reflect.TypeOf(model)
if model != nil && modelType.Kind() == reflect.Ptr && modelType.Elem().Kind() == reflect.Struct {
if config&SEARCH > 0 && params.Search != "" {
db = expressionByField(db, []string{params.Search}, modelType.Elem(), searchField, clause.Or)
db = expressionByField(db, []string{params.Search}, searchField, clause.Or)
}
if config&FILTER > 0 && len(params.Filter) > 0 {
db = expressionByField(db, params.Filter, modelType.Elem(), filterField, clause.And)
db = expressionByField(db, params.Filter, filterField, clause.And)
}
}

Expand Down
87 changes: 56 additions & 31 deletions gin-gorm-filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ import (
"gorm.io/gorm"
)

type Organization struct {
Id uint `filter:"param:id;filterable"`
Name string `filter:"param:name;searchable"`
}

type User struct {
Id int64 `filter:"param:id;filterable"`
Username string `filter:"param:login;searchable;filterable"`
FullName string `filter:"param:name;searchable"`
Email string `filter:"filterable"`
Id uint `filter:"param:id;filterable"`
Username string `filter:"param:login;searchable;filterable"`
FullName string `filter:"param:name;searchable"`
Email string `filter:"filterable"`
OrganizationId uint
Organization Organization
// This field is not filtered.
Password string
}
Expand Down Expand Up @@ -73,9 +80,9 @@ func (s *TestSuite) TestFiltersBasic() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Username" = \$1 ORDER BY "id" DESC LIMIT \$2$`).
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."username" = \$1 ORDER BY "id" DESC LIMIT \$2$`).
WithArgs("sampleUser", 10).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, ALL)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -90,9 +97,9 @@ func (s *TestSuite) TestFiltersLike() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Username" LIKE \$1 ORDER BY "id" DESC LIMIT \$2$`).
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."username" LIKE \$1 ORDER BY "id" DESC LIMIT \$2$`).
WithArgs("samp", 10).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, ALL)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -107,7 +114,7 @@ func (s *TestSuite) TestFiltersNotFilterable() {
},
}
s.mock.ExpectQuery(`^SELECT \* FROM "users"$`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -123,7 +130,7 @@ func (s *TestSuite) TestFiltersNoFilterConfig() {
}

s.mock.ExpectQuery(`^SELECT \* FROM "users"$`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, SEARCH)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -138,8 +145,8 @@ func (s *TestSuite) TestFiltersNotEqualTo() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Id" <> \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."id" <> \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -153,8 +160,8 @@ func (s *TestSuite) TestFiltersLessThan() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Username" < \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."username" < \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -169,8 +176,8 @@ func (s *TestSuite) TestFiltersLessThanOrEqualTo() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Id" <= \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."id" <= \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -184,8 +191,8 @@ func (s *TestSuite) TestFiltersGreaterThan() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Id" > \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."id" > \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -199,8 +206,8 @@ func (s *TestSuite) TestFiltersGreaterThanOrEqualTo() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "Id" >= \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE "users"."id" >= \$1`).
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -215,9 +222,9 @@ func (s *TestSuite) TestFiltersSearchable() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE \(LOWER\(\$1\) LIKE \$2 OR LOWER\(\$3\) LIKE \$4\)$`).
WithArgs("Username", "%john%", "FullName", "%john%").
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE \(LOWER\("users"."username"\) LIKE \$1 OR LOWER\("users"."full_name"\) LIKE \$2\)$`).
WithArgs("%john%", "%john%").
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, SEARCH)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -234,7 +241,7 @@ func (s *TestSuite) TestFiltersPaginateOnly() {

s.mock.ExpectQuery(`^SELECT \* FROM "users" ORDER BY "id" DESC LIMIT \$1 OFFSET \$2$`).
WithArgs(10, 10).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, ALL)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -250,7 +257,7 @@ func (s *TestSuite) TestFiltersOrderBy() {
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" ORDER BY "Email"$`).
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))
err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, ORDER_BY)).Find(&users).Error
s.NoError(err)
}
Expand All @@ -265,9 +272,9 @@ func (s *TestSuite) TestFiltersAndSearch() {
},
}

s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE \(LOWER\(\$1\) LIKE \$2 OR LOWER\(\$3\) LIKE \$4\) AND "Username" = \$5$`).
WithArgs("Username", "%john%", "FullName", "%john%", "sampleUser").
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
s.mock.ExpectQuery(`^SELECT \* FROM "users" WHERE \(LOWER\("users"."username"\) LIKE \$1 OR LOWER\("users"."full_name"\) LIKE \$2\) AND "users"."username" = \$3$`).
WithArgs("%john%", "%john%", "sampleUser").
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))

err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER|SEARCH)).Find(&users).Error
s.NoError(err)
Expand All @@ -279,18 +286,36 @@ func (s *TestSuite) TestFiltersMultipleColumns() {
ctx := gin.Context{}
ctx.Request = &http.Request{
URL: &url.URL{
RawQuery: "filter=login:sampleUser&filter=Email:[email protected]",
RawQuery: "filter=login:sampleUser&filter=email:[email protected]",
},
}

s.mock.ExpectQuery(`SELECT \* FROM "users" WHERE "Username" = \$1 AND "Email" = \$2$`).
s.mock.ExpectQuery(`SELECT \* FROM "users" WHERE "users"."username" = \$1 AND "users"."email" = \$2$`).
WithArgs("sampleUser", "[email protected]").
WillReturnRows(sqlmock.NewRows([]string{"id", "Username", "FullName", "Email", "Password"}))
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))

err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Find(&users).Error
s.NoError(err)
}

// TestFiltersWithJoin is a test for filtering with join.
func (s *TestSuite) TestFiltersWithJoin() {
var users []User
ctx := gin.Context{}
ctx.Request = &http.Request{
URL: &url.URL{
RawQuery: "filter=id!=22",
},
}

s.mock.ExpectQuery(`SELECT "users"."id","users"."username","users"."full_name","users"."email","users"."organization_id","users"."password","Organization"."id" AS "Organization__id","Organization"."name" AS "Organization__name" FROM "users" LEFT JOIN "organizations" "Organization" ON "users"."organization_id" = "Organization"."id" WHERE "users"."id" <> \$1$`).
WithArgs("22").
WillReturnRows(sqlmock.NewRows([]string{"id", "username", "full_name", "email", "password"}))

err := s.db.Model(&User{}).Scopes(FilterByQuery(&ctx, FILTER)).Joins("Organization").Find(&users).Error
s.NoError(err)
}

func TestRunSuite(t *testing.T) {
suite.Run(t, new(TestSuite))
}

0 comments on commit 34a41a4

Please sign in to comment.