Skip to content

Commit

Permalink
Merge pull request #3008 from redpanda-data/batch-migrate
Browse files Browse the repository at this point in the history
snowpipe: batch schema migration
  • Loading branch information
rockwotj authored Nov 20, 2024
2 parents 8adeeb6 + ff8ee78 commit cbea8d8
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 22 deletions.
67 changes: 47 additions & 20 deletions internal/impl/snowflake/output_snowflake_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,27 +564,12 @@ func (o *snowflakeStreamerOutput) WriteBatchInternal(ctx context.Context, batch
} else {
// Only evolve the schema if requested.
if o.schemaEvolutionEnabled() {
nullColumnErr := streaming.NonNullColumnError{}
if errors.As(err, &nullColumnErr) {
schemaErr, ok := asSchemaMigrationError(o, err)
if ok {
// put the channel back so that we can reopen it along with the rest of the channels to
// pick up the new schema.
o.channelPool.Put(channel)
// Return an error so that we release our read lock and can take the write lock
// to forcibly reopen all our channels to get a new schema.
return schemaMigrationNeededError{
migrator: func(ctx context.Context) error {
return o.MigrateNotNullColumn(ctx, nullColumnErr)
},
}
}
missingColumnErr := streaming.MissingColumnError{}
if errors.As(err, &missingColumnErr) {
o.channelPool.Put(channel)
return schemaMigrationNeededError{
migrator: func(ctx context.Context) error {
return o.MigrateMissingColumn(ctx, missingColumnErr)
},
}
return schemaErr
}
}
reopened, reopenErr := o.openChannel(ctx, channel.Name, channel.ID)
Expand All @@ -605,6 +590,48 @@ func (o *snowflakeStreamerOutput) WriteBatchInternal(ctx context.Context, batch
return err
}

func asSchemaMigrationError(o *snowflakeStreamerOutput, err error) (schemaMigrationNeededError, bool) {
nullColumnErr := streaming.NonNullColumnError{}
if errors.As(err, &nullColumnErr) {
// Return an error so that we release our read lock and can take the write lock
// to forcibly reopen all our channels to get a new schema.
return schemaMigrationNeededError{
migrator: func(ctx context.Context) error {
if err := o.MigrateNotNullColumn(ctx, nullColumnErr); err != nil {
return err
}
return o.ReopenAllChannels(ctx)
},
}, true
}
missingColumnErr := streaming.MissingColumnError{}
if errors.As(err, &missingColumnErr) {
return schemaMigrationNeededError{
migrator: func(ctx context.Context) error {
if err := o.MigrateMissingColumn(ctx, missingColumnErr); err != nil {
return err
}
return o.ReopenAllChannels(ctx)
},
}, true
}
batchErr := streaming.BatchSchemaMismatchError[streaming.MissingColumnError]{}
if errors.As(err, &batchErr) {
return schemaMigrationNeededError{
migrator: func(ctx context.Context) error {
for _, missingCol := range batchErr.Errors {
// TODO(rockwood): Consider a batch SQL statement that adds N columns at a time
if err := o.MigrateMissingColumn(ctx, missingCol); err != nil {
return err
}
}
return o.ReopenAllChannels(ctx)
},
}, true
}
return schemaMigrationNeededError{}, false
}

type schemaMigrationNeededError struct {
migrator func(ctx context.Context) error
}
Expand Down Expand Up @@ -657,7 +684,7 @@ func (o *snowflakeStreamerOutput) MigrateMissingColumn(ctx context.Context, col
if err != nil {
o.logger.Warnf("unable to add new column, this maybe due to a race with another request, error: %s", err)
}
return o.ReopenAllChannels(ctx)
return nil
}

func (o *snowflakeStreamerOutput) MigrateNotNullColumn(ctx context.Context, col streaming.NonNullColumnError) error {
Expand All @@ -678,7 +705,7 @@ func (o *snowflakeStreamerOutput) MigrateNotNullColumn(ctx context.Context, col
if err != nil {
o.logger.Warnf("unable to mark column %s as null, this maybe due to a race with another request, error: %s", col.ColumnName(), err)
}
return o.ReopenAllChannels(ctx)
return nil
}

func (o *snowflakeStreamerOutput) CreateOutputTable(ctx context.Context, batch service.MessageBatch) error {
Expand Down
6 changes: 5 additions & 1 deletion internal/impl/snowflake/streaming/parquet.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@ func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int
if !ok {
return fmt.Errorf("expected object, got: %T", v)
}
var missingColumns []MissingColumnError
for k, v := range row {
idx, ok := nameToPosition[normalizeColumnName(k)]
if !ok {
if !allowExtraProperties && v != nil {
return MissingColumnError{columnName: k, val: v}
missingColumns = append(missingColumns, MissingColumnError{columnName: k, val: v})
}
continue
}
out[idx] = v
}
if len(missingColumns) > 0 {
return BatchSchemaMismatchError[MissingColumnError]{missingColumns}
}
return nil
}

Expand Down
22 changes: 21 additions & 1 deletion internal/impl/snowflake/streaming/schema_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,37 @@

package streaming

import "fmt"
import (
"errors"
"fmt"
)

// SchemaMismatchError occurs when the user provided data has data that
// doesn't match the schema *and* the table can be evolved to accommodate
//
// This can be used as a mechanism to evolve the schema dynamically.
type SchemaMismatchError interface {
error
ColumnName() string
Value() any
}

var _ error = BatchSchemaMismatchError[SchemaMismatchError]{}

// BatchSchemaMismatchError is when multiple schema mismatch errors happen at once
type BatchSchemaMismatchError[T SchemaMismatchError] struct {
Errors []T
}

// Error implements the error interface
func (e BatchSchemaMismatchError[T]) Error() string {
errs := []error{}
for _, err := range e.Errors {
errs = append(errs, err)
}
return errors.Join(errs...).Error()
}

var _ error = NonNullColumnError{}
var _ SchemaMismatchError = NonNullColumnError{}

Expand Down

0 comments on commit cbea8d8

Please sign in to comment.