Skip to content

Commit

Permalink
rework: arguments injection safe
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-burghardt committed Jun 26, 2024
1 parent 4174316 commit b3aac4f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
46 changes: 35 additions & 11 deletions internal/data/payments.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,28 +104,39 @@ func (m *PaymentModel) GetPaymentsPaginated(ctx context.Context, address string,

const filteredSetCTE = `
WITH filtered_set AS (
SELECT * FROM ingest_payments WHERE $1 = '' OR $1 IN (from_address, to_address)
SELECT * FROM ingest_payments WHERE :address = '' OR :address IN (from_address, to_address)
)
`

var selectQ string
if beforeID != 0 && sort == DESC {
selectQ = fmt.Sprintf("SELECT * FROM (SELECT * FROM filtered_set WHERE operation_id > %d ORDER BY operation_id ASC LIMIT $2) AS reverse_set ORDER BY operation_id DESC", beforeID)
selectQ = "SELECT * FROM (SELECT * FROM filtered_set WHERE operation_id > :before_id ORDER BY operation_id ASC LIMIT :limit) AS reverse_set ORDER BY operation_id DESC"
} else if beforeID != 0 && sort == ASC {
selectQ = fmt.Sprintf("SELECT * FROM (SELECT * FROM filtered_set WHERE operation_id < %d ORDER BY operation_id DESC LIMIT $2) AS reverse_set ORDER BY operation_id ASC", beforeID)
selectQ = "SELECT * FROM (SELECT * FROM filtered_set WHERE operation_id < :before_id ORDER BY operation_id DESC LIMIT :limit) AS reverse_set ORDER BY operation_id ASC"
} else if afterID != 0 && sort == DESC {
selectQ = fmt.Sprintf("SELECT * FROM filtered_set WHERE operation_id < %d ORDER BY operation_id DESC LIMIT $2", afterID)
selectQ = "SELECT * FROM filtered_set WHERE operation_id < :after_id ORDER BY operation_id DESC LIMIT :limit"
} else if afterID != 0 && sort == ASC {
selectQ = fmt.Sprintf("SELECT * FROM filtered_set WHERE operation_id > %d ORDER BY operation_id ASC LIMIT $2", afterID)
selectQ = "SELECT * FROM filtered_set WHERE operation_id > :after_id ORDER BY operation_id ASC LIMIT :limit"
} else if sort == ASC {
selectQ = "SELECT * FROM filtered_set ORDER BY operation_id ASC LIMIT $2"
selectQ = "SELECT * FROM filtered_set ORDER BY operation_id ASC LIMIT :limit"
} else {
selectQ = "SELECT * FROM filtered_set ORDER BY operation_id DESC LIMIT $2"
selectQ = "SELECT * FROM filtered_set ORDER BY operation_id DESC LIMIT :limit"
}

argumentsMap := map[string]interface{}{
"address": address,
"limit": limit,
"before_id": beforeID,
"after_id": afterID,
}

payments := make([]Payment, 0)
query := fmt.Sprintf("%s %s", filteredSetCTE, selectQ)
err := m.DB.SelectContext(ctx, &payments, query, address, limit)
query, args, err := PrepareNamedQuery(ctx, m.DB, query, argumentsMap)
if err != nil {
return nil, false, false, fmt.Errorf("preparing named query: %w", err)
}
err = m.DB.SelectContext(ctx, &payments, query, args...)
if err != nil {
return nil, false, false, fmt.Errorf("fetching payments: %w", err)
}
Expand All @@ -146,14 +157,27 @@ func (m *PaymentModel) existsPrevNext(ctx context.Context, filteredSetCTE string
%s
SELECT
EXISTS(
SELECT 1 FROM filtered_set WHERE CASE WHEN $2 = 'ASC' THEN operation_id < $3 WHEN $2 = 'DESC' THEN operation_id > $3 END LIMIT 1
SELECT 1 FROM filtered_set WHERE CASE WHEN :sort = 'ASC' THEN operation_id < :first_element_id WHEN :sort = 'DESC' THEN operation_id > :first_element_id END LIMIT 1
) AS prev_exists,
EXISTS(
SELECT 1 FROM filtered_set WHERE CASE WHEN $2 = 'ASC' THEN operation_id > $4 WHEN $2 = 'DESC' THEN operation_id < $4 END LIMIT 1
SELECT 1 FROM filtered_set WHERE CASE WHEN :sort = 'ASC' THEN operation_id > :last_element_id WHEN :sort = 'DESC' THEN operation_id < :last_element_id END LIMIT 1
) AS next_exists
`, filteredSetCTE)

argumentsMap := map[string]interface{}{
"address": address,
"first_element_id": firstElementID,
"last_element_id": lastElementID,
"sort": sort,
}

query, args, err := PrepareNamedQuery(ctx, m.DB, query, argumentsMap)
if err != nil {
return false, false, fmt.Errorf("preparing named query: %w", err)
}

var prevExists, nextExists bool
err := m.DB.QueryRowxContext(ctx, query, address, sort, firstElementID, lastElementID).Scan(&prevExists, &nextExists)
err = m.DB.QueryRowxContext(ctx, query, args...).Scan(&prevExists, &nextExists)
if err != nil {
return false, false, fmt.Errorf("fetching prev and next exists: %w", err)
}
Expand Down
22 changes: 22 additions & 0 deletions internal/data/query_utils.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package data

import (
"context"
"fmt"

"github.com/jmoiron/sqlx"
"github.com/stellar/wallet-backend/internal/db"
)

type SortOrder string

const (
Expand All @@ -10,3 +18,17 @@ const (
func (o SortOrder) IsValid() bool {
return o == ASC || o == DESC
}

func PrepareNamedQuery(ctx context.Context, connectionPool db.ConnectionPool, namedQuery string, argsMap map[string]interface{}) (string, []interface{}, error) {
query, args, err := sqlx.Named(namedQuery, argsMap)
if err != nil {
return "", nil, fmt.Errorf("replacing attributes with bindvars: %w", err)
}
query, args, err = sqlx.In(query, args...)
if err != nil {
return "", nil, fmt.Errorf("expanding slice arguments: %w", err)
}
query = connectionPool.Rebind(query)

return query, args, nil
}

0 comments on commit b3aac4f

Please sign in to comment.