Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit cd34c1a

Browse files
authored
Merge pull request #543 from theodesp/feature/replace_repeat_reverse
Add Reverse, Repeat, Replace
2 parents 685f4e7 + 7486067 commit cd34c1a

File tree

3 files changed

+340
-0
lines changed

3 files changed

+340
-0
lines changed

sql/expression/function/registry.go

+3
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,7 @@ var Defaults = sql.Functions{
6161
"ltrim": sql.Function1(NewTrimFunc(lTrimType)),
6262
"rtrim": sql.Function1(NewTrimFunc(rTrimType)),
6363
"trim": sql.Function1(NewTrimFunc(bTrimType)),
64+
"reverse": sql.Function1(NewReverse),
65+
"repeat": sql.Function2(NewRepeat),
66+
"replace": sql.Function3(NewReplace),
6467
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql"
9+
"gopkg.in/src-d/go-errors.v1"
10+
)
11+
12+
// Reverse is a function that returns the reverse of the text provided.
13+
type Reverse struct {
14+
expression.UnaryExpression
15+
}
16+
17+
// NewReverse creates a new Reverse expression.
18+
func NewReverse(e sql.Expression) sql.Expression {
19+
return &Reverse{expression.UnaryExpression{Child: e}}
20+
}
21+
22+
// Eval implements the Expression interface.
23+
func (r *Reverse) Eval(
24+
ctx *sql.Context,
25+
row sql.Row,
26+
) (interface{}, error) {
27+
v, err := r.Child.Eval(ctx, row)
28+
if v == nil || err != nil {
29+
return nil, err
30+
}
31+
32+
v, err = sql.Text.Convert(v)
33+
if err != nil {
34+
return nil, err
35+
}
36+
37+
return reverseString(v.(string)), nil
38+
}
39+
40+
func reverseString(s string) string {
41+
r := []rune(s)
42+
for i, j := 0, len(r) - 1; i < j; i, j = i+1, j-1 {
43+
r[i], r[j] = r[j], r[i]
44+
}
45+
return string(r)
46+
}
47+
48+
func (r *Reverse) String() string {
49+
return fmt.Sprintf("reverse(%s)", r.Child)
50+
}
51+
52+
// TransformUp implements the Expression interface.
53+
func (r *Reverse) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
54+
child, err := r.Child.TransformUp(f)
55+
if err != nil {
56+
return nil, err
57+
}
58+
return f(NewReverse(child))
59+
}
60+
61+
// Type implements the Expression interface.
62+
func (r *Reverse) Type() sql.Type {
63+
return r.Child.Type()
64+
}
65+
66+
var ErrNegativeRepeatCount = errors.NewKind("negative Repeat count: %v")
67+
68+
// Repeat is a function that returns the string repeated n times.
69+
type Repeat struct {
70+
expression.BinaryExpression
71+
}
72+
73+
// NewRepeat creates a new Repeat expression.
74+
func NewRepeat(str sql.Expression, count sql.Expression) sql.Expression {
75+
return &Repeat{expression.BinaryExpression{Left: str, Right: count}}
76+
}
77+
78+
func (r *Repeat) String() string {
79+
return fmt.Sprintf("repeat(%s, %s)", r.Left, r.Right)
80+
}
81+
82+
// Type implements the Expression interface.
83+
func (r *Repeat) Type() sql.Type {
84+
return sql.Text
85+
}
86+
87+
// TransformUp implements the Expression interface.
88+
func (r *Repeat) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
89+
left, err := r.Left.TransformUp(f)
90+
if err != nil {
91+
return nil, err
92+
}
93+
94+
right, err := r.Right.TransformUp(f)
95+
if err != nil {
96+
return nil, err
97+
}
98+
return f(NewRepeat(left, right))
99+
}
100+
101+
// Eval implements the Expression interface.
102+
func (r *Repeat) Eval(
103+
ctx *sql.Context,
104+
row sql.Row,
105+
) (interface{}, error) {
106+
str, err := r.Left.Eval(ctx, row)
107+
if str == nil || err != nil {
108+
return nil, err
109+
}
110+
111+
str, err = sql.Text.Convert(str)
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
count, err := r.Right.Eval(ctx, row)
117+
if count == nil || err != nil {
118+
return nil, err
119+
}
120+
121+
count, err = sql.Int32.Convert(count)
122+
if err != nil {
123+
return nil, err
124+
}
125+
if count.(int32) < 0 {
126+
return nil, ErrNegativeRepeatCount.New(count)
127+
}
128+
return strings.Repeat(str.(string), int(count.(int32))), nil
129+
}
130+
131+
// Replace is a function that returns a string with all occurrences of fromStr replaced by the
132+
// string toStr
133+
type Replace struct {
134+
str sql.Expression
135+
fromStr sql.Expression
136+
toStr sql.Expression
137+
}
138+
139+
// NewReplace creates a new Replace expression.
140+
func NewReplace(str sql.Expression, fromStr sql.Expression, toStr sql.Expression) sql.Expression {
141+
return &Replace{str, fromStr, toStr}
142+
}
143+
144+
// Children implements the Expression interface.
145+
func (r *Replace) Children() []sql.Expression {
146+
return []sql.Expression{r.str, r.fromStr, r.toStr}
147+
}
148+
149+
// Resolved implements the Expression interface.
150+
func (r *Replace) Resolved() bool {
151+
return r.str.Resolved() && r.fromStr.Resolved() && r.toStr.Resolved()
152+
}
153+
154+
// IsNullable implements the Expression interface.
155+
func (r *Replace) IsNullable() bool {
156+
return r.str.IsNullable() || r.fromStr.IsNullable() || r.toStr.IsNullable()
157+
}
158+
159+
func (r *Replace) String() string {
160+
return fmt.Sprintf("replace(%s, %s, %s)", r.str, r.fromStr, r.toStr)
161+
}
162+
163+
// Type implements the Expression interface.
164+
func (r *Replace) Type() sql.Type {
165+
return sql.Text
166+
}
167+
168+
// TransformUp implements the Expression interface.
169+
func (r *Replace) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
170+
str, err := r.str.TransformUp(f)
171+
if err != nil {
172+
return nil, err
173+
}
174+
175+
fromStr, err := r.fromStr.TransformUp(f)
176+
if err != nil {
177+
return nil, err
178+
}
179+
180+
toStr, err := r.toStr.TransformUp(f)
181+
if err != nil {
182+
return nil, err
183+
}
184+
return f(NewReplace(str, fromStr, toStr))
185+
}
186+
187+
// Eval implements the Expression interface.
188+
func (r *Replace) Eval(
189+
ctx *sql.Context,
190+
row sql.Row,
191+
) (interface{}, error) {
192+
str, err := r.str.Eval(ctx, row)
193+
if str == nil || err != nil {
194+
return nil, err
195+
}
196+
197+
str, err = sql.Text.Convert(str)
198+
if err != nil {
199+
return nil, err
200+
}
201+
202+
fromStr, err := r.fromStr.Eval(ctx, row)
203+
if fromStr == nil || err != nil {
204+
return nil, err
205+
}
206+
207+
fromStr, err = sql.Text.Convert(fromStr)
208+
if err != nil {
209+
return nil, err
210+
}
211+
212+
toStr, err := r.toStr.Eval(ctx, row)
213+
if toStr == nil || err != nil {
214+
return nil, err
215+
}
216+
217+
toStr, err = sql.Text.Convert(toStr)
218+
if err != nil {
219+
return nil, err
220+
}
221+
222+
if fromStr.(string) == "" {
223+
return str, nil
224+
}
225+
226+
return strings.Replace(str.(string), fromStr.(string), toStr.(string), -1), nil
227+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-mysql-server.v0/sql"
8+
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
9+
)
10+
11+
func TestReverse(t *testing.T) {
12+
f := NewReverse(expression.NewGetField(0, sql.Text, "", false))
13+
testCases := []struct {
14+
name string
15+
row sql.Row
16+
expected interface{}
17+
err bool
18+
}{
19+
{"null input", sql.NewRow(nil), nil, false},
20+
{"empty string", sql.NewRow(""), "", false},
21+
{"handles numbers as strings", sql.NewRow(123), "321", false},
22+
{"valid string", sql.NewRow("foobar"), "raboof", false},
23+
}
24+
for _, tt := range testCases {
25+
t.Run(tt.name, func(t *testing.T) {
26+
t.Helper()
27+
require := require.New(t)
28+
ctx := sql.NewEmptyContext()
29+
30+
v, err := f.Eval(ctx, tt.row)
31+
if tt.err {
32+
require.Error(err)
33+
} else {
34+
require.NoError(err)
35+
require.Equal(tt.expected, v)
36+
}
37+
})
38+
}
39+
}
40+
41+
func TestRepeat(t *testing.T) {
42+
f := NewRepeat(
43+
expression.NewGetField(0, sql.Text, "", false),
44+
expression.NewGetField(1, sql.Int32, "", false),
45+
)
46+
47+
testCases := []struct {
48+
name string
49+
row sql.Row
50+
expected interface{}
51+
err bool
52+
}{
53+
{"null input", sql.NewRow(nil), nil, false},
54+
{"empty string", sql.NewRow("", 2), "", false},
55+
{"count is zero", sql.NewRow("foo", 0), "", false},
56+
{"count is negative", sql.NewRow("foo", -2), "foo", true},
57+
{"valid string", sql.NewRow("foobar", 2), "foobarfoobar", false},
58+
}
59+
for _, tt := range testCases {
60+
t.Run(tt.name, func(t *testing.T) {
61+
t.Helper()
62+
require := require.New(t)
63+
ctx := sql.NewEmptyContext()
64+
65+
v, err := f.Eval(ctx, tt.row)
66+
if tt.err {
67+
require.Error(err)
68+
} else {
69+
require.NoError(err)
70+
require.Equal(tt.expected, v)
71+
}
72+
})
73+
}
74+
}
75+
76+
func TestReplace(t *testing.T) {
77+
f := NewReplace(
78+
expression.NewGetField(0, sql.Text, "", false),
79+
expression.NewGetField(1, sql.Text, "", false),
80+
expression.NewGetField(2, sql.Text, "", false),
81+
)
82+
83+
testCases := []struct {
84+
name string
85+
row sql.Row
86+
expected interface{}
87+
err bool
88+
}{
89+
{"null inputs", sql.NewRow(nil), nil, false},
90+
{"empty str", sql.NewRow("", "foo", "bar"), "", false},
91+
{"empty fromStr", sql.NewRow("foobarfoobar", "", "car"), "foobarfoobar", false},
92+
{"empty toStr", sql.NewRow("foobarfoobar", "bar", ""), "foofoo", false},
93+
{"valid strings", sql.NewRow("foobarfoobar", "bar", "car"), "foocarfoocar", false},
94+
}
95+
for _, tt := range testCases {
96+
t.Run(tt.name, func(t *testing.T) {
97+
t.Helper()
98+
require := require.New(t)
99+
ctx := sql.NewEmptyContext()
100+
101+
v, err := f.Eval(ctx, tt.row)
102+
if tt.err {
103+
require.Error(err)
104+
} else {
105+
require.NoError(err)
106+
require.Equal(tt.expected, v)
107+
}
108+
})
109+
}
110+
}

0 commit comments

Comments
 (0)