diff --git a/internal/impl/snowflake/integration_test.go b/internal/impl/snowflake/integration_test.go index a9550cf47c..ed921b0f6c 100644 --- a/internal/impl/snowflake/integration_test.go +++ b/internal/impl/snowflake/integration_test.go @@ -144,3 +144,61 @@ snowflake_streaming: {"zing", "6"}, }, rows) } + +func TestIntegrationNamedChannels(t *testing.T) { + integration.CheckSkip(t) + produce, stream := SetupSnowflakeStream(t, ` +label: snowpipe_streaming +snowflake_streaming: + account: "${SNOWFLAKE_ACCOUNT:WQKFXQQ-WI77362}" + user: "${SNOWFLAKE_USER:ROCKWOODREDPANDA}" + role: ACCOUNTADMIN + database: "${SNOWFLAKE_DB:BABY_DATABASE}" + schema: PUBLIC + table: integration_test_named_channels + init_statement: | + DROP TABLE IF EXISTS integration_test_named_channels; + private_key_file: "${SNOWFLAKE_PRIVATE_KEY:./streaming/resources/rsa_key.p8}" + max_in_flight: 1 + offset_token: "${!this.token}" + channel_name: "${!this.channel}" + schema_evolution: + enabled: true +`) + RunStreamInBackground(t, stream) + require.NoError(t, produce(context.Background(), Batch([]map[string]any{ + {"foo": "bar", "token": 1, "channel": "foo"}, + {"foo": "baz", "token": 2, "channel": "foo"}, + {"foo": "qux", "token": 3, "channel": "foo"}, + {"foo": "zoom", "token": 4, "channel": "foo"}, + }))) + require.NoError(t, produce(context.Background(), Batch([]map[string]any{ + {"foo": "qux", "token": 3, "channel": "bar"}, + {"foo": "zoom", "token": 4, "channel": "bar"}, + {"foo": "thud", "token": 5, "channel": "bar"}, + {"foo": "zing", "token": 6, "channel": "bar"}, + }))) + require.NoError(t, produce(context.Background(), Batch([]map[string]any{ + {"foo": "thud", "token": 5, "channel": "bar"}, + {"foo": "zing", "token": 6, "channel": "bar"}, + {"foo": "bizz", "token": 7, "channel": "bar"}, + {"foo": "bang", "token": 8, "channel": "bar"}, + }))) + rows := RunSQLQuery( + t, + stream, + `SELECT foo, token, channel FROM integration_test_named_channels ORDER BY channel, token`, + ) + require.Equal(t, [][]string{ + {"qux", "3", "bar"}, + {"zoom", "4", "bar"}, + {"thud", "5", "bar"}, + {"zing", "6", "bar"}, + {"bizz", "7", "bar"}, + {"bang", "8", "bar"}, + {"bar", "1", "foo"}, + {"baz", "2", "foo"}, + {"qux", "3", "foo"}, + {"zoom", "4", "foo"}, + }, rows) +} diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index ed94eb608c..db18847ac0 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -564,6 +564,7 @@ func newSnowflakeStreamer( } foo := &snowpipeStreamingOutput{ initStatementsFn: initStatementsFn, + client: client, restClient: restClient, mapping: mapping, logger: mgr.Logger(), @@ -637,16 +638,8 @@ func (o *snowpipeStreamingOutput) WriteBatch(ctx context.Context, batch service. batch = mapped } if o.needsTableCreation.Load() { // Check outside of lock - o.mu.Lock() - defer o.mu.Unlock() - if o.needsTableCreation.Load() { - if err := o.schemaEvolver.CreateOutputTable(ctx, batch); err != nil { - return err - } - if err := o.impl.Connect(ctx); err != nil { - return err - } - o.needsTableCreation.Store(false) + if err := o.createTable(ctx, batch); err != nil { + return err } // Now we can proceed writing } @@ -675,6 +668,22 @@ func (o *snowpipeStreamingOutput) WriteBatch(ctx context.Context, batch service. return err } +func (o *snowpipeStreamingOutput) createTable(ctx context.Context, batch service.MessageBatch) error { + o.mu.Lock() + defer o.mu.Unlock() + if !o.needsTableCreation.Load() { + return nil + } + if err := o.schemaEvolver.CreateOutputTable(ctx, batch); err != nil { + return err + } + if err := o.impl.Connect(ctx); err != nil { + return err + } + o.needsTableCreation.Store(false) + return nil +} + func (o *snowpipeStreamingOutput) runMigration(ctx context.Context, needsMigrationErr schemaMigrationNeededError) error { if err := needsMigrationErr.runMigration(ctx, o.schemaEvolver); err != nil { return err @@ -694,8 +703,8 @@ func (o *snowpipeStreamingOutput) Close(ctx context.Context) error { if err := o.impl.Close(ctx); err != nil { return err } - o.restClient.Close() o.client.Close() + o.restClient.Close() return nil }