Skip to content

Commit 00bbda8

Browse files
authored
feat: support batch inserts (#22)
1 parent f8f37ec commit 00bbda8

File tree

5 files changed

+91
-0
lines changed

5 files changed

+91
-0
lines changed

create.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package crud
33
import (
44
"context"
55
stdsql "database/sql"
6+
"errors"
67
"fmt"
78
"maps"
9+
"reflect"
810
"slices"
911

1012
"github.com/azer/crud/v2/sql"
@@ -33,6 +35,38 @@ func createAndRead(ctx context.Context, exec ExecFn, query QueryFn, record inter
3335
return readLastInsert(ctx, query, record, result)
3436
}
3537

38+
func bulkCreate(ctx context.Context, exec ExecFn, value any) error {
39+
v := reflect.ValueOf(value)
40+
if v.Kind() != reflect.Slice {
41+
return errors.New("records must be a slice")
42+
}
43+
44+
records := make([]any, 0, v.Len())
45+
for i := 0; i < v.Len(); i++ {
46+
records = append(records, v.Index(i).Interface())
47+
}
48+
49+
row, columns, _, err := valuesForRecord(records[0])
50+
if err != nil {
51+
return err
52+
}
53+
54+
query := sql.InsertBulkQuery(row.SQLTableName, columns, len(records))
55+
values := make([]interface{}, 0, len(records)*len(columns))
56+
57+
for _, record := range records {
58+
_, _, v, err := valuesForRecord(record)
59+
if err != nil {
60+
return err
61+
}
62+
63+
values = append(values, v...)
64+
}
65+
66+
_, err = exec(ctx, query, values...)
67+
return err
68+
}
69+
3670
func replaceAndGetResult(ctx context.Context, exec ExecFn, record interface{}) (stdsql.Result, error) {
3771
row, columns, values, err := valuesForRecord(record)
3872
if err != nil {

create_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,43 @@ func TestCreatingRenamedTableRow(t *testing.T) {
8181

8282
DB.DropTables(ctx, Post{})
8383
}
84+
85+
func TestCreateBulk(t *testing.T) {
86+
ctx := context.Background()
87+
88+
DB.ResetTables(ctx, UserProfile{})
89+
90+
users := []UserProfile{
91+
{
92+
Name: "Azer",
93+
Bio: "I like photography",
94+
95+
},
96+
{
97+
Name: "Azer2",
98+
Bio: "I like photography2",
99+
100+
},
101+
{
102+
Name: "Azer3",
103+
Bio: "I like photography3",
104+
105+
},
106+
}
107+
108+
err := DB.BulkCreate(ctx, users)
109+
assert.Nil(t, err)
110+
111+
var users2 []UserProfile
112+
err = DB.Read(ctx, &users2, "SELECT * FROM user_profiles")
113+
assert.Nil(t, err)
114+
assert.Equal(t, len(users), 3)
115+
116+
for i, user := range users2 {
117+
assert.Equal(t, user.Name, users[i].Name)
118+
assert.Equal(t, user.Bio, users[i].Bio)
119+
assert.Equal(t, user.Email, users[i].Email)
120+
}
121+
122+
DB.DropTables(ctx, UserProfile{})
123+
}

db.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,7 @@ func Connect(driver, url string, logger *slog.Logger) (*DB, error) {
192192
Logger: logger,
193193
}, nil
194194
}
195+
196+
func (db *DB) BulkCreate(ctx context.Context, records any) error {
197+
return bulkCreate(ctx, db.Exec, records)
198+
}

sql/table.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sql
22

33
import (
44
"fmt"
5+
"slices"
56
"strings"
67
)
78

@@ -115,6 +116,14 @@ func InsertQuery(tableName string, columnNames []string) string {
115116
tableName, strings.Join(quoteColumnNames(columnNames), ","), questionMarks)
116117
}
117118

119+
func InsertBulkQuery(tableName string, columnNames []string, numRecords int) string {
120+
pattern := fmt.Sprintf("(%s)", repeatComma(len(columnNames), "?"))
121+
questionMarks := slices.Repeat([]string{pattern}, numRecords)
122+
123+
return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES %s",
124+
tableName, strings.Join(quoteColumnNames(columnNames), ","), strings.Join(questionMarks, ","))
125+
}
126+
118127
func ReplaceQuery(tableName string, columnNames []string) string {
119128
questionMarks := repeatComma(len(columnNames), "?")
120129

sql/table_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ func TestInsertQuery(t *testing.T) {
9090
assert.Equal(t, sql.InsertQuery("yolo", []string{"name", "email", "age"}), "INSERT INTO `yolo` (`name`,`email`,`age`) VALUES (?,?,?)")
9191
}
9292

93+
func TestInsertBulkQuery(t *testing.T) {
94+
assert.Equal(t, sql.InsertBulkQuery("yolo", []string{"name", "email", "age"}, 3), "INSERT INTO `yolo` (`name`,`email`,`age`) VALUES (?,?,?),(?,?,?),(?,?,?)")
95+
}
96+
9397
func TestUpdateQuery(t *testing.T) {
9498
assert.Equal(t, sql.UpdateQuery("yolo", "id", []string{"name", "email", "age"}), "UPDATE `yolo` SET `name`=?, `email`=?, `age`=? WHERE id=?")
9599
}

0 commit comments

Comments
 (0)