diff --git a/enterprise/reporting/event_sampler/badger_event_sampler.go b/enterprise/reporting/event_sampler/badger_event_sampler.go new file mode 100644 index 00000000000..ac0e1ba79b4 --- /dev/null +++ b/enterprise/reporting/event_sampler/badger_event_sampler.go @@ -0,0 +1,156 @@ +package event_sampler + +import ( + "context" + "fmt" + "os" + "sync" + "time" + + "github.com/dgraph-io/badger/v4" + "github.com/dgraph-io/badger/v4/options" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-server/rruntime" + "github.com/rudderlabs/rudder-server/utils/misc" +) + +type BadgerEventSampler struct { + db *badger.DB + mu sync.Mutex + ttl config.ValueLoader[time.Duration] + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +func DefaultPath(pathName string) (string, error) { + tmpDirPath, err := misc.CreateTMPDIR() + if err != nil { + return "", err + } + return fmt.Sprintf(`%v%v`, tmpDirPath, pathName), nil +} + +func NewBadgerEventSampler(ctx context.Context, pathName string, ttl config.ValueLoader[time.Duration], conf *config.Config, log logger.Logger) (*BadgerEventSampler, error) { + dbPath, err := DefaultPath(pathName) + if err != nil || dbPath == "" { + return nil, err + } + + opts := badger.DefaultOptions(dbPath). + WithLogger(badgerLogger{log}). + WithCompression(options.None). + WithIndexCacheSize(16 << 20). // 16mb + WithNumGoroutines(1). + WithBlockCacheSize(0). + WithNumVersionsToKeep(1). + WithNumMemtables(conf.GetInt("Reporting.eventSampling.badgerDB.numMemtable", 5)). + WithValueThreshold(conf.GetInt64("Reporting.eventSampling.badgerDB.valueThreshold", 1048576)). + WithNumLevelZeroTables(conf.GetInt("Reporting.eventSampling.badgerDB.numLevelZeroTables", 5)). + WithNumLevelZeroTablesStall(conf.GetInt("Reporting.eventSampling.badgerDB.numLevelZeroTablesStall", 15)). + WithSyncWrites(conf.GetBool("Reporting.eventSampling.badgerDB.syncWrites", false)). + WithDetectConflicts(conf.GetBool("Reporting.eventSampling.badgerDB.detectConflicts", false)) + + ctx, cancel := context.WithCancel(ctx) + + db, err := badger.Open(opts) + + es := &BadgerEventSampler{ + db: db, + ttl: ttl, + ctx: ctx, + cancel: cancel, + wg: sync.WaitGroup{}, + } + + if err != nil { + return nil, err + } + + es.wg.Add(1) + rruntime.Go(func() { + defer es.wg.Done() + es.gcLoop() + }) + + return es, nil +} + +func (es *BadgerEventSampler) Get(key string) (bool, error) { + es.mu.Lock() + defer es.mu.Unlock() + + var found bool + + err := es.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(key)) + if err != nil { + return err + } + + found = item != nil + return nil + }) + + if err == badger.ErrKeyNotFound { + return false, nil + } else if err != nil { + return false, err + } + + return found, nil +} + +func (es *BadgerEventSampler) Put(key string) error { + es.mu.Lock() + defer es.mu.Unlock() + + return es.db.Update(func(txn *badger.Txn) error { + entry := badger.NewEntry([]byte(key), []byte{1}).WithTTL(es.ttl.Load()) + return txn.SetEntry(entry) + }) +} + +func (es *BadgerEventSampler) gcLoop() { + for { + select { + case <-es.ctx.Done(): + _ = es.db.RunValueLogGC(0.5) + return + case <-time.After(5 * time.Minute): + } + again: + if es.ctx.Err() != nil { + return + } + // One call would only result in removal of at max one log file. + // As an optimization, you could also immediately re-run it whenever it returns nil error + // (this is why `goto again` is used). + err := es.db.RunValueLogGC(0.5) + if err == nil { + goto again + } + } +} + +func (es *BadgerEventSampler) Close() { + es.cancel() + es.wg.Wait() + if es.db != nil { + _ = es.db.Close() + } +} + +type badgerLogger struct { + logger.Logger +} + +func (badgerLogger) Errorf(format string, a ...interface{}) { + _, _ = fmt.Fprintf(os.Stderr, format, a...) +} + +func (badgerLogger) Warningf(format string, a ...interface{}) { + _, _ = fmt.Fprintf(os.Stderr, format, a...) +} diff --git a/enterprise/reporting/event_sampler/event_sampler.go b/enterprise/reporting/event_sampler/event_sampler.go new file mode 100644 index 00000000000..8e3da3fd0e6 --- /dev/null +++ b/enterprise/reporting/event_sampler/event_sampler.go @@ -0,0 +1,45 @@ +package event_sampler + +import ( + "context" + "time" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" +) + +const ( + BadgerTypeEventSampler = "badger" + InMemoryCacheTypeEventSampler = "in_memory_cache" + BadgerEventSamplerPathName = "/reporting-badger" +) + +type EventSampler interface { + Put(key string) error + Get(key string) (bool, error) + Close() +} + +func NewEventSampler( + ctx context.Context, + ttl config.ValueLoader[time.Duration], + eventSamplerType config.ValueLoader[string], + eventSamplingCardinality config.ValueLoader[int], + conf *config.Config, + log logger.Logger, +) (es EventSampler, err error) { + switch eventSamplerType.Load() { + case BadgerTypeEventSampler: + es, err = NewBadgerEventSampler(ctx, BadgerEventSamplerPathName, ttl, conf, log) + case InMemoryCacheTypeEventSampler: + es, err = NewInMemoryCacheEventSampler(ctx, ttl, eventSamplingCardinality) + default: + log.Warnf("invalid event sampler type: %s. Using default badger event sampler", eventSamplerType.Load()) + es, err = NewBadgerEventSampler(ctx, BadgerEventSamplerPathName, ttl, conf, log) + } + + if err != nil { + return nil, err + } + return es, nil +} diff --git a/enterprise/reporting/event_sampler/event_sampler_test.go b/enterprise/reporting/event_sampler/event_sampler_test.go new file mode 100644 index 00000000000..4d47f241234 --- /dev/null +++ b/enterprise/reporting/event_sampler/event_sampler_test.go @@ -0,0 +1,178 @@ +package event_sampler + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" +) + +func TestBadger(t *testing.T) { + ctx := context.Background() + conf := config.New() + ttl := conf.GetReloadableDurationVar(3000, time.Millisecond, "Reporting.eventSampling.durationInMinutes") + eventSamplerType := conf.GetReloadableStringVar("badger", "Reporting.eventSampling.type") + eventSamplingCardinality := conf.GetReloadableIntVar(10, 1, "Reporting.eventSampling.cardinality") + log := logger.NewLogger() + + t.Run("should put and get keys", func(t *testing.T) { + assert.Equal(t, 3000*time.Millisecond, ttl.Load()) + es, _ := NewEventSampler(ctx, ttl, eventSamplerType, eventSamplingCardinality, conf, log) + _ = es.Put("key1") + _ = es.Put("key2") + _ = es.Put("key3") + val1, _ := es.Get("key1") + val2, _ := es.Get("key2") + val3, _ := es.Get("key3") + val4, _ := es.Get("key4") + + assert.True(t, val1, "Expected key1 to be present") + assert.True(t, val2, "Expected key2 to be present") + assert.True(t, val3, "Expected key3 to be present") + assert.False(t, val4, "Expected key4 to not be present") + es.Close() + }) + + t.Run("should not get evicted keys", func(t *testing.T) { + conf.Set("Reporting.eventSampling.durationInMinutes", 100) + assert.Equal(t, 100*time.Millisecond, ttl.Load()) + + es, _ := NewEventSampler(ctx, ttl, eventSamplerType, eventSamplingCardinality, conf, log) + defer es.Close() + + _ = es.Put("key1") + + require.Eventually(t, func() bool { + val1, _ := es.Get("key1") + return !val1 + }, 1*time.Second, 50*time.Millisecond) + }) +} + +func TestInMemoryCache(t *testing.T) { + ctx := context.Background() + conf := config.New() + eventSamplerType := conf.GetReloadableStringVar("in_memory_cache", "Reporting.eventSampling.type") + eventSamplingCardinality := conf.GetReloadableIntVar(3, 1, "Reporting.eventSampling.cardinality") + ttl := conf.GetReloadableDurationVar(3000, time.Millisecond, "Reporting.eventSampling.durationInMinutes") + log := logger.NewLogger() + + t.Run("should put and get keys", func(t *testing.T) { + assert.Equal(t, 3000*time.Millisecond, ttl.Load()) + es, _ := NewEventSampler(ctx, ttl, eventSamplerType, eventSamplingCardinality, conf, log) + _ = es.Put("key1") + _ = es.Put("key2") + _ = es.Put("key3") + val1, _ := es.Get("key1") + val2, _ := es.Get("key2") + val3, _ := es.Get("key3") + val4, _ := es.Get("key4") + + assert.True(t, val1, "Expected key1 to be present") + assert.True(t, val2, "Expected key2 to be present") + assert.True(t, val3, "Expected key3 to be present") + assert.False(t, val4, "Expected key4 to not be present") + }) + + t.Run("should not get evicted keys", func(t *testing.T) { + conf.Set("Reporting.eventSampling.durationInMinutes", 100) + assert.Equal(t, 100*time.Millisecond, ttl.Load()) + es, _ := NewEventSampler(ctx, ttl, eventSamplerType, eventSamplingCardinality, conf, log) + _ = es.Put("key1") + + require.Eventually(t, func() bool { + val1, _ := es.Get("key1") + return !val1 + }, 1*time.Second, 50*time.Millisecond) + }) + + t.Run("should not add keys if length exceeds", func(t *testing.T) { + conf.Set("Reporting.eventSampling.durationInMinutes", 3000) + assert.Equal(t, 3000*time.Millisecond, ttl.Load()) + es, _ := NewEventSampler(ctx, ttl, eventSamplerType, eventSamplingCardinality, conf, log) + _ = es.Put("key1") + _ = es.Put("key2") + _ = es.Put("key3") + _ = es.Put("key4") + _ = es.Put("key5") + + val1, _ := es.Get("key1") + val2, _ := es.Get("key2") + val3, _ := es.Get("key3") + val4, _ := es.Get("key4") + val5, _ := es.Get("key5") + + assert.True(t, val1, "Expected key1 to be present") + assert.True(t, val2, "Expected key2 to be present") + assert.True(t, val3, "Expected key3 to be present") + assert.False(t, val4, "Expected key4 to not be added") + assert.False(t, val5, "Expected key5 to not be added") + }) +} + +func BenchmarkEventSampler(b *testing.B) { + testCases := []struct { + name string + eventSamplerType string + }{ + { + name: "Badger", + eventSamplerType: "badger", + }, + { + name: "InMemoryCache", + eventSamplerType: "in_memory_cache", + }, + } + + ctx := context.Background() + conf := config.New() + ttl := conf.GetReloadableDurationVar(1, time.Minute, "Reporting.eventSampling.durationInMinutes") + eventSamplerType := conf.GetReloadableStringVar("default", "Reporting.eventSampling.type") + eventSamplingCardinality := conf.GetReloadableIntVar(10, 1, "Reporting.eventSampling.cardinality") + log := logger.NewLogger() + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + conf.Set("Reporting.eventSampling.type", tc.eventSamplerType) + + eventSampler, err := NewEventSampler( + ctx, + ttl, + eventSamplerType, + eventSamplingCardinality, + conf, + log, + ) + require.NoError(b, err) + + b.Run("Put", func(b *testing.B) { + for i := 0; i < b.N; i++ { + key := uuid.New().String() + err := eventSampler.Put(key) + require.NoError(b, err) + } + }) + + b.Run("Get", func(b *testing.B) { + for i := 0; i < b.N; i++ { + key := uuid.New().String() + + err := eventSampler.Put(key) + require.NoError(b, err) + + _, err = eventSampler.Get(key) + require.NoError(b, err) + } + }) + + eventSampler.Close() + }) + } +} diff --git a/enterprise/reporting/event_sampler/in_memory_cache_event_sampler.go b/enterprise/reporting/event_sampler/in_memory_cache_event_sampler.go new file mode 100644 index 00000000000..a8e413fc8ef --- /dev/null +++ b/enterprise/reporting/event_sampler/in_memory_cache_event_sampler.go @@ -0,0 +1,57 @@ +package event_sampler + +import ( + "context" + "time" + + "github.com/rudderlabs/rudder-go-kit/cachettl" + "github.com/rudderlabs/rudder-go-kit/config" +) + +type InMemoryCacheEventSampler struct { + ctx context.Context + cancel context.CancelFunc + cache *cachettl.Cache[string, bool] + ttl config.ValueLoader[time.Duration] + limit config.ValueLoader[int] + length int +} + +func NewInMemoryCacheEventSampler(ctx context.Context, ttl config.ValueLoader[time.Duration], limit config.ValueLoader[int]) (*InMemoryCacheEventSampler, error) { + c := cachettl.New[string, bool](cachettl.WithNoRefreshTTL) + ctx, cancel := context.WithCancel(ctx) + + es := &InMemoryCacheEventSampler{ + ctx: ctx, + cancel: cancel, + cache: c, + ttl: ttl, + limit: limit, + length: 0, + } + + es.cache.OnEvicted(func(key string, value bool) { + es.length-- + }) + + return es, nil +} + +func (es *InMemoryCacheEventSampler) Get(key string) (bool, error) { + value := es.cache.Get(key) + return value, nil +} + +func (es *InMemoryCacheEventSampler) Put(key string) error { + if es.length >= es.limit.Load() { + return nil + } + + es.cache.Put(key, true, es.ttl.Load()) + es.length++ + return nil +} + +func (es *InMemoryCacheEventSampler) Close() { + es.cancel() +} diff --git a/enterprise/reporting/label_set.go b/enterprise/reporting/label_set.go new file mode 100644 index 00000000000..8910763212c --- /dev/null +++ b/enterprise/reporting/label_set.go @@ -0,0 +1,70 @@ +package reporting + +import ( + "encoding/hex" + "strconv" + + "github.com/spaolacci/murmur3" + + "github.com/rudderlabs/rudder-server/utils/types" +) + +type LabelSet struct { + WorkspaceID string + SourceDefinitionID string + SourceCategory string + SourceID string + DestinationDefinitionID string + DestinationID string + SourceTaskRunID string + SourceJobID string + SourceJobRunID string + TransformationID string + TransformationVersionID string + TrackingPlanID string + TrackingPlanVersion int + InPU string + PU string + Status string + TerminalState bool + InitialState bool + StatusCode int + EventName string + EventType string + ErrorType string + Bucket int64 +} + +func NewLabelSet(metric types.PUReportedMetric, bucket int64) LabelSet { + return LabelSet{ + WorkspaceID: metric.ConnectionDetails.SourceID, + SourceDefinitionID: metric.ConnectionDetails.SourceDefinitionID, + SourceCategory: metric.ConnectionDetails.SourceCategory, + SourceID: metric.ConnectionDetails.SourceID, + DestinationDefinitionID: metric.ConnectionDetails.DestinationDefinitionID, + DestinationID: metric.ConnectionDetails.DestinationID, + SourceTaskRunID: metric.ConnectionDetails.SourceTaskRunID, + SourceJobID: metric.ConnectionDetails.SourceJobID, + SourceJobRunID: metric.ConnectionDetails.SourceJobRunID, + TransformationID: metric.ConnectionDetails.TransformationID, + TransformationVersionID: metric.ConnectionDetails.TransformationVersionID, + TrackingPlanID: metric.ConnectionDetails.TrackingPlanID, + TrackingPlanVersion: metric.ConnectionDetails.TrackingPlanVersion, + InPU: metric.PUDetails.InPU, + PU: metric.PUDetails.PU, + Status: metric.StatusDetail.Status, + TerminalState: metric.PUDetails.TerminalPU, + InitialState: metric.PUDetails.InitialPU, + StatusCode: metric.StatusDetail.StatusCode, + EventName: metric.StatusDetail.EventName, + EventType: metric.StatusDetail.EventType, + ErrorType: metric.StatusDetail.ErrorType, + Bucket: bucket, + } +} + +func (labelSet LabelSet) generateHash() string { + data := labelSet.WorkspaceID + labelSet.SourceDefinitionID + labelSet.SourceCategory + labelSet.SourceID + labelSet.DestinationDefinitionID + labelSet.DestinationID + labelSet.SourceTaskRunID + labelSet.SourceJobID + labelSet.SourceJobRunID + labelSet.TransformationID + labelSet.TransformationVersionID + labelSet.TrackingPlanID + strconv.Itoa(labelSet.TrackingPlanVersion) + labelSet.InPU + labelSet.PU + labelSet.Status + strconv.FormatBool(labelSet.TerminalState) + strconv.FormatBool(labelSet.InitialState) + strconv.Itoa(labelSet.StatusCode) + labelSet.EventName + labelSet.EventType + labelSet.ErrorType + strconv.FormatInt(labelSet.Bucket, 10) + hash := murmur3.Sum64([]byte(data)) + return hex.EncodeToString([]byte(strconv.FormatUint(hash, 16))) +} diff --git a/enterprise/reporting/label_set_test.go b/enterprise/reporting/label_set_test.go new file mode 100644 index 00000000000..c4722761a07 --- /dev/null +++ b/enterprise/reporting/label_set_test.go @@ -0,0 +1,81 @@ +package reporting + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/rudderlabs/rudder-server/utils/types" +) + +func createMetricObject(eventName string) types.PUReportedMetric { + return types.PUReportedMetric{ + ConnectionDetails: types.ConnectionDetails{ + SourceID: "some-source-id", + DestinationID: "some-destination-id", + }, + PUDetails: types.PUDetails{ + InPU: "some-in-pu", + PU: "some-pu", + }, + StatusDetail: &types.StatusDetail{ + Status: "some-status", + Count: 3, + StatusCode: 0, + SampleResponse: `{"some-sample-response-key": "some-sample-response-value"}`, + SampleEvent: []byte(`{"some-sample-event-key": "some-sample-event-value"}`), + EventName: eventName, + EventType: "some-event-type", + }, + } +} + +func TestNewLabelSet(t *testing.T) { + t.Run("should create the correct LabelSet from types.PUReportedMetric", func(t *testing.T) { + inputMetric := createMetricObject("some-event-name") + bucket := int64(28889820) + labelSet := NewLabelSet(inputMetric, bucket) + + assert.Equal(t, "some-source-id", labelSet.SourceID) + assert.Equal(t, "some-event-name", labelSet.EventName) // Default value + }) + + t.Run("should create the correct LabelSet with custom EventName", func(t *testing.T) { + inputMetric := createMetricObject("custom-event-name") + bucket := int64(28889820) + labelSet := NewLabelSet(inputMetric, bucket) + + assert.Equal(t, "some-source-id", labelSet.SourceID) + assert.Equal(t, "custom-event-name", labelSet.EventName) // Custom event name + }) +} + +func TestGenerateHash(t *testing.T) { + t.Run("same hash for same LabelSet", func(t *testing.T) { + inputMetric1 := createMetricObject("some-event-name") + bucket := int64(28889820) + labelSet1 := NewLabelSet(inputMetric1, bucket) + + inputMetric2 := createMetricObject("some-event-name") + labelSet2 := NewLabelSet(inputMetric2, bucket) + + hash1 := labelSet1.generateHash() + hash2 := labelSet2.generateHash() + + assert.Equal(t, hash1, hash2) + }) + + t.Run("different hash for different LabelSet", func(t *testing.T) { + inputMetric1 := createMetricObject("some-event-name-1") + bucket := int64(28889820) + labelSet1 := NewLabelSet(inputMetric1, bucket) + + inputMetric2 := createMetricObject("some-event-name-2") + labelSet2 := NewLabelSet(inputMetric2, bucket) + + hash1 := labelSet1.generateHash() + hash2 := labelSet2.generateHash() + + assert.NotEqual(t, hash1, hash2) + }) +} diff --git a/enterprise/reporting/reporting.go b/enterprise/reporting/reporting.go index e411c1079fa..b8f5cf8d157 100644 --- a/enterprise/reporting/reporting.go +++ b/enterprise/reporting/reporting.go @@ -27,6 +27,7 @@ import ( "github.com/rudderlabs/rudder-go-kit/stats" obskit "github.com/rudderlabs/rudder-observability-kit/go/labels" + "github.com/rudderlabs/rudder-server/enterprise/reporting/event_sampler" migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" "github.com/rudderlabs/rudder-server/utils/httputil" . "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck @@ -75,10 +76,15 @@ type DefaultReporter struct { requestLatency stats.Measurement stats stats.Stats maxReportsCountInARequest config.ValueLoader[int] + + eventSamplingEnabled config.ValueLoader[bool] + eventSamplingDuration config.ValueLoader[time.Duration] + eventSampler event_sampler.EventSampler } func NewDefaultReporter(ctx context.Context, conf *config.Config, log logger.Logger, configSubscriber *configSubscriber, stats stats.Stats) *DefaultReporter { var dbQueryTimeout *config.Reloadable[time.Duration] + var eventSampler event_sampler.EventSampler reportingServiceURL := config.GetString("REPORTING_URL", "https://reporting.rudderstack.com/") reportingServiceURL = strings.TrimSuffix(reportingServiceURL, "/") @@ -90,11 +96,23 @@ func NewDefaultReporter(ctx context.Context, conf *config.Config, log logger.Log maxOpenConnections := config.GetIntVar(32, 1, "Reporting.maxOpenConnections") dbQueryTimeout = config.GetReloadableDurationVar(60, time.Second, "Reporting.dbQueryTimeout") maxReportsCountInARequest := conf.GetReloadableIntVar(10, 1, "Reporting.maxReportsCountInARequest") + eventSamplingEnabled := conf.GetReloadableBoolVar(false, "Reporting.eventSampling.enabled") + eventSamplingDuration := conf.GetReloadableDurationVar(60, time.Minute, "Reporting.eventSampling.durationInMinutes") + eventSamplerType := conf.GetReloadableStringVar("badger", "Reporting.eventSampling.type") + eventSamplingCardinality := conf.GetReloadableIntVar(100000, 1, "Reporting.eventSampling.cardinality") // only send reports for wh actions sources if whActionsOnly is configured whActionsOnly := config.GetBool("REPORTING_WH_ACTIONS_ONLY", false) if whActionsOnly { log.Info("REPORTING_WH_ACTIONS_ONLY enabled.only sending reports relevant to wh actions.") } + + if eventSamplingEnabled.Load() { + var err error + eventSampler, err = event_sampler.NewEventSampler(ctx, eventSamplingDuration, eventSamplerType, eventSamplingCardinality, conf, log) + if err != nil { + panic(err) + } + } ctx, cancel := context.WithCancel(ctx) g, ctx := errgroup.WithContext(ctx) return &DefaultReporter{ @@ -118,6 +136,9 @@ func NewDefaultReporter(ctx context.Context, conf *config.Config, log logger.Log dbQueryTimeout: dbQueryTimeout, maxReportsCountInARequest: maxReportsCountInARequest, stats: stats, + eventSamplingEnabled: eventSamplingEnabled, + eventSamplingDuration: eventSamplingDuration, + eventSampler: eventSampler, } } @@ -191,7 +212,7 @@ func (r *DefaultReporter) getDBHandle(syncerKey string) (*sql.DB, error) { return nil, fmt.Errorf("DBHandle not found for syncer key: %s", syncerKey) } -func (r *DefaultReporter) getReports(currentMs int64, syncerKey string) (reports []*types.ReportByStatus, reportedAt int64, err error) { +func (r *DefaultReporter) getReports(currentMs, aggregationIntervalMin int64, syncerKey string) (reports []*types.ReportByStatus, reportedAt int64, err error) { sqlStatement := fmt.Sprintf(`SELECT min(reported_at) FROM %s WHERE reported_at < $1`, ReportsTable) var queryMin sql.NullInt64 dbHandle, err := r.getDBHandle(syncerKey) @@ -216,11 +237,35 @@ func (r *DefaultReporter) getReports(currentMs int64, syncerKey string) (reports return nil, 0, nil } - groupByColumns := "workspace_id, namespace, instance_id, source_definition_id, source_category, source_id, destination_definition_id, destination_id, source_task_run_id, source_job_id, source_job_run_id, transformation_id, transformation_version_id, tracking_plan_id, tracking_plan_version, in_pu, pu, reported_at, status, terminal_state, initial_state, status_code, event_name, event_type, error_type" - sqlStatement = fmt.Sprintf(`SELECT %s, (ARRAY_AGG(sample_response order by id))[1], (ARRAY_AGG(sample_event order by id))[1], SUM(count), SUM(violation_count) FROM %s WHERE reported_at = $1 GROUP BY %s`, groupByColumns, ReportsTable, groupByColumns) + bucketStart, bucketEnd := r.getAggregationBucketMinute(queryMin.Int64, aggregationIntervalMin) + // we don't want to flush partial buckets, so we wait for the current bucket to be complete + if bucketEnd > currentMs { + return nil, 0, nil + } + + groupByColumns := "workspace_id, namespace, instance_id, source_definition_id, source_category, source_id, destination_definition_id, destination_id, source_task_run_id, source_job_id, source_job_run_id, transformation_id, transformation_version_id, tracking_plan_id, tracking_plan_version, in_pu, pu, status, terminal_state, initial_state, status_code, event_name, event_type, error_type" + sqlStatement = fmt.Sprintf(` + SELECT + %s, MAX(reported_at), + COALESCE( + (ARRAY_AGG(sample_response ORDER BY id DESC) FILTER (WHERE sample_event != '{}'::jsonb))[1], + '' + ) AS sample_response, + COALESCE( + (ARRAY_AGG(sample_event ORDER BY id DESC) FILTER (WHERE sample_event != '{}'::jsonb))[1], + '{}'::jsonb + ) AS sample_event, + SUM(count), + SUM(violation_count) + FROM + %s + WHERE + reported_at >= $1 and reported_at < $2 + GROUP BY + %s`, groupByColumns, ReportsTable, groupByColumns) var rows *sql.Rows queryStart = time.Now() - rows, err = dbHandle.Query(sqlStatement, queryMin.Int64) + rows, err = dbHandle.Query(sqlStatement, bucketStart, bucketEnd) if err != nil { panic(err) } @@ -246,12 +291,12 @@ func (r *DefaultReporter) getReports(currentMs int64, syncerKey string) (reports &metricReport.ConnectionDetails.TrackingPlanID, &metricReport.ConnectionDetails.TrackingPlanVersion, &metricReport.PUDetails.InPU, &metricReport.PUDetails.PU, - &metricReport.ReportedAt, &metricReport.StatusDetail.Status, &metricReport.PUDetails.TerminalPU, &metricReport.PUDetails.InitialPU, &metricReport.StatusDetail.StatusCode, &metricReport.StatusDetail.EventName, &metricReport.StatusDetail.EventType, &metricReport.StatusDetail.ErrorType, + &metricReport.ReportedAt, &metricReport.StatusDetail.SampleResponse, &metricReport.StatusDetail.SampleEvent, &metricReport.StatusDetail.Count, &metricReport.StatusDetail.ViolationCount, ) @@ -271,6 +316,7 @@ func (r *DefaultReporter) getReports(currentMs int64, syncerKey string) (reports func (r *DefaultReporter) getAggregatedReports(reports []*types.ReportByStatus) []*types.Metric { metricsByGroup := map[string]*types.Metric{} maxReportsCountInARequest := r.maxReportsCountInARequest.Load() + sampleEventBucket, _ := r.getAggregationBucketMinute(reports[0].ReportedAt, int64(r.eventSamplingDuration.Load().Minutes())) var values []*types.Metric reportIdentifier := func(report *types.ReportByStatus) string { @@ -286,6 +332,8 @@ func (r *DefaultReporter) getAggregatedReports(reports []*types.ReportByStatus) report.ConnectionDetails.TrackingPlanID, strconv.Itoa(report.ConnectionDetails.TrackingPlanVersion), report.PUDetails.InPU, report.PUDetails.PU, + strconv.FormatBool(report.TerminalPU), strconv.FormatBool(report.InitialPU), + strconv.FormatInt(report.ReportedAt, 10), } return strings.Join(groupingIdentifiers, `::`) } @@ -320,7 +368,8 @@ func (r *DefaultReporter) getAggregatedReports(reports []*types.ReportByStatus) InitialPU: report.InitialPU, }, ReportMetadata: types.ReportMetadata{ - ReportedAt: report.ReportedAt * 60 * 1000, // send reportedAt in milliseconds + ReportedAt: report.ReportedAt * 60 * 1000, // send reportedAt in milliseconds + SampleEventBucket: sampleEventBucket * 60 * 1000, }, } values = append(values, metricsByGroup[identifier]) @@ -342,6 +391,30 @@ func (r *DefaultReporter) getAggregatedReports(reports []*types.ReportByStatus) return values } +func (*DefaultReporter) getAggregationBucketMinute(timeMs, intervalMs int64) (int64, int64) { + // If interval is not a factor of 60, then the bucket start will not be aligned to hour start + // For example, if intervalMs is 7, and timeMs is 28891085 (6:05) then the bucket start will be 28891079 (5:59) + // and current bucket will contain the data of 2 different hourly buckets, which is should not have happened. + // To avoid this, we round the intervalMs to the nearest factor of 60. + if intervalMs <= 0 || 60%intervalMs != 0 { + factors := []int64{1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60} + closestFactor := factors[0] + for _, factor := range factors { + if factor < intervalMs { + closestFactor = factor + } else { + break + } + } + intervalMs = closestFactor + } + + bucketStart := timeMs - (timeMs % intervalMs) + bucketEnd := bucketStart + intervalMs + + return bucketStart, bucketEnd +} + func (r *DefaultReporter) emitLagMetric(ctx context.Context, c types.SyncerConfig, lastReportedAtTime *atomic.Time) error { // for monitoring reports pileups reportingLag := r.stats.NewTaggedStat( @@ -390,6 +463,7 @@ func (r *DefaultReporter) mainLoop(ctx context.Context, c types.SyncerConfig) { lastVacuum time.Time vacuumInterval = config.GetReloadableDurationVar(15, time.Minute, "Reporting.vacuumInterval") vacuumThresholdBytes = config.GetReloadableInt64Var(10*bytesize.GB, 1, "Reporting.vacuumThresholdBytes") + aggregationInterval = config.GetReloadableDurationVar(1, time.Minute, "Reporting.aggregationIntervalMinutes") // Values should be a factor of 60 or else we will panic, for example 1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60 ) for { if ctx.Err() != nil { @@ -401,7 +475,8 @@ func (r *DefaultReporter) mainLoop(ctx context.Context, c types.SyncerConfig) { currentMin := time.Now().UTC().Unix() / 60 getReportsStart := time.Now() - reports, reportedAt, err := r.getReports(currentMin, c.ConnInfo) + aggregationIntervalMin := int64(aggregationInterval.Load().Minutes()) + reports, reportedAt, err := r.getReports(currentMin, aggregationIntervalMin, c.ConnInfo) if err != nil { r.log.Errorw("getting reports", "error", err) select { @@ -460,7 +535,9 @@ func (r *DefaultReporter) mainLoop(ctx context.Context, c types.SyncerConfig) { if err != nil { return err } - _, err = dbHandle.Exec(`DELETE FROM `+ReportsTable+` WHERE reported_at = $1`, reportedAt) + // Use the same aggregationIntervalMin value that was used to query the reports in getReports() + bucketStart, bucketEnd := r.getAggregationBucketMinute(reportedAt, aggregationIntervalMin) + _, err = dbHandle.Exec(`DELETE FROM `+ReportsTable+` WHERE reported_at >= $1 and reported_at < $2`, bucketStart, bucketEnd) if err != nil { r.log.Errorf(`[ Reporting ]: Error deleting local reports from %s: %v`, ReportsTable, err) } else { @@ -612,6 +689,34 @@ func transformMetricForPII(metric types.PUReportedMetric, piiColumns []string) t return metric } +func (r *DefaultReporter) transformMetricWithEventSampling(metric types.PUReportedMetric, reportedAt int64) (types.PUReportedMetric, error) { + if r.eventSampler == nil { + return metric, nil + } + + isValidSampleEvent := metric.StatusDetail.SampleEvent != nil && string(metric.StatusDetail.SampleEvent) != "{}" + + if isValidSampleEvent { + sampleEventBucket, _ := r.getAggregationBucketMinute(reportedAt, int64(r.eventSamplingDuration.Load().Minutes())) + hash := NewLabelSet(metric, sampleEventBucket).generateHash() + found, err := r.eventSampler.Get(hash) + if err != nil { + return metric, err + } + + if found { + metric.StatusDetail.SampleEvent = json.RawMessage(`{}`) + metric.StatusDetail.SampleResponse = "" + } else { + err := r.eventSampler.Put(hash) + if err != nil { + return metric, err + } + } + } + return metric, nil +} + func (r *DefaultReporter) Report(ctx context.Context, metrics []*types.PUReportedMetric, txn *Tx) error { if len(metrics) == 0 { return nil @@ -659,6 +764,13 @@ func (r *DefaultReporter) Report(ctx context.Context, metrics []*types.PUReporte metric = transformMetricForPII(metric, getPIIColumnsToExclude()) } + if r.eventSamplingEnabled.Load() { + metric, err = r.transformMetricWithEventSampling(metric, reportedAt) + if err != nil { + return err + } + } + runeEventName := []rune(metric.StatusDetail.EventName) if len(runeEventName) > 50 { metric.StatusDetail.EventName = fmt.Sprintf("%s...%s", string(runeEventName[:40]), string(runeEventName[len(runeEventName)-10:])) @@ -710,4 +822,8 @@ func (r *DefaultReporter) getTags(label string) stats.Tags { func (r *DefaultReporter) Stop() { r.cancel() _ = r.g.Wait() + + if r.eventSampler != nil { + r.eventSampler.Close() + } } diff --git a/enterprise/reporting/reporting_test.go b/enterprise/reporting/reporting_test.go index 8bfbf7b3b79..63464de5c66 100644 --- a/enterprise/reporting/reporting_test.go +++ b/enterprise/reporting/reporting_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/stats" @@ -162,12 +163,14 @@ func TestGetAggregatedReports(t *testing.T) { }, } conf := config.New() + conf.Set("Reporting.eventSampling.durationInMinutes", 10) configSubscriber := newConfigSubscriber(logger.NOP) reportHandle := NewDefaultReporter(context.Background(), conf, logger.NOP, configSubscriber, stats.NOP) t.Run("Should provide aggregated reports when batch size is 1", func(t *testing.T) { conf.Set("Reporting.maxReportsCountInARequest", 1) assert.Equal(t, 1, reportHandle.maxReportsCountInARequest.Load()) + bucket, _ := reportHandle.getAggregationBucketMinute(28017690, 10) expectedResponse := []*types.Metric{ { InstanceDetails: types.InstanceDetails{ @@ -184,7 +187,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -213,7 +217,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -242,7 +247,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -265,6 +271,7 @@ func TestGetAggregatedReports(t *testing.T) { t.Run("Should provide aggregated reports when batch size more than 1", func(t *testing.T) { conf.Set("Reporting.maxReportsCountInARequest", 10) assert.Equal(t, 10, reportHandle.maxReportsCountInARequest.Load()) + bucket, _ := reportHandle.getAggregationBucketMinute(28017690, 10) expectedResponse := []*types.Metric{ { InstanceDetails: types.InstanceDetails{ @@ -281,7 +288,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -319,7 +327,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -342,6 +351,7 @@ func TestGetAggregatedReports(t *testing.T) { t.Run("Should provide aggregated reports when batch size is more than 1 and reports with same identifier are more then batch size", func(t *testing.T) { conf.Set("Reporting.maxReportsCountInARequest", 2) assert.Equal(t, 2, reportHandle.maxReportsCountInARequest.Load()) + bucket, _ := reportHandle.getAggregationBucketMinute(28017690, 10) extraReport := &types.ReportByStatus{ InstanceDetails: types.InstanceDetails{ WorkspaceID: "some-workspace-id", @@ -386,7 +396,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -424,7 +435,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -453,7 +465,8 @@ func TestGetAggregatedReports(t *testing.T) { PU: "some-pu", }, ReportMetadata: types.ReportMetadata{ - ReportedAt: 28017690 * 60 * 1000, + ReportedAt: 28017690 * 60 * 1000, + SampleEventBucket: bucket * 60 * 1000, }, StatusDetails: []*types.StatusDetail{ { @@ -962,3 +975,218 @@ func TestAggregationLogic(t *testing.T) { require.Equal(t, reportResults, reportingMetrics) } + +func TestGetAggregationBucket(t *testing.T) { + conf := config.New() + configSubscriber := newConfigSubscriber(logger.NOP) + reportHandle := NewDefaultReporter(context.Background(), conf, logger.NOP, configSubscriber, stats.NOP) + t.Run("should return the correct aggregation bucket with default interval of 1 mintue", func(t *testing.T) { + cases := []struct { + reportedAt int64 + bucketStart int64 + bucketEnd int64 + }{ + { + reportedAt: time.Date(2022, 1, 1, 10, 5, 10, 40, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 1, 1, 10, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 10, 6, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 2, 4, 11, 5, 59, 10, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 2, 4, 11, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 2, 4, 11, 6, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 3, 5, 12, 59, 59, 59, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 3, 5, 12, 59, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 3, 5, 13, 0, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 0, 0, 0, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 13, 1, 0, 0, time.UTC).Unix() / 60, + }, + } + + for _, c := range cases { + bs, be := reportHandle.getAggregationBucketMinute(c.reportedAt, 1) + require.Equal(t, c.bucketStart, bs) + require.Equal(t, c.bucketEnd, be) + } + }) + + t.Run("should return the correct aggregation bucket with aggregation interval of 5 mintue", func(t *testing.T) { + cases := []struct { + reportedAt int64 + bucketStart int64 + bucketEnd int64 + }{ + { + reportedAt: time.Date(2022, 1, 1, 10, 5, 10, 40, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 1, 1, 10, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 10, 10, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 2, 4, 11, 5, 59, 10, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 2, 4, 11, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 2, 4, 11, 10, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 7, 30, 11, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 13, 10, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 8, 50, 30, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 13, 10, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 9, 5, 15, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 13, 10, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 3, 5, 12, 55, 53, 1, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 3, 5, 12, 55, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 3, 5, 13, 0, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 3, 5, 12, 57, 53, 1, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 3, 5, 12, 55, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 3, 5, 13, 0, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 3, 5, 12, 59, 59, 59, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 3, 5, 12, 55, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 3, 5, 13, 0, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 0, 0, 0, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 13, 5, 0, 0, time.UTC).Unix() / 60, + }, + } + + for _, c := range cases { + bs, be := reportHandle.getAggregationBucketMinute(c.reportedAt, 5) + require.Equal(t, c.bucketStart, bs) + require.Equal(t, c.bucketEnd, be) + } + }) + + t.Run("should return the correct aggregation bucket with aggregation interval of 15 mintue", func(t *testing.T) { + cases := []struct { + reportedAt int64 + bucketStart int64 + bucketEnd int64 + }{ + { + reportedAt: time.Date(2022, 1, 1, 10, 5, 10, 40, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 1, 1, 10, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 10, 15, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 2, 4, 11, 17, 59, 10, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 2, 4, 11, 15, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 2, 4, 11, 30, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 39, 10, 59, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 30, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 13, 45, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 4, 6, 13, 59, 50, 30, time.UTC).Unix() / 60, + bucketStart: time.Date(2022, 4, 6, 13, 45, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 4, 6, 14, 0, 0, 0, time.UTC).Unix() / 60, + }, + } + + for _, c := range cases { + bs, be := reportHandle.getAggregationBucketMinute(c.reportedAt, 15) + require.Equal(t, c.bucketStart, bs) + require.Equal(t, c.bucketEnd, be) + } + }) + + t.Run("should choose closest factor of 60 if interval is non positive and return the correct aggregation bucket", func(t *testing.T) { + cases := []struct { + reportedAt int64 + interval int64 + bucketStart int64 + bucketEnd int64 + }{ + { + reportedAt: time.Date(2022, 1, 1, 12, 5, 10, 40, time.UTC).Unix() / 60, + interval: -1, // it should round to 1 + bucketStart: time.Date(2022, 1, 1, 12, 5, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 12, 6, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 2, 29, 10, 0, 2, 59, time.UTC).Unix() / 60, + interval: -1, // it should round to 1 + bucketStart: time.Date(2022, 2, 29, 10, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 2, 29, 10, 1, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 2, 10, 0, 0, 0, 40, time.UTC).Unix() / 60, + interval: 0, // it should round to 1 + bucketStart: time.Date(2022, 2, 10, 0, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 2, 10, 0, 1, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 11, 27, 23, 59, 59, 40, time.UTC).Unix() / 60, + interval: 0, // it should round to 1 + bucketStart: time.Date(2022, 11, 27, 23, 59, 59, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 11, 28, 0, 0, 0, 0, time.UTC).Unix() / 60, + }, + } + + for _, c := range cases { + bs, be := reportHandle.getAggregationBucketMinute(c.reportedAt, c.interval) + require.Equal(t, c.bucketStart, bs) + require.Equal(t, c.bucketEnd, be) + } + }) + + t.Run("should choose closest factor of 60 if interval is not a factor of 60 and return the correct aggregation bucket", func(t *testing.T) { + cases := []struct { + reportedAt int64 + interval int64 + bucketStart int64 + bucketEnd int64 + }{ + { + reportedAt: time.Date(2022, 1, 1, 10, 23, 10, 40, time.UTC).Unix() / 60, + interval: 7, // it should round to 6 + bucketStart: time.Date(2022, 1, 1, 10, 18, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 10, 24, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 1, 1, 10, 5, 10, 40, time.UTC).Unix() / 60, + interval: 14, // it should round to 12 + bucketStart: time.Date(2022, 1, 1, 10, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 10, 12, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 1, 1, 10, 39, 10, 40, time.UTC).Unix() / 60, + interval: 59, // it should round to 30 + bucketStart: time.Date(2022, 1, 1, 10, 30, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 11, 0, 0, 0, time.UTC).Unix() / 60, + }, + { + reportedAt: time.Date(2022, 1, 1, 10, 5, 10, 40, time.UTC).Unix() / 60, + interval: 63, // it should round to 60 + bucketStart: time.Date(2022, 1, 1, 10, 0, 0, 0, time.UTC).Unix() / 60, + bucketEnd: time.Date(2022, 1, 1, 11, 0, 0, 0, time.UTC).Unix() / 60, + }, + } + + for _, c := range cases { + bs, be := reportHandle.getAggregationBucketMinute(c.reportedAt, c.interval) + require.Equal(t, c.bucketStart, bs) + require.Equal(t, c.bucketEnd, be) + } + }) +} diff --git a/go.mod b/go.mod index 8974b7b16b8..ff3e77c14cf 100644 --- a/go.mod +++ b/go.mod @@ -24,9 +24,9 @@ replace ( ) require ( - cloud.google.com/go/bigquery v1.64.0 + cloud.google.com/go/bigquery v1.65.0 cloud.google.com/go/pubsub v1.45.3 - cloud.google.com/go/storage v1.47.0 + cloud.google.com/go/storage v1.48.0 github.com/Azure/azure-storage-blob-go v0.15.0 github.com/ClickHouse/clickhouse-go v1.5.4 github.com/DATA-DOG/go-sqlmock v1.5.2 @@ -86,7 +86,7 @@ require ( github.com/segmentio/go-hll v1.0.1 github.com/segmentio/kafka-go v0.4.47 github.com/segmentio/ksuid v1.0.4 - github.com/snowflakedb/gosnowflake v1.12.0 + github.com/snowflakedb/gosnowflake v1.12.1 github.com/sony/gobreaker v1.0.0 github.com/spaolacci/murmur3 v1.1.0 github.com/spf13/cast v1.7.0 @@ -116,6 +116,7 @@ require ( require ( github.com/BurntSushi/toml v1.4.0 // indirect github.com/apache/arrow-go/v18 v18.0.0 // indirect + github.com/apache/arrow/go/v16 v16.0.0 // indirect github.com/dgraph-io/ristretto/v2 v2.0.0 // indirect github.com/spf13/viper v1.19.0 // indirect ) @@ -343,7 +344,7 @@ require ( golang.org/x/sys v0.27.0 // indirect golang.org/x/term v0.25.0 // indirect golang.org/x/text v0.20.0 // indirect - golang.org/x/time v0.8.0 // indirect + golang.org/x/time v0.8.0 golang.org/x/tools v0.26.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect diff --git a/go.sum b/go.sum index 4005fe1b4fb..9d8f12ec864 100644 --- a/go.sum +++ b/go.sum @@ -44,8 +44,8 @@ cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvf cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/bigquery v1.64.0 h1:vSSZisNyhr2ioJE1OuYBQrnrpB7pIhRQm4jfjc7E/js= -cloud.google.com/go/bigquery v1.64.0/go.mod h1:gy8Ooz6HF7QmA+TRtX8tZmXBKH5mCFBwUApGAb3zI7Y= +cloud.google.com/go/bigquery v1.65.0 h1:ZZ1EOJMHTYf6R9lhxIXZJic1qBD4/x9loBIS+82moUs= +cloud.google.com/go/bigquery v1.65.0/go.mod h1:9WXejQ9s5YkTW4ryDYzKXBooL78u5+akWGXgJqQkY6A= cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= cloud.google.com/go/compute v1.2.0/go.mod h1:xlogom/6gr8RJGBe7nT2eGsQYAFUbbv8dbC29qE3Xmw= cloud.google.com/go/compute v1.3.0/go.mod h1:cCZiE1NHEtai4wiufUhW8I8S1JKkAnhnQJWM7YD99wM= @@ -89,8 +89,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.12.0/go.mod h1:fFLk2dp2oAhDz8QFKwqrjdJvxSp/W2g7nillojlL5Ho= cloud.google.com/go/storage v1.21.0/go.mod h1:XmRlxkgPjlBONznT2dDUU/5XlpU2OjMnKuqnZI01LAA= -cloud.google.com/go/storage v1.47.0 h1:ajqgt30fnOMmLfWfu1PWcb+V9Dxz6n+9WKjdNg5R4HM= -cloud.google.com/go/storage v1.47.0/go.mod h1:Ks0vP374w0PW6jOUameJbapbQKXqkjGd/OJRp2fb9IQ= +cloud.google.com/go/storage v1.48.0 h1:FhBDHACbVtdPx7S/AbcKujPWiHvfO6F8OXGgCEbB2+o= +cloud.google.com/go/storage v1.48.0/go.mod h1:aFoDYNMAjv67lp+xcuZqjUKv/ctmplzQ3wJgodA7b+M= cloud.google.com/go/trace v1.0.0/go.mod h1:4iErSByzxkyHWzzlAj63/Gmjz0NH1ASqhJguHpGcr6A= cloud.google.com/go/trace v1.2.0/go.mod h1:Wc8y/uYyOhPy12KEnXG9XGrvfMz5F5SrYecQlbW1rwM= cloud.google.com/go/trace v1.11.2 h1:4ZmaBdL8Ng/ajrgKqY5jfvzqMXbrDcBsUGXOT9aqTtI= @@ -236,6 +236,8 @@ github.com/apache/arrow/go/v12 v12.0.1 h1:JsR2+hzYYjgSUkBSaahpqCetqZMr76djX80fF/ github.com/apache/arrow/go/v12 v12.0.1/go.mod h1:weuTY7JvTG/HDPtMQxEUp7pU73vkLWMLpY67QwZ/WWw= github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= github.com/apache/arrow/go/v15 v15.0.2/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= +github.com/apache/arrow/go/v16 v16.0.0 h1:qRLbJRPj4zaseZrjbDHa7mUoZDDIU+4pu+mE2Lucs5g= +github.com/apache/arrow/go/v16 v16.0.0/go.mod h1:9wnc9mn6vEDTRIm4+27pEjQpRKuTvBaessPoEXQzxWA= github.com/apache/pulsar-client-go v0.14.0 h1:P7yfAQhQ52OCAu8yVmtdbNQ81vV8bF54S2MLmCPJC9w= github.com/apache/pulsar-client-go v0.14.0/go.mod h1:PNUE29x9G1EHMvm41Bs2vcqwgv7N8AEjeej+nEVYbX8= github.com/apache/thrift v0.14.2/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= @@ -1229,8 +1231,8 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/snowflakedb/gosnowflake v1.12.0 h1:Saez8egtn5xAoVMBxFaMu9MYfAG9SS9dpAEXD1/ECIo= -github.com/snowflakedb/gosnowflake v1.12.0/go.mod h1:wHfYmZi3zvtWItojesAhWWXBN7+niex2R1h/S7QCZYg= +github.com/snowflakedb/gosnowflake v1.12.1 h1:IpYK9Wr1dYwPiMSG9RNudAJV0rI0ZOgcNEMXOUiPFX8= +github.com/snowflakedb/gosnowflake v1.12.1/go.mod h1:SYLNMBZ4LXTJfTfJt+M4N40DwabGUx3gkH7VT8hu3Rw= github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ= github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= diff --git a/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/apiService.go b/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/apiService.go index 34b45de4b2f..ddfeddc7263 100644 --- a/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/apiService.go +++ b/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/apiService.go @@ -7,6 +7,8 @@ import ( "net/http" "time" + "golang.org/x/time/rate" + "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" backendconfig "github.com/rudderlabs/rudder-server/backend-config" @@ -16,14 +18,35 @@ const ( KlaviyoAPIURL = "https://a.klaviyo.com/api/profile-bulk-import-jobs/" ) +type RateLimiterHTTPClient struct { + client *http.Client + Ratelimiter *rate.Limiter +} + +func (c *RateLimiterHTTPClient) Do(req *http.Request) (*http.Response, error) { + if err := c.Ratelimiter.Wait(req.Context()); err != nil { + return nil, err + } + return c.client.Do(req) +} + type KlaviyoAPIServiceImpl struct { - client *http.Client + client *RateLimiterHTTPClient PrivateAPIKey string logger logger.Logger statsFactory stats.Stats statLabels stats.Tags } +func newRateLimiterClient() *RateLimiterHTTPClient { + rlc := &RateLimiterHTTPClient{ + client: http.DefaultClient, + // Doc: https://developers.klaviyo.com/en/reference/bulk_import_profiles + Ratelimiter: rate.NewLimiter(rate.Every(400*time.Millisecond), 10), + } + return rlc +} + func setRequestHeaders(req *http.Request, apiKey string) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Klaviyo-API-Key "+apiKey) @@ -59,6 +82,10 @@ func (k *KlaviyoAPIServiceImpl) UploadProfiles(profiles Payload) (*UploadResp, e if len(uploadResp.Errors) > 0 { return &uploadResp, fmt.Errorf("upload failed with errors: %+v", uploadResp.Errors) } + if uploadResp.Data.Id == "" { + k.logger.Error("[klaviyo bulk upload] upload failed with empty importId", string(uploadBodyBytes)) + return &uploadResp, fmt.Errorf("upload failed with empty importId") + } uploadTimeStat := k.statsFactory.NewTaggedStat("async_upload_time", stats.TimerType, k.statLabels) uploadTimeStat.Since(startTime) @@ -66,6 +93,9 @@ func (k *KlaviyoAPIServiceImpl) UploadProfiles(profiles Payload) (*UploadResp, e } func (k *KlaviyoAPIServiceImpl) GetUploadStatus(importId string) (*PollResp, error) { + if importId == "" { + return nil, fmt.Errorf("importId is empty") + } pollUrl := KlaviyoAPIURL + importId req, err := http.NewRequest("GET", pollUrl, nil) if err != nil { @@ -122,13 +152,13 @@ func NewKlaviyoAPIService(destination *backendconfig.DestinationT, logger logger return nil, fmt.Errorf("privateApiKey not found or not a string") } return &KlaviyoAPIServiceImpl{ - client: http.DefaultClient, + client: newRateLimiterClient(), PrivateAPIKey: privateApiKey, logger: logger, statsFactory: statsFactory, statLabels: stats.Tags{ "module": "batch_router", - "destType": destination.Name, + "destType": destination.DestinationDefinition.Name, "destID": destination.ID, }, }, nil diff --git a/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload.go b/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload.go index cb1c95a5dec..aaff4819f08 100644 --- a/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload.go +++ b/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload.go @@ -22,7 +22,7 @@ import ( const ( BATCHSIZE = 10000 MAXALLOWEDPROFILESIZE = 512000 - MAXPAYLOADSIZE = 4900000 + MAXPAYLOADSIZE = 4600000 IMPORT_ID_SEPARATOR = ":" ) @@ -106,7 +106,9 @@ func (kbu *KlaviyoBulkUploader) Poll(pollInput common.AsyncPoll) common.PollStat importStatuses := make(map[string]string) failedImports := make([]string, 0) for _, importId := range importIds { - importStatuses[importId] = "queued" + if importId != "" { + importStatuses[importId] = "queued" + } } for { @@ -158,7 +160,6 @@ func (kbu *KlaviyoBulkUploader) Poll(pollInput common.AsyncPoll) common.PollStat func (kbu *KlaviyoBulkUploader) GetUploadStats(UploadStatsInput common.GetUploadStatsInput) common.GetUploadStatsResponse { pollResultImportIds := strings.Split(UploadStatsInput.FailedJobParameters, IMPORT_ID_SEPARATOR) - // make a map of jobId to error reason jobIdToErrorMap := make(map[int64]string) @@ -284,8 +285,8 @@ func (kbu *KlaviyoBulkUploader) Upload(asyncDestStruct *common.AsyncDestinationS // if profileStructure length is more than 500 kB, throw an error profileStructureJSON, _ := json.Marshal(profileStructure) profileSize := float64(len(profileStructureJSON)) - profileSizeStat.Observe(float64(profileSize)) // Record the size in the histogram - if float64(len(profileStructureJSON)) >= MAXALLOWEDPROFILESIZE { + profileSizeStat.Observe(profileSize) // Record the size in the histogram + if len(profileStructureJSON) >= MAXALLOWEDPROFILESIZE { abortReason = "Error while marshaling profiles. The profile size exceeds Klaviyo's limit of 500 kB for a single profile." abortedJobs = append(abortedJobs, int64(metadata.JobID)) continue @@ -304,7 +305,7 @@ func (kbu *KlaviyoBulkUploader) Upload(asyncDestStruct *common.AsyncDestinationS uploadResp, err := kbu.KlaviyoAPIService.UploadProfiles(combinedPayload) if err != nil { failedJobs = append(failedJobs, importingJobIDs[idx]) - kbu.Logger.Error("Error while uploading profiles", err, uploadResp.Errors) + kbu.Logger.Error("Error while uploading profiles", err, uploadResp.Errors, destinationID) continue } diff --git a/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload_test.go b/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload_test.go index 8d7e11ae9fd..0ec8a9c65db 100644 --- a/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload_test.go +++ b/router/batchrouter/asyncdestinationmanager/klaviyobulkupload/klaviyobulkupload_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" "go.uber.org/mock/gomock" "github.com/rudderlabs/rudder-go-kit/stats" @@ -17,9 +18,9 @@ import ( backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/jobsdb" - mockklaviyoservice "github.com/rudderlabs/rudder-server/mocks/router/klaviyobulkupload" + mockAPIService "github.com/rudderlabs/rudder-server/mocks/router/klaviyobulkupload" "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/common" - "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/klaviyobulkupload" + klaviyobulkupload "github.com/rudderlabs/rudder-server/router/batchrouter/asyncdestinationmanager/klaviyobulkupload" ) var currentDir, _ = os.Getwd() @@ -48,7 +49,7 @@ func TestUpload(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockKlaviyoAPIService := mockklaviyoservice.NewMockKlaviyoAPIService(ctrl) + mockKlaviyoAPIService := mockAPIService.NewMockKlaviyoAPIService(ctrl) testLogger := logger.NewLogger().Child("klaviyo-bulk-upload-test") uploader := klaviyobulkupload.KlaviyoBulkUploader{ @@ -192,20 +193,29 @@ func TestUploadIntegration(t *testing.T) { ImportingJobIDs: []int64{1, 2, 3}, } - output := kbu.Upload(asyncDestStruct) - assert.NotNil(t, output) - assert.Equal(t, destination.ID, output.DestinationID) - assert.Empty(t, output.FailedJobIDs) - assert.Empty(t, output.AbortJobIDs) - assert.Empty(t, output.AbortReason) - assert.NotEmpty(t, output.ImportingJobIDs) + uploadResp := kbu.Upload(asyncDestStruct) + assert.NotNil(t, uploadResp) + assert.Equal(t, destination.ID, uploadResp.DestinationID) + assert.Empty(t, uploadResp.FailedJobIDs) + assert.Empty(t, uploadResp.AbortJobIDs) + assert.Empty(t, uploadResp.AbortReason) + assert.NotEmpty(t, uploadResp.ImportingJobIDs) + assert.NotNil(t, uploadResp.ImportingParameters) + + importId := gjson.GetBytes(uploadResp.ImportingParameters, "importId").String() + pollResp := kbu.Poll(common.AsyncPoll{ImportId: importId}) + assert.NotNil(t, pollResp) + assert.Equal(t, http.StatusOK, pollResp.StatusCode) + assert.True(t, pollResp.Complete) + assert.False(t, pollResp.HasFailed) + assert.False(t, pollResp.HasWarning) } func TestPoll(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockKlaviyoAPIService := mockklaviyoservice.NewMockKlaviyoAPIService(ctrl) + mockKlaviyoAPIService := mockAPIService.NewMockKlaviyoAPIService(ctrl) testLogger := logger.NewLogger().Child("klaviyo-bulk-upload-test") uploader := klaviyobulkupload.KlaviyoBulkUploader{ @@ -313,7 +323,7 @@ func TestGetUploadStats(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockKlaviyoAPIService := mockklaviyoservice.NewMockKlaviyoAPIService(ctrl) + mockKlaviyoAPIService := mockAPIService.NewMockKlaviyoAPIService(ctrl) testLogger := logger.NewLogger().Child("klaviyo-bulk-upload-test") uploader := klaviyobulkupload.KlaviyoBulkUploader{ diff --git a/utils/types/reporting_types.go b/utils/types/reporting_types.go index 5168c4d0383..f7334968bb4 100644 --- a/utils/types/reporting_types.go +++ b/utils/types/reporting_types.go @@ -57,7 +57,7 @@ type StatusDetail struct { EventType string `json:"eventType"` ErrorType string `json:"errorType"` ViolationCount int64 `json:"violationCount"` - StatTags map[string]string `json:"statTags"` + StatTags map[string]string `json:"-"` FailedMessages []*FailedMessage `json:"-"` } @@ -81,7 +81,8 @@ type InstanceDetails struct { } type ReportMetadata struct { - ReportedAt int64 `json:"reportedAt"` + ReportedAt int64 `json:"reportedAt"` + SampleEventBucket int64 `json:"bucket"` } type Metric struct { diff --git a/utils/types/reporting_types_test.go b/utils/types/reporting_types_test.go new file mode 100644 index 00000000000..01501902c5c --- /dev/null +++ b/utils/types/reporting_types_test.go @@ -0,0 +1,134 @@ +package types_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-server/utils/types" +) + +func TestMetricJSONMarshaling(t *testing.T) { + expectedJSON := `{ + "workspaceId": "SomeWorkspaceId", + "namespace": "SomeNamespace", + "instanceId": "1", + "sourceId": "SomeSourceId", + "destinationId": "SomeDestinationId", + "DestinationDefinitionId": "SomeDestinationDefinitionId", + "sourceDefinitionId": "SomeSourceDefinitionId", + "sourceTaskRunId": "", + "sourceJobId": "", + "sourceJobRunId": "", + "sourceCategory": "SomeSourceCategory", + "inReportedBy": "router", + "reportedBy": "router", + "transformationId": "", + "transformationVersionId": "2", + "terminalState": true, + "initialState": false, + "reportedAt": 1730712600000, + "trackingPlanId": "1", + "trackingPlanVersion": 1, + "bucket": 1730712600000, + "reports": [ + { + "state": "failed", + "count": 20, + "errorType": "this is errorType", + "statusCode": 400, + "violationCount": 12, + "sampleResponse": "error email not valid", + "sampleEvent": {"key": "value-1"}, + "eventName": "SomeEventName1", + "eventType": "SomeEventType" + }, + { + "state": "failed", + "count": 20, + "errorType": "this is errorType", + "statusCode": 400, + "violationCount": 12, + "sampleResponse": "error email not valid", + "sampleEvent": {"key": "value-1"}, + "eventName": "SomeEventName2", + "eventType": "SomeEventType" + } + ] + }` + + // Populate the Metric struct + metric := types.Metric{ + InstanceDetails: types.InstanceDetails{ + WorkspaceID: "SomeWorkspaceId", + Namespace: "SomeNamespace", + InstanceID: "1", + }, + ConnectionDetails: types.ConnectionDetails{ + SourceID: "SomeSourceId", + DestinationID: "SomeDestinationId", + SourceDefinitionID: "SomeSourceDefinitionId", + DestinationDefinitionID: "SomeDestinationDefinitionId", + SourceTaskRunID: "", + SourceJobID: "", + SourceJobRunID: "", + SourceCategory: "SomeSourceCategory", + TransformationID: "", + TransformationVersionID: "2", + TrackingPlanID: "1", + TrackingPlanVersion: 1, + }, + PUDetails: types.PUDetails{ + InPU: "router", + PU: "router", + TerminalPU: true, + InitialPU: false, + }, + ReportMetadata: types.ReportMetadata{ + ReportedAt: 1730712600000, + SampleEventBucket: 1730712600000, + }, + StatusDetails: []*types.StatusDetail{ + { + Status: "failed", + Count: 20, + StatusCode: 400, + SampleResponse: "error email not valid", + SampleEvent: json.RawMessage(`{"key": "value-1"}`), + EventName: "SomeEventName1", + EventType: "SomeEventType", + ErrorType: "this is errorType", + ViolationCount: 12, + StatTags: map[string]string{ + "category": "validation", + }, + FailedMessages: []*types.FailedMessage{ + { + MessageID: "1", + ReceivedAt: time.Now(), + }, + }, + }, + { + Status: "failed", + Count: 20, + StatusCode: 400, + SampleResponse: "error email not valid", + SampleEvent: json.RawMessage(`{"key": "value-1"}`), + EventName: "SomeEventName2", + EventType: "SomeEventType", + ErrorType: "this is errorType", + ViolationCount: 12, + StatTags: map[string]string{ + "category": "autentication", + }, + }, + }, + } + + marshaledJSON, err := json.Marshal(metric) + require.NoError(t, err) + require.JSONEq(t, expectedJSON, string(marshaledJSON)) +} diff --git a/warehouse/router/state_export_data.go b/warehouse/router/state_export_data.go index d83c80f39d8..e768f4707ad 100644 --- a/warehouse/router/state_export_data.go +++ b/warehouse/router/state_export_data.go @@ -6,6 +6,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "github.com/samber/lo" @@ -242,7 +243,7 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e LastExecTime: &lastExecTime, }) - _, err = job.updateSchema(job.identifiesTableName()) + alteredIdentitySchema, err := job.updateSchema(job.identifiesTableName()) if err != nil { status := model.TableUploadUpdatingSchemaFailed errorsString := misc.QuoteLiteral(err.Error()) @@ -252,6 +253,7 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e }) return job.processLoadTableResponse(map[string]error{job.identifiesTableName(): err}) } + var alteredUserSchema bool if _, ok := job.upload.UploadSchema[job.usersTableName()]; ok { status := model.TableUploadExecuting lastExecTime := job.now() @@ -259,7 +261,7 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e Status: &status, LastExecTime: &lastExecTime, }) - _, err = job.updateSchema(job.usersTableName()) + alteredUserSchema, err = job.updateSchema(job.usersTableName()) if err != nil { status = model.TableUploadUpdatingSchemaFailed errorsString := misc.QuoteLiteral(err.Error()) @@ -277,6 +279,10 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e } errorMap := job.whManager.LoadUserTables(job.ctx) + if alteredIdentitySchema || alteredUserSchema { + job.logger.Infof("loadUserTables: schema changed - updating local schema for %s", job.warehouse.Identifier) + _ = job.schemaHandle.UpdateLocalSchemaWithWarehouse(job.ctx) + } return job.processLoadTableResponse(errorMap) } @@ -491,7 +497,7 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE return job.processLoadTableResponse(errorMap) } } - + var alteredSchema bool for _, tableName := range identityTables { if _, loaded := currentJobSucceededTables[tableName]; loaded { continue @@ -500,6 +506,7 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE errorMap[tableName] = nil tableSchemaDiff := job.schemaHandle.TableSchemaDiff(tableName, job.GetTableSchemaInUpload(tableName)) + job.logger.Infof("Table schema diff for %s: %v", tableName, tableSchemaDiff) if tableSchemaDiff.Exists { err := job.UpdateTableSchema(tableName, tableSchemaDiff) if err != nil { @@ -518,6 +525,7 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tableName, repo.TableUploadSetOptions{ Status: &status, }) + alteredSchema = true } status := model.TableUploadExecuting @@ -543,6 +551,10 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE break } } + if alteredSchema { + job.logger.Infof("loadIdentityTables: schema changed - updating local schema for %s", job.warehouse.Identifier) + _ = job.schemaHandle.UpdateLocalSchemaWithWarehouse(job.ctx) // TODO check error + } return job.processLoadTableResponse(errorMap) } @@ -623,7 +635,7 @@ func (job *UploadJob) loadAllTablesExcept(skipLoadForTables []string, loadFilesT var wg sync.WaitGroup wg.Add(len(uploadSchema)) - + var alteredSchemaInAtLeastOneTable atomic.Bool concurrencyGuard := make(chan struct{}, parallelLoads) var ( @@ -664,7 +676,10 @@ func (job *UploadJob) loadAllTablesExcept(skipLoadForTables []string, loadFilesT tableName := tableName concurrencyGuard <- struct{}{} rruntime.GoForWarehouse(func() { - _, err := job.loadTable(tableName) + alteredSchema, err := job.loadTable(tableName) + if alteredSchema { + alteredSchemaInAtLeastOneTable.Store(true) + } if err != nil { loadErrorLock.Lock() loadErrors = append(loadErrors, err) @@ -676,6 +691,11 @@ func (job *UploadJob) loadAllTablesExcept(skipLoadForTables []string, loadFilesT }) } wg.Wait() + + if alteredSchemaInAtLeastOneTable.Load() { + job.logger.Infof("loadAllTablesExcept: schema changed - updating local schema for %s", job.warehouse.Identifier) + _ = job.schemaHandle.UpdateLocalSchemaWithWarehouse(job.ctx) // TODO check error + } return loadErrors } diff --git a/warehouse/router/state_generate_upload_schema.go b/warehouse/router/state_generate_upload_schema.go index 2e741f34f90..ceb58e67c73 100644 --- a/warehouse/router/state_generate_upload_schema.go +++ b/warehouse/router/state_generate_upload_schema.go @@ -30,6 +30,5 @@ func (job *UploadJob) generateUploadSchema() error { } job.upload.UploadSchema = uploadSchema - mergedSchema := job.schemaHandle.MergeUploadSchemaWithLocalSchema(uploadSchema) - return job.schemaHandle.UpdateLocalSchema(job.ctx, mergedSchema) + return nil } diff --git a/warehouse/schema/schema.go b/warehouse/schema/schema.go index 97afc8b0006..92f74dc40cb 100644 --- a/warehouse/schema/schema.go +++ b/warehouse/schema/schema.go @@ -413,18 +413,3 @@ func (sh *Schema) GetColumnsCountInWarehouseSchema(tableName string) int { defer sh.schemaInWarehouseMu.RUnlock() return len(sh.schemaInWarehouse[tableName]) } - -func (sh *Schema) MergeUploadSchemaWithLocalSchema(schema model.Schema) model.Schema { - sh.localSchemaMu.RLock() - defer sh.localSchemaMu.RUnlock() - var mergedSchema = sh.localSchema - for tableName, columnMap := range schema { - if _, ok := mergedSchema[tableName]; !ok { - mergedSchema[tableName] = model.TableSchema{} - } - for columnName, columnType := range columnMap { - mergedSchema[tableName][columnName] = columnType - } - } - return mergedSchema -}