From 51248b62ccbf72633e4225f7a03e3c481dcee77f Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Sun, 3 Nov 2024 02:29:49 +0000 Subject: [PATCH 1/7] snowflake: add strict schema enforcement this is a mechanism that can be used to evolve the schema when extra data is coming through --- internal/impl/snowflake/streaming/parquet.go | 17 +++-- .../impl/snowflake/streaming/parquet_test.go | 1 + .../impl/snowflake/streaming/schema_errors.go | 74 +++++++++++++++++++ .../impl/snowflake/streaming/streaming.go | 4 +- 4 files changed, 90 insertions(+), 6 deletions(-) create mode 100644 internal/impl/snowflake/streaming/schema_errors.go diff --git a/internal/impl/snowflake/streaming/parquet.go b/internal/impl/snowflake/streaming/parquet.go index 873ac9f02c..494ba00618 100644 --- a/internal/impl/snowflake/streaming/parquet.go +++ b/internal/impl/snowflake/streaming/parquet.go @@ -13,6 +13,7 @@ package streaming import ( "bytes" "encoding/binary" + "errors" "fmt" "github.com/parquet-go/parquet-go" @@ -24,7 +25,7 @@ import ( // messageToRow converts a message into columnar form using the provided name to index mapping. // We have to materialize the column into a row so that we can know if a column is null - the // msg can be sparse, but the row must not be sparse. -func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int) error { +func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int, allowExtraProperties bool) error { v, err := msg.AsStructured() if err != nil { return fmt.Errorf("error extracting object from message: %w", err) @@ -36,8 +37,9 @@ func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int for k, v := range row { idx, ok := nameToPosition[normalizeColumnName(k)] if !ok { - // TODO(schema): Unknown column, we just skip it. - // In the future we may evolve the schema based on the new data. + if !allowExtraProperties && v != nil { + return MissingColumnError{columnName: k, val: v} + } continue } out[idx] = v @@ -49,6 +51,7 @@ func constructRowGroup( batch service.MessageBatch, schema *parquet.Schema, transformers []*dataTransformer, + allowExtraProperties bool, ) ([]parquet.Row, []*statsBuffer, error) { // We write all of our data in a columnar fashion, but need to pivot that data so that we can feed it into // out parquet library (which sadly will redo the pivot - maybe we need a lower level abstraction...). @@ -76,7 +79,7 @@ func constructRowGroup( // is needed row := make([]any, rowWidth) for _, msg := range batch { - err := messageToRow(msg, row, nameToPosition) + err := messageToRow(msg, row, nameToPosition, allowExtraProperties) if err != nil { return nil, nil, err } @@ -86,7 +89,11 @@ func constructRowGroup( b := buffers[i] err = t.converter.ValidateAndConvert(s, v, b) if err != nil { - // TODO(schema): if this is a null value err then we can evolve the schema to mark it null. + if errors.Is(err, errNullValue) { + return nil, nil, NonNullColumnError{t.column.Name} + } + // There is not special typed error for a validation error, there really isn't + // anything we can do about it. return nil, nil, fmt.Errorf("invalid data for column %s: %w", t.name, err) } // reset the column as nil for the next row diff --git a/internal/impl/snowflake/streaming/parquet_test.go b/internal/impl/snowflake/streaming/parquet_test.go index 27c2581345..f4e2ecb47d 100644 --- a/internal/impl/snowflake/streaming/parquet_test.go +++ b/internal/impl/snowflake/streaming/parquet_test.go @@ -60,6 +60,7 @@ func TestWriteParquet(t *testing.T) { batch, schema, transformers, + false, ) require.NoError(t, err) b, err := writeParquetFile("latest", parquetFileData{ diff --git a/internal/impl/snowflake/streaming/schema_errors.go b/internal/impl/snowflake/streaming/schema_errors.go new file mode 100644 index 0000000000..f7534cf37b --- /dev/null +++ b/internal/impl/snowflake/streaming/schema_errors.go @@ -0,0 +1,74 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import "fmt" + +// SchemaMismatchError occurs when the user provided data has data that +// doesn't match the schema *and* the table can be evolved to accomidate +// +// This can be used as a mechanism to evolve the schema dynamically. +type SchemaMismatchError interface { + ColumnName() string + Value() any +} + +var _ error = NonNullColumnError{} +var _ SchemaMismatchError = NonNullColumnError{} + +// NonNullColumnError occurs when a column with a NOT NULL constraint +// gets a value with a `NULL` value. +type NonNullColumnError struct { + columnName string +} + +// ColumnName returns the column name with the NOT NULL constraint +func (e NonNullColumnError) ColumnName() string { + return e.columnName +} + +// ColumnName returns nil +func (e NonNullColumnError) Value() any { + return nil +} + +// Error implements the error interface +func (e NonNullColumnError) Error() string { + return fmt.Sprintf("column %q has a NOT NULL constraint and recieved a nil value", e.columnName) +} + +var _ error = MissingColumnError{} +var _ SchemaMismatchError = MissingColumnError{} + +// MissingColumnError occurs when a column that is not in the table is +// found on a record +type MissingColumnError struct { + columnName string + val any +} + +// ColumnName returns the column name of the data that was not in the table +// +// NOTE that this isn't escaped, so it's important that we don't just inject +// it directly in the query. +func (e MissingColumnError) ColumnName() string { + return e.columnName +} + +// ColumnName returns the value that was associated with the missing column +func (e MissingColumnError) Value() any { + return e.val +} + +// Error implements the error interface +func (e MissingColumnError) Error() string { + return fmt.Sprintf("new data %+v with the name %q does not have an associated column", e.val, e.columnName) +} diff --git a/internal/impl/snowflake/streaming/streaming.go b/internal/impl/snowflake/streaming/streaming.go index 4f22879172..5c68aa6740 100644 --- a/internal/impl/snowflake/streaming/streaming.go +++ b/internal/impl/snowflake/streaming/streaming.go @@ -145,6 +145,8 @@ type ChannelOptions struct { TableName string // The max parallelism used to build parquet files and convert message batches into rows. BuildParallelism int + // If set to true, don't ignore extra columns in user data, but raise an error. + StrictSchemaEnforcement bool } type encryptionInfo struct { @@ -304,7 +306,7 @@ func (c *SnowflakeIngestionChannel) constructBdecPart(batch service.MessageBatch rowGroups = append(rowGroups, rowGroup{}) chunk := batch[i : i+end] wg.Go(func() error { - rows, stats, err := constructRowGroup(chunk, c.schema, c.transformers) + rows, stats, err := constructRowGroup(chunk, c.schema, c.transformers, !c.StrictSchemaEnforcement) rowGroups[j] = rowGroup{rows, stats} return err }) From e618ba35ea94b443216dade7a7be88ac90608e6c Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Mon, 4 Nov 2024 16:12:28 +0000 Subject: [PATCH 2/7] snowflake: support bindings --- internal/impl/snowflake/streaming/rest.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/internal/impl/snowflake/streaming/rest.go b/internal/impl/snowflake/streaming/rest.go index 737d64f00f..7df5dfa8a0 100644 --- a/internal/impl/snowflake/streaming/rest.go +++ b/internal/impl/snowflake/streaming/rest.go @@ -249,14 +249,21 @@ type ( Message string `json:"message"` Blobs []blobRegisterStatus `json:"blobs"` } + BindingValue struct { + // The binding data type, generally TEXT is what you want + // see: https://docs.snowflake.com/en/developer-guide/sql-api/submitting-requests#using-bind-variables-in-a-statement + Type string `json:"type"` + Value string `json:"value"` + } // RunSQLRequest is the way to run a SQL statement RunSQLRequest struct { - Statement string `json:"statement"` - Timeout int64 `json:"timeout"` - Database string `json:"database,omitempty"` - Schema string `json:"schema,omitempty"` - Warehouse string `json:"warehouse,omitempty"` - Role string `json:"role,omitempty"` + Statement string `json:"statement"` + Timeout int64 `json:"timeout"` + Database string `json:"database,omitempty"` + Schema string `json:"schema,omitempty"` + Warehouse string `json:"warehouse,omitempty"` + Role string `json:"role,omitempty"` + Bindings map[string]BindingValue `json:"bindings,omitempty"` // https://docs.snowflake.com/en/sql-reference/parameters Parameters map[string]string `json:"parameters,omitempty"` } From 7a90aa9ec8f51e2d8dfab8ec54e0a8e4d7634d71 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Mon, 4 Nov 2024 18:27:08 +0000 Subject: [PATCH 3/7] snowflake: add a function to quote an identifier Will be needed for the column name we use for schema evolution --- internal/impl/snowflake/streaming/compat.go | 21 +++++++++++++++++++ .../impl/snowflake/streaming/compat_test.go | 10 +++++++++ .../impl/snowflake/streaming/schema_errors.go | 7 ++++--- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/internal/impl/snowflake/streaming/compat.go b/internal/impl/snowflake/streaming/compat.go index facbb9e735..188cbfc2d6 100644 --- a/internal/impl/snowflake/streaming/compat.go +++ b/internal/impl/snowflake/streaming/compat.go @@ -150,6 +150,27 @@ func normalizeColumnName(name string) string { return strings.ToUpper(strings.ReplaceAll(name, `\ `, ` `)) } +// quoteColumnName escapes an object identifier according to the +// rules in Snowflake. +// +// https://docs.snowflake.com/en/sql-reference/identifiers-syntax +func quoteColumnName(name string) string { + var quoted strings.Builder + // Default to assume we're just going to add quotes and there won't + // be any double quotes inside the string that need escaped. + quoted.Grow(len(name) + 2) + quoted.WriteByte('"') + for _, r := range name { + if r == '"' { + quoted.WriteString(`""`) + } else { + quoted.WriteRune(r) + } + } + quoted.WriteByte('"') + return quoted.String() +} + // snowflakeTimestampInt computes the same result as the logic in TimestampWrapper // in the Java SDK. It converts a timestamp to the integer representation that // is used internally within Snowflake. diff --git a/internal/impl/snowflake/streaming/compat_test.go b/internal/impl/snowflake/streaming/compat_test.go index 4b4aaffa64..71a0f2bb6e 100644 --- a/internal/impl/snowflake/streaming/compat_test.go +++ b/internal/impl/snowflake/streaming/compat_test.go @@ -127,6 +127,16 @@ func TestColumnNormalization(t *testing.T) { require.Equal(t, `foo" bar "baz`, normalizeColumnName(`"foo"" bar ""baz"`)) } +func TestColumnQuoting(t *testing.T) { + require.Equal(t, `""`, quoteColumnName("")) + require.Equal(t, `"foo"`, quoteColumnName("foo")) + require.Equal(t, `"""bar"""`, quoteColumnName(`"bar"`)) + require.Equal(t, `"foo bar"`, quoteColumnName(`foo bar`)) + require.Equal(t, `"foo\ bar"`, quoteColumnName(`foo\ bar`)) + require.Equal(t, `"foo""bar"`, quoteColumnName(`foo"bar`)) + require.Equal(t, `""""""""""`, quoteColumnName(`""""`)) +} + func TestSnowflakeTimestamp(t *testing.T) { type TestCase struct { timestamp string diff --git a/internal/impl/snowflake/streaming/schema_errors.go b/internal/impl/snowflake/streaming/schema_errors.go index f7534cf37b..debecdf5d9 100644 --- a/internal/impl/snowflake/streaming/schema_errors.go +++ b/internal/impl/snowflake/streaming/schema_errors.go @@ -32,6 +32,7 @@ type NonNullColumnError struct { // ColumnName returns the column name with the NOT NULL constraint func (e NonNullColumnError) ColumnName() string { + // This name comes directly from the Snowflake API so I hope this is properly quoted... return e.columnName } @@ -57,10 +58,10 @@ type MissingColumnError struct { // ColumnName returns the column name of the data that was not in the table // -// NOTE that this isn't escaped, so it's important that we don't just inject -// it directly in the query. +// NOTE this is escaped, so it's valid to use this directly in a SQL statement +// but I wish that Snowflake would just allow `identifier` for ALTER column. func (e MissingColumnError) ColumnName() string { - return e.columnName + return quoteColumnName(e.columnName) } // ColumnName returns the value that was associated with the missing column From 2582de92af6d2d8e70a061a0b12f761887dd09a4 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Mon, 4 Nov 2024 21:37:39 +0000 Subject: [PATCH 4/7] snowflake: schema evolution Enable schema evolution using a custom mapping function, to allow customers to customize the evolution method. Schema evolution needs to be a global pause so that we can reopen channels to get new schemas. I'm not in love with how errors are being hijacked here but I have not came up with a better idea in the short time working on. --- .../pages/outputs/snowflake_streaming.adoc | 62 +++- .../snowflake/output_snowflake_streaming.go | 336 ++++++++++++++---- .../impl/snowflake/output_streaming_test.go | 52 +++ internal/impl/snowflake/streaming/compat.go | 2 +- .../impl/snowflake/streaming/compat_test.go | 11 +- .../impl/snowflake/streaming/schema_errors.go | 6 + 6 files changed, 392 insertions(+), 77 deletions(-) create mode 100644 internal/impl/snowflake/output_streaming_test.go diff --git a/docs/modules/components/pages/outputs/snowflake_streaming.adoc b/docs/modules/components/pages/outputs/snowflake_streaming.adoc index 23f6a26f2c..a5a0c9f8e0 100644 --- a/docs/modules/components/pages/outputs/snowflake_streaming.adoc +++ b/docs/modules/components/pages/outputs/snowflake_streaming.adoc @@ -39,7 +39,7 @@ Common:: output: label: "" snowflake_streaming: - account: AAAAAAA-AAAAAAA # No default (required) + account: ORG-ACCOUNT # No default (required) user: "" # No default (required) role: ACCOUNTADMIN # No default (required) database: "" # No default (required) @@ -51,6 +51,17 @@ output: mapping: "" # No default (optional) init_statement: | # No default (optional) CREATE TABLE IF NOT EXISTS mytable (amount NUMBER); + schema_evolution: + enabled: false # No default (required) + new_column_type_mapping: |- + root = match this.value.type() { + this == "string" => "STRING" + this == "bytes" => "BINARY" + this == "number" => "DOUBLE" + this == "bool" => "BOOLEAN" + this == "timestamp" => "TIMESTAMP" + _ => "VARIANT" + } batching: count: 0 byte_size: 0 @@ -69,7 +80,7 @@ Advanced:: output: label: "" snowflake_streaming: - account: AAAAAAA-AAAAAAA # No default (required) + account: ORG-ACCOUNT # No default (required) user: "" # No default (required) role: ACCOUNTADMIN # No default (required) database: "" # No default (required) @@ -81,6 +92,17 @@ output: mapping: "" # No default (optional) init_statement: | # No default (optional) CREATE TABLE IF NOT EXISTS mytable (amount NUMBER); + schema_evolution: + enabled: false # No default (required) + new_column_type_mapping: |- + root = match this.value.type() { + this == "string" => "STRING" + this == "bytes" => "BINARY" + this == "number" => "DOUBLE" + this == "bool" => "BOOLEAN" + this == "timestamp" => "TIMESTAMP" + _ => "VARIANT" + } build_parallelism: 1 batching: count: 0 @@ -170,6 +192,8 @@ output: schema: "PUBLIC" table: "MYTABLE" private_key_file: "my/private/key.p8" + schema_evolution: + enabled: true ``` -- @@ -214,10 +238,7 @@ output: === `account` -Account name, which is the same as the https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#where-are-account-identifiers-used[Account Identifier^]. - However, when using an https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account Locator^], - the Account Identifier is formatted as `..` and this field needs to be - populated using the `` part. +The Snowflake https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account name^]. Which should be formatted as `-` where `` is the name of your Snowflake organization and `` is the unique name of your account within your organization. *Type*: `string` @@ -226,7 +247,7 @@ Account name, which is the same as the https://docs.snowflake.com/en/user-guide/ ```yml # Examples -account: AAAAAAA-AAAAAAA +account: ORG-ACCOUNT ``` === `user` @@ -336,6 +357,33 @@ init_statement: |2 ALTER TABLE t1 ADD COLUMN a2 NUMBER; ``` +=== `schema_evolution` + +Options to control schema evolution within the pipeline as new columns are added to the pipeline. + + +*Type*: `object` + + +=== `schema_evolution.enabled` + +Whether schema evolution is enabled. + + +*Type*: `bool` + + +=== `schema_evolution.new_column_type_mapping` + +The mapping function from Redpanda Connect type to column type in Snowflake. Overriding this can allow for customization of the datatype if there is specific information that you know about the data types in use. This mapping should result in the `root` variable being assigned a string with the data type for the new column in Snowflake. + +The input to this mapping is an object with the value and the name of the new column, for example: `{"value": 42.3, "name":"new_data_field"}" + + +*Type*: `string` + +*Default*: `"root = match this.value.type() {\n this == \"string\" =\u003e \"STRING\"\n this == \"bytes\" =\u003e \"BINARY\"\n this == \"number\" =\u003e \"DOUBLE\"\n this == \"bool\" =\u003e \"BOOLEAN\"\n this == \"timestamp\" =\u003e \"TIMESTAMP\"\n _ =\u003e \"VARIANT\"\n}"` + === `build_parallelism` The maximum amount of parallelism to use when building the output for Snowflake. The metric to watch to see if you need to change this is `snowflake_build_output_latency_ns`. diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index 9cdb34558a..679e758c12 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -11,7 +11,9 @@ package snowflake import ( "context" "crypto/rsa" + "errors" "fmt" + "regexp" "sync" "github.com/redpanda-data/benthos/v4/public/bloblang" @@ -21,20 +23,32 @@ import ( ) const ( - ssoFieldAccount = "account" - ssoFieldUser = "user" - ssoFieldRole = "role" - ssoFieldDB = "database" - ssoFieldSchema = "schema" - ssoFieldTable = "table" - ssoFieldKey = "private_key" - ssoFieldKeyFile = "private_key_file" - ssoFieldKeyPass = "private_key_pass" - ssoFieldInitStatement = "init_statement" - ssoFieldBatching = "batching" - ssoFieldChannelPrefix = "channel_prefix" - ssoFieldMapping = "mapping" - ssoFieldBuildParallelism = "build_parallelism" + ssoFieldAccount = "account" + ssoFieldUser = "user" + ssoFieldRole = "role" + ssoFieldDB = "database" + ssoFieldSchema = "schema" + ssoFieldTable = "table" + ssoFieldKey = "private_key" + ssoFieldKeyFile = "private_key_file" + ssoFieldKeyPass = "private_key_pass" + ssoFieldInitStatement = "init_statement" + ssoFieldBatching = "batching" + ssoFieldChannelPrefix = "channel_prefix" + ssoFieldMapping = "mapping" + ssoFieldBuildParallelism = "build_parallelism" + ssoFieldSchemaEvolution = "schema_evolution" + ssoFieldSchemaEvolutionEnabled = "enabled" + ssoFieldSchemaEvolutionNewColumnTypeMapping = "new_column_type_mapping" + + defaultSchemaEvolutionNewColumnMapping = `root = match this.value.type() { + this == "string" => "STRING" + this == "bytes" => "BINARY" + this == "number" => "DOUBLE" + this == "bool" => "BOOLEAN" + this == "timestamp" => "TIMESTAMP" + _ => "VARIANT" +}` ) func snowflakeStreamingOutputConfig() *service.ConfigSpec { @@ -70,11 +84,8 @@ You can monitor the output batch size using the `+"`snowflake_compressed_output_ `). Fields( service.NewStringField(ssoFieldAccount). - Description(`Account name, which is the same as the https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#where-are-account-identifiers-used[Account Identifier^]. - However, when using an https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account Locator^], - the Account Identifier is formatted as `+"`..`"+` and this field needs to be - populated using the `+"``"+` part. -`).Example("AAAAAAA-AAAAAAA"), + Description(`The Snowflake https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account name^]. Which should be formatted as `+"`-`"+` where `+"``"+` is the name of your Snowflake organization and `+"``"+` is the unique name of your account within your organization. +`).Example("ORG-ACCOUNT"), service.NewStringField(ssoFieldUser).Description("The user to run the Snowpipe Stream as. See https://docs.snowflake.com/en/user-guide/admin-user-management[Snowflake Documentation^] on how to create a user."), service.NewStringField(ssoFieldRole).Description("The role for the `user` field. The role must have the https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#required-access-privileges[required privileges^] to call the Snowpipe Streaming APIs. See https://docs.snowflake.com/en/user-guide/admin-user-management#user-roles[Snowflake Documentation^] for more information about roles.").Example("ACCOUNTADMIN"), service.NewStringField(ssoFieldDB).Description("The Snowflake database to ingest data into."), @@ -92,6 +103,13 @@ CREATE TABLE IF NOT EXISTS mytable (amount NUMBER); ALTER TABLE t1 ALTER COLUMN c1 DROP NOT NULL; ALTER TABLE t1 ADD COLUMN a2 NUMBER; `), + service.NewObjectField(ssoFieldSchemaEvolution, + service.NewBoolField(ssoFieldSchemaEvolutionEnabled).Description("Whether schema evolution is enabled."), + service.NewBloblangField(ssoFieldSchemaEvolutionNewColumnTypeMapping).Description(` +The mapping function from Redpanda Connect type to column type in Snowflake. Overriding this can allow for customization of the datatype if there is specific information that you know about the data types in use. This mapping should result in the `+"`root`"+` variable being assigned a string with the data type for the new column in Snowflake. + +The input to this mapping is an object with the value and the name of the new column, for example: `+"`"+`{"value": 42.3, "name":"new_data_field"}`+`"`).Default(defaultSchemaEvolutionNewColumnMapping), + ).Description(`Options to control schema evolution within the pipeline as new columns are added to the pipeline.`).Optional(), service.NewIntField(ssoFieldBuildParallelism).Description("The maximum amount of parallelism to use when building the output for Snowflake. The metric to watch to see if you need to change this is `snowflake_build_output_latency_ns`.").Default(1).Advanced(), service.NewBatchPolicyField(ssoFieldBatching), service.NewOutputMaxInFlightField(), @@ -144,6 +162,8 @@ output: schema: "PUBLIC" table: "MYTABLE" private_key_file: "my/private/key.p8" + schema_evolution: + enabled: true `, ). Example( @@ -268,6 +288,17 @@ func newSnowflakeStreamer( return nil, err } } + var schemaEvolutionMapping *bloblang.Executor + if conf.Contains(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionEnabled) { + enabled, err := conf.FieldBool(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionEnabled) + if err == nil && enabled { + schemaEvolutionMapping, err = conf.FieldBloblang(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionNewColumnTypeMapping) + } + if err != nil { + return nil, err + } + } + buildParallelism, err := conf.FieldInt(ssoFieldBuildParallelism) if err != nil { return nil, err @@ -284,19 +315,14 @@ func newSnowflakeStreamer( // stream to write to a single table. channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table) } - var initStatementsFn func(context.Context) error + var initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error if conf.Contains(ssoFieldInitStatement) { initStatements, err := conf.FieldString(ssoFieldInitStatement) if err != nil { return nil, err } - initStatementsFn = func(ctx context.Context) error { - c, err := streaming.NewRestClient(account, user, mgr.EngineVersion(), channelPrefix, rsaKey, mgr.Logger()) - if err != nil { - return err - } - defer c.Close() - _, err = c.RunSQL(ctx, streaming.RunSQLRequest{ + initStatementsFn = func(ctx context.Context, client *streaming.SnowflakeRestClient) error { + _, err = client.RunSQL(ctx, streaming.RunSQLRequest{ Statement: initStatements, // Currently we set of timeout of 30 seconds so that we don't have to handle async operations // that need polling to wait until they finish (results are made async when execution is longer @@ -313,6 +339,10 @@ func newSnowflakeStreamer( return err } } + restClient, err := streaming.NewRestClient(account, user, mgr.EngineVersion(), channelPrefix, rsaKey, mgr.Logger()) + if err != nil { + return nil, fmt.Errorf("unable to create rest API client: %w", err) + } client, err := streaming.NewSnowflakeServiceClient( context.Background(), streaming.ClientOptions{ @@ -328,40 +358,46 @@ func newSnowflakeStreamer( return nil, err } o := &snowflakeStreamerOutput{ - channelPrefix: channelPrefix, - client: client, - db: db, - schema: schema, - table: table, - mapping: mapping, - logger: mgr.Logger(), - buildTime: mgr.Metrics().NewTimer("snowflake_build_output_latency_ns"), - uploadTime: mgr.Metrics().NewTimer("snowflake_upload_latency_ns"), - convertTime: mgr.Metrics().NewTimer("snowflake_convert_latency_ns"), - serializeTime: mgr.Metrics().NewTimer("snowflake_serialize_latency_ns"), - compressedOutput: mgr.Metrics().NewCounter("snowflake_compressed_output_size_bytes"), - initStatementsFn: initStatementsFn, - buildParallelism: buildParallelism, + channelPrefix: channelPrefix, + client: client, + db: db, + schema: schema, + table: table, + role: role, + mapping: mapping, + logger: mgr.Logger(), + buildTime: mgr.Metrics().NewTimer("snowflake_build_output_latency_ns"), + uploadTime: mgr.Metrics().NewTimer("snowflake_upload_latency_ns"), + convertTime: mgr.Metrics().NewTimer("snowflake_convert_latency_ns"), + serializeTime: mgr.Metrics().NewTimer("snowflake_serialize_latency_ns"), + compressedOutput: mgr.Metrics().NewCounter("snowflake_compressed_output_size_bytes"), + initStatementsFn: initStatementsFn, + buildParallelism: buildParallelism, + schemaEvolutionMapping: schemaEvolutionMapping, + restClient: restClient, } return o, nil } type snowflakeStreamerOutput struct { - client *streaming.SnowflakeServiceClient - channelPool sync.Pool - channelCreationMu sync.Mutex - poolSize int - compressedOutput *service.MetricCounter - uploadTime *service.MetricTimer - buildTime *service.MetricTimer - convertTime *service.MetricTimer - serializeTime *service.MetricTimer - buildParallelism int - - channelPrefix, db, schema, table string - mapping *bloblang.Executor - logger *service.Logger - initStatementsFn func(context.Context) error + client *streaming.SnowflakeServiceClient + channelPool sync.Pool + channelCreationMu sync.Mutex + poolSize int + compressedOutput *service.MetricCounter + uploadTime *service.MetricTimer + buildTime *service.MetricTimer + convertTime *service.MetricTimer + serializeTime *service.MetricTimer + buildParallelism int + schemaEvolutionMapping *bloblang.Executor + + schemaMigrationMu sync.RWMutex + channelPrefix, db, schema, table, role string + mapping *bloblang.Executor + logger *service.Logger + initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error + restClient *streaming.SnowflakeRestClient } func (o *snowflakeStreamerOutput) openNewChannel(ctx context.Context) (*streaming.SnowflakeIngestionChannel, error) { @@ -380,19 +416,20 @@ func (o *snowflakeStreamerOutput) openNewChannel(ctx context.Context) (*streamin func (o *snowflakeStreamerOutput) openChannel(ctx context.Context, name string, id int16) (*streaming.SnowflakeIngestionChannel, error) { o.logger.Debugf("opening snowflake streaming channel: %s", name) return o.client.OpenChannel(ctx, streaming.ChannelOptions{ - ID: id, - Name: name, - DatabaseName: o.db, - SchemaName: o.schema, - TableName: o.table, - BuildParallelism: o.buildParallelism, + ID: id, + Name: name, + DatabaseName: o.db, + SchemaName: o.schema, + TableName: o.table, + BuildParallelism: o.buildParallelism, + StrictSchemaEnforcement: o.schemaEvolutionMapping != nil, }) } func (o *snowflakeStreamerOutput) Connect(ctx context.Context) error { if o.initStatementsFn != nil { - if err := o.initStatementsFn(ctx); err != nil { - return err + if err := o.initStatementsFn(ctx, o.restClient); err != nil { + return fmt.Errorf("unable to run initialization statement: %w", err) } // We've already executed our init statement, we don't need to do that anymore o.initStatementsFn = nil @@ -419,6 +456,29 @@ func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service. } batch = mapped } + var err error + // We only migrate one column at a time, so tolerate up to 10 schema + // migrations for a single batch before giving up. This protects against + // any bugs over infinitely looping. + for i := 0; i < 10; i++ { + err = o.WriteBatchInternal(ctx, batch) + if err == nil { + return nil + } + migrationErr := schemaMigrationNeededError{} + if !errors.As(err, &migrationErr) { + break + } + if err := migrationErr.migrator(ctx); err != nil { + return err + } + } + return err +} + +func (o *snowflakeStreamerOutput) WriteBatchInternal(ctx context.Context, batch service.MessageBatch) error { + o.schemaMigrationMu.RLock() + defer o.schemaMigrationMu.RUnlock() var channel *streaming.SnowflakeIngestionChannel if maybeChan := o.channelPool.Get(); maybeChan != nil { channel = maybeChan.(*streaming.SnowflakeIngestionChannel) @@ -438,6 +498,31 @@ func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service. o.convertTime.Timing(stats.ConvertTime.Nanoseconds()) o.serializeTime.Timing(stats.SerializeTime.Nanoseconds()) } else { + // Only evolve the schema if requested. + if o.schemaEvolutionMapping != nil { + nullColumnErr := streaming.NonNullColumnError{} + if errors.As(err, &nullColumnErr) { + // put the channel back so that we can reopen it along with the rest of the channels to + // pick up the new schema. + o.channelPool.Put(channel) + // Return an error so that we release our read lock and can take the write lock + // to forcibly reopen all our channels to get a new schema. + return schemaMigrationNeededError{ + migrator: func(ctx context.Context) error { + return o.MigrateNotNullColumn(ctx, nullColumnErr) + }, + } + } + missingColumnErr := streaming.MissingColumnError{} + if errors.As(err, &missingColumnErr) { + o.channelPool.Put(channel) + return schemaMigrationNeededError{ + migrator: func(ctx context.Context) error { + return o.MigrateMissingColumn(ctx, missingColumnErr) + }, + } + } + } reopened, reopenErr := o.openChannel(ctx, channel.Name, channel.ID) if reopenErr == nil { o.channelPool.Put(reopened) @@ -456,6 +541,129 @@ func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service. return err } +type schemaMigrationNeededError struct { + migrator func(ctx context.Context) error +} + +func (schemaMigrationNeededError) Error() string { + return "schema migration was required and the operation needs to be retried after the migration" +} + +func (o *snowflakeStreamerOutput) MigrateMissingColumn(ctx context.Context, col streaming.MissingColumnError) error { + o.schemaMigrationMu.Lock() + defer o.schemaMigrationMu.Unlock() + msg := service.NewMessage(nil) + msg.SetStructuredMut(map[string]any{ + "name": col.RawName(), + "value": col.Value(), + }) + out, err := msg.BloblangQuery(o.schemaEvolutionMapping) + if err != nil { + return fmt.Errorf("unable to compute new column type for %s: %w", col.ColumnName(), err) + } + v, err := out.AsBytes() + if err != nil { + return fmt.Errorf("unable to compute new column type for %s: %w", col.ColumnName(), err) + } + columnType := string(v) + if err := validateColumnType(columnType); err != nil { + return err + } + o.logger.Infof("identified new schema - attempting to alter table to add column: %s %s", col.ColumnName(), columnType) + _, err = o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ + // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name here + // comes directly from the Snowflake API so it better be escaped correctly. This is also why we need to + // validate the data type, so that you can't sneak an injection attack in there. + Statement: fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) + ADD COLUMN IF NOT EXISTS %s %s + COMMENT 'column created by schema evolution from Redpanda Connect'`, + col.ColumnName(), + columnType, + ), + // Currently we set of timeout of 30 seconds so that we don't have to handle async operations + // that need polling to wait until they finish (results are made async when execution is longer + // than 45 seconds). + Timeout: 30, + Database: o.db, + Schema: o.schema, + Role: o.role, + Bindings: map[string]streaming.BindingValue{ + "1": {Type: "TEXT", Value: o.table}, + }, + }) + if err != nil { + o.logger.Warnf("unable to add new column, this maybe due to a race with another request, error: %s", err) + } + return o.ReopenAllChannels(ctx) +} + +func (o *snowflakeStreamerOutput) MigrateNotNullColumn(ctx context.Context, col streaming.NonNullColumnError) error { + o.schemaMigrationMu.Lock() + defer o.schemaMigrationMu.Unlock() + o.logger.Infof("identified new schema - attempting to alter table to remove null constraint on column: %s", col.ColumnName()) + _, err := o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ + // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name here + // comes directly from the Snowflake API so it better not have a SQL injection :) + Statement: fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) ALTER + %s DROP NOT NULL, + %s COMMENT 'column altered to be nullable by schema evolution from Redpanda Connect'`, + col.ColumnName(), + col.ColumnName(), + ), + // Currently we set of timeout of 30 seconds so that we don't have to handle async operations + // that need polling to wait until they finish (results are made async when execution is longer + // than 45 seconds). + Timeout: 30, + Database: o.db, + Schema: o.schema, + Role: o.role, + Bindings: map[string]streaming.BindingValue{ + "1": {Type: "TEXT", Value: o.table}, + }, + }) + if err != nil { + o.logger.Warnf("unable to mark column %s as null, this maybe due to a race with another request, error: %s", col.ColumnName(), err) + } + return o.ReopenAllChannels(ctx) +} + +// ReopenAllChannels should be called while holding schemaMigrationMu so that +// all channels are actually processed +func (o *snowflakeStreamerOutput) ReopenAllChannels(ctx context.Context) error { + all := []*streaming.SnowflakeIngestionChannel{} + for { + maybeChan := o.channelPool.Get() + if maybeChan == nil { + break + } + channel := maybeChan.(*streaming.SnowflakeIngestionChannel) + reopened, reopenErr := o.openChannel(ctx, channel.Name, channel.ID) + if reopenErr == nil { + channel = reopened + } else { + o.logger.Warnf("unable to reopen channel %q schema migration: %v", channel.Name, reopenErr) + // Keep the existing channel so we don't reopen channels, but instead retry later. + } + all = append(all, channel) + } + for _, c := range all { + o.channelPool.Put(c) + } + return nil +} + func (o *snowflakeStreamerOutput) Close(ctx context.Context) error { + o.restClient.Close() return o.client.Close() } + +// This doesn't need to fully match, but be enough to prevent SQL injection as well as +// catch common errors. +var validColumnTypeRegex = regexp.MustCompile(`^\s*(?i:NUMBER|DECIMAL|NUMERIC|INT|INTEGER|BIGINT|SMALLINT|TINYINT|BYTEINT|FLOAT|FLOAT4|FLOAT8|DOUBLE|DOUBLE\s+PRECISION|REAL|VARCHAR|CHAR|CHARACTER|STRING|TEXT|BINARY|VARBINARY|BOOLEAN|DATE|DATETIMETIME|TIMESTAMP|TIEMSTAMP_LTZ|TIMESTAMP_NTZ|TIMESTAMP_TZ|VARIANT|OBJECT|ARRAY)\s*(?:\(\s*\d+\s*\)|\(\s*\d+\s*,\s*\d+\s*\))?\s*$`) + +func validateColumnType(v string) error { + if validColumnTypeRegex.MatchString(v) { + return nil + } + return fmt.Errorf("invalid Snowflake column data type: %s", v) +} diff --git a/internal/impl/snowflake/output_streaming_test.go b/internal/impl/snowflake/output_streaming_test.go new file mode 100644 index 0000000000..5a3f46c1cd --- /dev/null +++ b/internal/impl/snowflake/output_streaming_test.go @@ -0,0 +1,52 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package snowflake + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidColumnTypeRegex(t *testing.T) { + matches := []string{ + "INT", + "NUMBER", + "NUMBER ( 38, 0 )", + " NUMBER ( 38, 0 ) ", + "DOUBLE PRECISION", + "DOUBLE PRECISION", + " varchar ( 99 ) ", + " varchar ( 0 ) ", + } + for _, m := range matches { + m := m + t.Run(m, func(t *testing.T) { + require.Regexp(t, validColumnTypeRegex, m) + }) + } + nonMatches := []string{ + "VAR", + "N", + "VAR(1, 3)", + "VAR(1)", + "VARCHAR()", + "VARCHAR( )", + "GARBAGE VARCHAR(2)", + "VARCHAR(2) GARBAGE", + } + for _, m := range nonMatches { + m := m + t.Run(m, func(t *testing.T) { + require.NotRegexp(t, validColumnTypeRegex, m) + }) + } +} diff --git a/internal/impl/snowflake/streaming/compat.go b/internal/impl/snowflake/streaming/compat.go index 188cbfc2d6..5a9a603575 100644 --- a/internal/impl/snowflake/streaming/compat.go +++ b/internal/impl/snowflake/streaming/compat.go @@ -160,7 +160,7 @@ func quoteColumnName(name string) string { // be any double quotes inside the string that need escaped. quoted.Grow(len(name) + 2) quoted.WriteByte('"') - for _, r := range name { + for _, r := range strings.ToUpper(name) { if r == '"' { quoted.WriteString(`""`) } else { diff --git a/internal/impl/snowflake/streaming/compat_test.go b/internal/impl/snowflake/streaming/compat_test.go index 71a0f2bb6e..98b17e47a1 100644 --- a/internal/impl/snowflake/streaming/compat_test.go +++ b/internal/impl/snowflake/streaming/compat_test.go @@ -129,11 +129,12 @@ func TestColumnNormalization(t *testing.T) { func TestColumnQuoting(t *testing.T) { require.Equal(t, `""`, quoteColumnName("")) - require.Equal(t, `"foo"`, quoteColumnName("foo")) - require.Equal(t, `"""bar"""`, quoteColumnName(`"bar"`)) - require.Equal(t, `"foo bar"`, quoteColumnName(`foo bar`)) - require.Equal(t, `"foo\ bar"`, quoteColumnName(`foo\ bar`)) - require.Equal(t, `"foo""bar"`, quoteColumnName(`foo"bar`)) + require.Equal(t, `"FOO"`, quoteColumnName("foo")) + require.Equal(t, `"""BAR"""`, quoteColumnName(`"bar"`)) + require.Equal(t, `"FOO BAR"`, quoteColumnName(`foo bar`)) + require.Equal(t, `"FOO\ BAR"`, quoteColumnName(`foo\ bar`)) + require.Equal(t, `"FOO""BAR"`, quoteColumnName(`foo"bar`)) + require.Equal(t, `"FOO""BAR1"`, quoteColumnName(`foo"bar1`)) require.Equal(t, `""""""""""`, quoteColumnName(`""""`)) } diff --git a/internal/impl/snowflake/streaming/schema_errors.go b/internal/impl/snowflake/streaming/schema_errors.go index debecdf5d9..13a101c75e 100644 --- a/internal/impl/snowflake/streaming/schema_errors.go +++ b/internal/impl/snowflake/streaming/schema_errors.go @@ -64,6 +64,12 @@ func (e MissingColumnError) ColumnName() string { return quoteColumnName(e.columnName) } +// The raw name of the new column - DO NOT USE IN SQL! +// This is the more intutitve name for users in the mapping function +func (e MissingColumnError) RawName() string { + return e.columnName +} + // ColumnName returns the value that was associated with the missing column func (e MissingColumnError) Value() any { return e.val From f08d9f8e14a1c2f349ee639266ad7e475ef7fb0d Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 5 Nov 2024 01:38:56 +0000 Subject: [PATCH 5/7] snowflake: make the linter happy --- internal/impl/snowflake/output_snowflake_streaming.go | 4 ++-- internal/impl/snowflake/streaming/rest.go | 1 + internal/impl/snowflake/streaming/schema_errors.go | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index 679e758c12..0208261400 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -571,8 +571,8 @@ func (o *snowflakeStreamerOutput) MigrateMissingColumn(ctx context.Context, col } o.logger.Infof("identified new schema - attempting to alter table to add column: %s %s", col.ColumnName(), columnType) _, err = o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ - // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name here - // comes directly from the Snowflake API so it better be escaped correctly. This is also why we need to + // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name is + // quoted according to the rules in Snowflake's documentation. This is also why we need to // validate the data type, so that you can't sneak an injection attack in there. Statement: fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) ADD COLUMN IF NOT EXISTS %s %s diff --git a/internal/impl/snowflake/streaming/rest.go b/internal/impl/snowflake/streaming/rest.go index 7df5dfa8a0..bca68015e8 100644 --- a/internal/impl/snowflake/streaming/rest.go +++ b/internal/impl/snowflake/streaming/rest.go @@ -249,6 +249,7 @@ type ( Message string `json:"message"` Blobs []blobRegisterStatus `json:"blobs"` } + // BindingValue is a value available as a binding variable in a SQL statement. BindingValue struct { // The binding data type, generally TEXT is what you want // see: https://docs.snowflake.com/en/developer-guide/sql-api/submitting-requests#using-bind-variables-in-a-statement diff --git a/internal/impl/snowflake/streaming/schema_errors.go b/internal/impl/snowflake/streaming/schema_errors.go index 13a101c75e..d3b9b03ef5 100644 --- a/internal/impl/snowflake/streaming/schema_errors.go +++ b/internal/impl/snowflake/streaming/schema_errors.go @@ -36,7 +36,7 @@ func (e NonNullColumnError) ColumnName() string { return e.columnName } -// ColumnName returns nil +// Value returns nil func (e NonNullColumnError) Value() any { return nil } @@ -64,13 +64,13 @@ func (e MissingColumnError) ColumnName() string { return quoteColumnName(e.columnName) } -// The raw name of the new column - DO NOT USE IN SQL! +// RawName is the unquoted name of the new column - DO NOT USE IN SQL! // This is the more intutitve name for users in the mapping function func (e MissingColumnError) RawName() string { return e.columnName } -// ColumnName returns the value that was associated with the missing column +// Value returns the value that was associated with the missing column func (e MissingColumnError) Value() any { return e.val } From e7b39d3cd82afaddab159a80c6c0b795fc0d972b Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 5 Nov 2024 13:50:33 +0000 Subject: [PATCH 6/7] snowflake: fix data type regex --- internal/impl/snowflake/output_snowflake_streaming.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index 0208261400..9061064668 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -659,7 +659,7 @@ func (o *snowflakeStreamerOutput) Close(ctx context.Context) error { // This doesn't need to fully match, but be enough to prevent SQL injection as well as // catch common errors. -var validColumnTypeRegex = regexp.MustCompile(`^\s*(?i:NUMBER|DECIMAL|NUMERIC|INT|INTEGER|BIGINT|SMALLINT|TINYINT|BYTEINT|FLOAT|FLOAT4|FLOAT8|DOUBLE|DOUBLE\s+PRECISION|REAL|VARCHAR|CHAR|CHARACTER|STRING|TEXT|BINARY|VARBINARY|BOOLEAN|DATE|DATETIMETIME|TIMESTAMP|TIEMSTAMP_LTZ|TIMESTAMP_NTZ|TIMESTAMP_TZ|VARIANT|OBJECT|ARRAY)\s*(?:\(\s*\d+\s*\)|\(\s*\d+\s*,\s*\d+\s*\))?\s*$`) +var validColumnTypeRegex = regexp.MustCompile(`^\s*(?i:NUMBER|DECIMAL|NUMERIC|INT|INTEGER|BIGINT|SMALLINT|TINYINT|BYTEINT|FLOAT|FLOAT4|FLOAT8|DOUBLE|DOUBLE\s+PRECISION|REAL|VARCHAR|CHAR|CHARACTER|STRING|TEXT|BINARY|VARBINARY|BOOLEAN|DATE|DATETIME|TIME|TIMESTAMP|TIMESTAMP_LTZ|TIMESTAMP_NTZ|TIMESTAMP_TZ|VARIANT|OBJECT|ARRAY)\s*(?:\(\s*\d+\s*\)|\(\s*\d+\s*,\s*\d+\s*\))?\s*$`) func validateColumnType(v string) error { if validColumnTypeRegex.MatchString(v) { From f12ab28d301a848e4b2a484f176645eff4e60623 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 7 Nov 2024 02:35:04 +0000 Subject: [PATCH 7/7] snowflake: address review feedback Good comments as always! --- .../snowflake/output_snowflake_streaming.go | 39 +++++++++---------- internal/impl/snowflake/streaming/compat.go | 2 +- .../impl/snowflake/streaming/schema_errors.go | 2 +- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index 9061064668..df434988c3 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -563,34 +563,25 @@ func (o *snowflakeStreamerOutput) MigrateMissingColumn(ctx context.Context, col } v, err := out.AsBytes() if err != nil { - return fmt.Errorf("unable to compute new column type for %s: %w", col.ColumnName(), err) + return fmt.Errorf("unable to extract result from new column type mapping for %s: %w", col.ColumnName(), err) } columnType := string(v) if err := validateColumnType(columnType); err != nil { return err } o.logger.Infof("identified new schema - attempting to alter table to add column: %s %s", col.ColumnName(), columnType) - _, err = o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ + err = o.RunSQLMigration( + ctx, // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name is // quoted according to the rules in Snowflake's documentation. This is also why we need to // validate the data type, so that you can't sneak an injection attack in there. - Statement: fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) + fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) ADD COLUMN IF NOT EXISTS %s %s COMMENT 'column created by schema evolution from Redpanda Connect'`, col.ColumnName(), columnType, ), - // Currently we set of timeout of 30 seconds so that we don't have to handle async operations - // that need polling to wait until they finish (results are made async when execution is longer - // than 45 seconds). - Timeout: 30, - Database: o.db, - Schema: o.schema, - Role: o.role, - Bindings: map[string]streaming.BindingValue{ - "1": {Type: "TEXT", Value: o.table}, - }, - }) + ) if err != nil { o.logger.Warnf("unable to add new column, this maybe due to a race with another request, error: %s", err) } @@ -601,15 +592,26 @@ func (o *snowflakeStreamerOutput) MigrateNotNullColumn(ctx context.Context, col o.schemaMigrationMu.Lock() defer o.schemaMigrationMu.Unlock() o.logger.Infof("identified new schema - attempting to alter table to remove null constraint on column: %s", col.ColumnName()) - _, err := o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ + err := o.RunSQLMigration( + ctx, // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name here // comes directly from the Snowflake API so it better not have a SQL injection :) - Statement: fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) ALTER + fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) ALTER %s DROP NOT NULL, %s COMMENT 'column altered to be nullable by schema evolution from Redpanda Connect'`, col.ColumnName(), col.ColumnName(), ), + ) + if err != nil { + o.logger.Warnf("unable to mark column %s as null, this maybe due to a race with another request, error: %s", col.ColumnName(), err) + } + return o.ReopenAllChannels(ctx) +} + +func (o *snowflakeStreamerOutput) RunSQLMigration(ctx context.Context, statement string) error { + _, err := o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Statement: statement, // Currently we set of timeout of 30 seconds so that we don't have to handle async operations // that need polling to wait until they finish (results are made async when execution is longer // than 45 seconds). @@ -621,10 +623,7 @@ func (o *snowflakeStreamerOutput) MigrateNotNullColumn(ctx context.Context, col "1": {Type: "TEXT", Value: o.table}, }, }) - if err != nil { - o.logger.Warnf("unable to mark column %s as null, this maybe due to a race with another request, error: %s", col.ColumnName(), err) - } - return o.ReopenAllChannels(ctx) + return err } // ReopenAllChannels should be called while holding schemaMigrationMu so that diff --git a/internal/impl/snowflake/streaming/compat.go b/internal/impl/snowflake/streaming/compat.go index 5a9a603575..94e66a33a2 100644 --- a/internal/impl/snowflake/streaming/compat.go +++ b/internal/impl/snowflake/streaming/compat.go @@ -157,7 +157,7 @@ func normalizeColumnName(name string) string { func quoteColumnName(name string) string { var quoted strings.Builder // Default to assume we're just going to add quotes and there won't - // be any double quotes inside the string that need escaped. + // be any double quotes inside the string that needs escaped. quoted.Grow(len(name) + 2) quoted.WriteByte('"') for _, r := range strings.ToUpper(name) { diff --git a/internal/impl/snowflake/streaming/schema_errors.go b/internal/impl/snowflake/streaming/schema_errors.go index d3b9b03ef5..cd594b0e8b 100644 --- a/internal/impl/snowflake/streaming/schema_errors.go +++ b/internal/impl/snowflake/streaming/schema_errors.go @@ -13,7 +13,7 @@ package streaming import "fmt" // SchemaMismatchError occurs when the user provided data has data that -// doesn't match the schema *and* the table can be evolved to accomidate +// doesn't match the schema *and* the table can be evolved to accommodate // // This can be used as a mechanism to evolve the schema dynamically. type SchemaMismatchError interface {