Skip to content

Commit

Permalink
Extract common test setup methods for cross-package utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
imyousuf committed Nov 6, 2024
1 parent 7344ebb commit 8155dd4
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 94 deletions.
6 changes: 4 additions & 2 deletions storage/consumerrepo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (

var (
channel1, channel2, channel3 *data.Channel // Used by messagerepo_test as well
callbackURL, relativeURL *url.URL
relativeURL *url.URL
)

func getConsumerRepo() ConsumerRepository {
Expand All @@ -45,7 +45,9 @@ func SetupForConsumerTests() {
channel1 = createTestChannel("channel1-for-consumer", "sampletoken", channelRepo)
channel2 = createTestChannel("channel2-for-consumer", "sampletoken", channelRepo)
channel3 = createTestChannel("channel3-for-no-consumers", "sampletoken", channelRepo)
callbackURL = parseTestURL("https://imytech.net/")
if callbackURL == nil {
callbackURL = parseTestURL("https://imytech.net/")
}
relativeURL = parseTestURL("./test/")
}

Expand Down
60 changes: 3 additions & 57 deletions storage/deliveryjobrepo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"database/sql"
"errors"
"sort"
"strconv"
"testing"
"time"

Expand All @@ -23,65 +22,12 @@ var (
)

const (
consumerIDPrefix = "test-consumer-for-dj-"
messagePriority = 5
messagePriority = 5
)

type DeliveryJobSetupOptions struct {
ConsumerCount int
ConsumerIDPrefix string
ConsumerRepo ConsumerRepository
IgnoreSettingConsumers bool
ConsumerChannel *data.Channel
}

func (opt *DeliveryJobSetupOptions) GetConsumerCount() int {
if opt.ConsumerCount == 0 {
return 10
}
return opt.ConsumerCount
}

func (opt *DeliveryJobSetupOptions) GetConsumerIDPrefix() string {
if opt.ConsumerIDPrefix == "" {
return consumerIDPrefix
}
return opt.ConsumerIDPrefix
}

func (opt *DeliveryJobSetupOptions) GetConsumerRepo() ConsumerRepository {
if opt.ConsumerRepo == nil {
opt.ConsumerRepo = getConsumerRepo()
}
return opt.ConsumerRepo
}

func (opt *DeliveryJobSetupOptions) GetConsumerChannel() *data.Channel {
if opt.ConsumerChannel == nil {
return channel1
}
return opt.ConsumerChannel
}

func SetupForDeliveryJobTests() {
SetupForDeliveryJobTestsWithOptions(&DeliveryJobSetupOptions{})
}

func SetupForDeliveryJobTestsWithOptions(options *DeliveryJobSetupOptions) []*data.Consumer {
testConsumers := options.GetConsumerCount()
consumerRepo := getConsumerRepo()
internalConsumers := make([]*data.Consumer, 0, testConsumers)
for i := 0; i < testConsumers; i++ {
consumer, _ := data.NewConsumer(options.GetConsumerChannel(), options.GetConsumerIDPrefix()+strconv.Itoa(i),
successfulGetTestToken, callbackURL, "")
consumer.QuickFix()
consumerRepo.Store(consumer)
internalConsumers = append(internalConsumers, consumer)
}
if !options.IgnoreSettingConsumers {
consumers = internalConsumers
}
return internalConsumers
consumerRepo := NewConsumerRepository(testDB, NewChannelRepository(testDB))
consumers = SetupForDeliveryJobTestsWithOptions(&DeliveryJobSetupOptions{ConsumerRepo: consumerRepo, ConsumerChannel: channel1})
}

func getDeliverJobRepository() DeliveryJobRepository {
Expand Down
18 changes: 18 additions & 0 deletions storage/messagerepo_nix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//go:build linux || darwin

package storage

import (
"testing"

"github.com/go-sql-driver/mysql"
sqlite "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
)

func TestNormalizeMySQLError(t *testing.T) {
assert.Equal(t, ErrDuplicateMessageIDForChannel, normalizeDBError(&mysql.MySQLError{Number: 1062}, mysqlErrorMap))
assert.Nil(t, normalizeDBError(nil, mysqlErrorMap))
assert.Equal(t, ErrDuplicateMessageIDForChannel, normalizeDBError(&sqlite.ErrConstraint, mysqlErrorMap))
assert.Equal(t, ErrDuplicateMessageIDForChannel, normalizeDBError(&sqlite.ErrConstraintUnique, mysqlErrorMap))
}
92 changes: 58 additions & 34 deletions storage/messagerepo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@ import (
"github.com/rs/zerolog/log"

"github.com/DATA-DOG/go-sqlmock"
"github.com/go-sql-driver/mysql"
sqlite "github.com/mattn/go-sqlite3"
"github.com/newscred/webhook-broker/storage/data"
"github.com/newscred/webhook-broker/utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

const (
samplePayload = "some payload"
sampleContentType = "a content type"
duplicateMessageID = "a-duplicate-message-id"
consumerIDPrefixForPrune = "test-consumer-for-prune-"
)
Expand All @@ -34,13 +30,9 @@ var (

func SetupForMessageTests() {
producerRepo := NewProducerRepository(testDB)
producer, _ := data.NewProducer("producer1-for-message", successfulGetTestToken)
producer.QuickFix()
producer1, _ = producerRepo.Store(producer)
channelRepo := NewChannelRepository(testDB)
channelForPrune = createTestChannel("channel-for-prune", "sampletoken", channelRepo)
pruneConsumers = SetupForDeliveryJobTestsWithOptions(&DeliveryJobSetupOptions{IgnoreSettingConsumers: true,
ConsumerCount: 1, ConsumerIDPrefix: consumerIDPrefixForPrune, ConsumerChannel: channelForPrune})
consumerRepo := NewConsumerRepository(testDB, channelRepo)
producer1, channelForPrune, pruneConsumers = SetupMessageDependencyFixture(producerRepo, channelRepo, consumerRepo, consumerIDPrefixForPrune)
}

func getMessageRepository() MessageRepository {
Expand Down Expand Up @@ -235,13 +227,6 @@ func TestMessageSetDispatched(t *testing.T) {
})
}

func TestNormalizeMySQLError(t *testing.T) {
assert.Equal(t, ErrDuplicateMessageIDForChannel, normalizeDBError(&mysql.MySQLError{Number: 1062}, mysqlErrorMap))
assert.Nil(t, normalizeDBError(nil, mysqlErrorMap))
assert.Equal(t, ErrDuplicateMessageIDForChannel, normalizeDBError(&sqlite.ErrConstraint, mysqlErrorMap))
assert.Equal(t, ErrDuplicateMessageIDForChannel, normalizeDBError(&sqlite.ErrConstraintUnique, mysqlErrorMap))
}

func TestGetMessagesNotDispatchedForCertainPeriod(t *testing.T) {
t.Run("Success", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -286,34 +271,73 @@ func getPruneDeliveryJobsInFixture(msg *data.Message) []*data.DeliveryJob {
return jobs
}

// MockedDataAccessor is a mock implementation of DataAccessor for testing purposes.
type MockedDataAccessor struct {
mock.Mock
}

func (m *MockedDataAccessor) GetLockRepository() LockRepository {
args := m.Called()
return args.Get(0).(LockRepository)
}

func (m *MockedDataAccessor) GetAppRepository() AppRepository {
args := m.Called()
return args.Get(0).(AppRepository)
}

// GetMessageRepository mocks the GetMessageRepository method.
func (m *MockedDataAccessor) GetMessageRepository() MessageRepository {
args := m.Called()
return args.Get(0).(MessageRepository)
}

// GetDeliveryJobRepository mocks the GetDeliveryJobRepository method.
func (m *MockedDataAccessor) GetDeliveryJobRepository() DeliveryJobRepository {
args := m.Called()
return args.Get(0).(DeliveryJobRepository)
}

// GetProducerRepository mocks the GetProducerRepository method.
func (m *MockedDataAccessor) GetProducerRepository() ProducerRepository {
args := m.Called()
return args.Get(0).(ProducerRepository)
}

// GetChannelRepository mocks the GetChannelRepository method.
func (m *MockedDataAccessor) GetChannelRepository() ChannelRepository {
args := m.Called()
return args.Get(0).(ChannelRepository)
}

// GetConsumerRepository mocks the GetConsumerRepository method.
func (m *MockedDataAccessor) GetConsumerRepository() ConsumerRepository {
args := m.Called()
return args.Get(0).(ConsumerRepository)
}

// Close mocks the Close method.
func (m *MockedDataAccessor) Close() {
m.Called()
}

func TestGetMessagesFromBeforeDurationThatAreCompletelyDelivered(t *testing.T) {
deliverJobRepo := getDeliverJobRepository()
msgRepo := getMessageRepository()

t.Run("Success", func(t *testing.T) {
t.Parallel()
msg, _ := data.NewMessage(channelForPrune, producer1, samplePayload, sampleContentType, data.HeadersMap{})
msg.ReceivedAt = msg.ReceivedAt.Add(-50 * time.Second)
msgRepo.Create(msg)
jobs := getPruneDeliveryJobsInFixture(msg)
deliverJobRepo.DispatchMessage(msg, jobs...)
for index := range jobs {
markJobDelivered(deliverJobRepo, jobs[index])
}
dataAccessor := new(MockedDataAccessor)
dataAccessor.On("GetMessageRepository").Return(msgRepo)
dataAccessor.On("GetDeliveryJobRepository").Return(getDeliverJobRepository())
msg, _ := SetupPruneableMessageFixture(dataAccessor, channelForPrune, producer1, pruneConsumers, 50)
pruneAbleMessages := msgRepo.GetMessagesFromBeforeDurationThatAreCompletelyDelivered(40*time.Second, 1000)
assert.Equal(t, 1, len(pruneAbleMessages))
assert.Equal(t, msg.MessageID, pruneAbleMessages[0].MessageID)
// create such that pagination query gets triggered
iterLength := 110
msgIds := make([]string, iterLength+1)
for i := 0; i < iterLength; i++ {
msg, _ = data.NewMessage(channelForPrune, producer1, samplePayload, sampleContentType, data.HeadersMap{})
msg.ReceivedAt = msg.ReceivedAt.Add(-50 * time.Second)
msgRepo.Create(msg)
jobs = getPruneDeliveryJobsInFixture(msg)
deliverJobRepo.DispatchMessage(msg, jobs...)
for index := range jobs {
markJobDelivered(deliverJobRepo, jobs[index])
}
msg, _ := SetupPruneableMessageFixture(dataAccessor, channelForPrune, producer1, pruneConsumers, 50)
msgIds[i] = msg.MessageID
}
msgIds[iterLength] = pruneAbleMessages[0].MessageID
Expand Down
1 change: 0 additions & 1 deletion storage/producerrepo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

const (
successfulGetTestProducerID = "get-test"
successfulGetTestToken = "sometokenforget"
nonExistingGetTestProducerID = "get-test-ne"
successfulInsertTestProducerID = "s-insert-test"
invalidStateUpdateTestProducerID = "i-update-test"
Expand Down
109 changes: 109 additions & 0 deletions storage/testutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package storage

import (
"net/url"
"strconv"
"time"

data "github.com/newscred/webhook-broker/storage/data"
"github.com/rs/zerolog/log"
)

const (
samplePayload = "some payload"
sampleContentType = "a content type"
successfulGetTestToken = "sometokenforget"
consumerIDPrefix = "test-consumer-for-dj-"
)

var (
callbackURL *url.URL
)

type DeliveryJobSetupOptions struct {
ConsumerCount int
ConsumerIDPrefix string
ConsumerRepo ConsumerRepository
IgnoreSettingConsumers bool
ConsumerChannel *data.Channel
}

func (opt *DeliveryJobSetupOptions) GetConsumerCount() int {
if opt.ConsumerCount == 0 {
return 10
}
return opt.ConsumerCount
}

func (opt *DeliveryJobSetupOptions) GetConsumerIDPrefix() string {
if opt.ConsumerIDPrefix == "" {
return consumerIDPrefix
}
return opt.ConsumerIDPrefix
}

func (opt *DeliveryJobSetupOptions) GetConsumerRepo() ConsumerRepository {
return opt.ConsumerRepo
}

func (opt *DeliveryJobSetupOptions) GetConsumerChannel() *data.Channel {
return opt.ConsumerChannel
}

func SetupForDeliveryJobTestsWithOptions(options *DeliveryJobSetupOptions) []*data.Consumer {
testConsumers := options.GetConsumerCount()
consumerRepo := options.GetConsumerRepo()
internalConsumers := make([]*data.Consumer, 0, testConsumers)
if callbackURL == nil {
callbackURL, _ = url.Parse("https://imytech.net/")
}
for i := 0; i < testConsumers; i++ {
consumer, _ := data.NewConsumer(options.GetConsumerChannel(), options.GetConsumerIDPrefix()+strconv.Itoa(i),
successfulGetTestToken, callbackURL, "")
consumer.QuickFix()
consumerRepo.Store(consumer)
internalConsumers = append(internalConsumers, consumer)
}
return internalConsumers
}

func SetupMessageDependencyFixture(producerRepo ProducerRepository, channelRepo ChannelRepository,
consumerRepo ConsumerRepository, consumerIDPrefix string) (*data.Producer, *data.Channel, []*data.Consumer) {
producer, _ := data.NewProducer("producer1-for-message", successfulGetTestToken)
producer.QuickFix()
localProducer, _ := producerRepo.Store(producer)
var err error
localChannel, _ := data.NewChannel("channel-for-prune", "sampletoken")
if localChannel, err = channelRepo.Store(localChannel); err != nil {
log.Fatal().Err(err)
}
thisConsumers := SetupForDeliveryJobTestsWithOptions(&DeliveryJobSetupOptions{IgnoreSettingConsumers: true,
ConsumerCount: 1, ConsumerIDPrefix: consumerIDPrefix, ConsumerChannel: localChannel, ConsumerRepo: consumerRepo})
return localProducer, localChannel, thisConsumers
}

func SetupPruneableMessageFixture(dataAccessor DataAccessor, channel *data.Channel, producer *data.Producer,
pruneConsumers []*data.Consumer, lag int) (*data.Message, []*data.DeliveryJob) {
msg, _ := data.NewMessage(channel, producer, samplePayload, sampleContentType, data.HeadersMap{})
msg.ReceivedAt = msg.ReceivedAt.Add(time.Duration(-1*lag) * time.Second)
msgRepo := dataAccessor.GetMessageRepository()
msgRepo.Create(msg)
jobs := make([]*data.DeliveryJob, 0, len(pruneConsumers))
for _, consumer := range pruneConsumers {
job, _ := data.NewDeliveryJob(msg, consumer)
jobs = append(jobs, job)
}
deliverJobRepo := dataAccessor.GetDeliveryJobRepository()
deliverJobRepo.DispatchMessage(msg, jobs...)
for index := range jobs {
err := deliverJobRepo.MarkJobInflight(jobs[index])
if err != nil {
log.Error().Err(err).Msg("Error marking job inflight")
}
err = deliverJobRepo.MarkJobDelivered(jobs[index])
if err != nil {
log.Error().Err(err).Msg("Error marking job delivered")
}
}
return msg, jobs
}

0 comments on commit 8155dd4

Please sign in to comment.