Skip to content

Commit

Permalink
snowflake: add option to ignore null values in schema evolution
Browse files Browse the repository at this point in the history
Mor options! This allows folks to include nulls when managing the
schema. If you are managing the schema explicitly (ie. using a schema in
schema registry) this can make a lot of sense, by default we expect
people to just throw data at the connector and we should just drop nulls
because we don't know what type the column will eventually be.
  • Loading branch information
rockwotj committed Feb 26, 2025
1 parent a078abc commit a9b8ed0
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 53 deletions.
10 changes: 10 additions & 0 deletions docs/modules/components/pages/outputs/snowflake_streaming.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ output:
CREATE TABLE IF NOT EXISTS mytable (amount NUMBER);
schema_evolution:
enabled: false # No default (required)
ignore_nulls: true
processors: [] # No default (optional)
build_options:
parallelism: 1
Expand Down Expand Up @@ -459,6 +460,15 @@ Whether schema evolution is enabled.
*Type*: `bool`
=== `schema_evolution.ignore_nulls`
If `true`, then new columns that are `null` are ignored and schema evolution is not triggered. If `false` then null columns trigger schema migrations in Snowflake. NOTE: unless you already know what type this column will be in advance, it's highly encouraged to ignore null values.
*Type*: `bool`
*Default*: `true`
=== `schema_evolution.processors`
A series of processors to execute when new columns are added to the table. Specifying this can support running side effects when the schema evolves or enriching the message with additional data to guide the schema changes. For example, one could read the schema the message was produced with from the schema registry and use that to decide which type the new column in Snowflake should be.
Expand Down
61 changes: 61 additions & 0 deletions internal/impl/snowflake/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -373,6 +374,66 @@ snowflake_streaming:
}, rows)
}

func TestIntegrationSchemaEvolutionNull(t *testing.T) {
integration.CheckSkip(t)
runTest := func(t *testing.T, ignoreNull bool) {
produce, stream := SetupSnowflakeStream(t, fmt.Sprintf(`
label: snowpipe_streaming
snowflake_streaming:
account: "$ACCOUNT"
user: "$USER"
role: $ROLE
database: "$DB"
schema: $SCHEMA
private_key_file: "$PRIVATE_KEY_FILE"
table: integration_test_auto_schema_evolution_with_null
init_statement: |
DROP TABLE IF EXISTS integration_test_auto_schema_evolution_with_null;
max_in_flight: 4
channel_name: "${!this.channel}"
schema_evolution:
enabled: true
ignore_nulls: %v
processors:
- mapping: |
root = match {
this.name == "null_a" || this.name == "null_b" => "NUMBER"
_ => "variant"
}
`, ignoreNull))
RunStreamInBackground(t, stream)
// Initial schema creation test
require.NoError(t, produce([]map[string]any{
{"foo": "bar", "null_a": nil},
}))
// Incremental schema migration test
require.NoError(t, produce([]map[string]any{
{"foo": "bar", "null_b": nil},
}))
rows := RunSQLQuery(
t,
stream,
`SELECT column_name, data_type, numeric_precision, numeric_scale
FROM $DB.information_schema.columns
WHERE table_name = 'INTEGRATION_TEST_AUTO_SCHEMA_EVOLUTION_WITH_NULL' AND table_schema = '$SCHEMA'
ORDER BY column_name`,
)
if ignoreNull {
require.Equal(t, [][]string{
{"FOO", "VARIANT", "", ""},
}, rows)
} else {
require.Equal(t, [][]string{
{"FOO", "VARIANT", "", ""},
{"NULL_A", "NUMBER", "38", "0"},
{"NULL_B", "NUMBER", "38", "0"},
}, rows)
}
}
t.Run("IgnoreNull", func(t *testing.T) { runTest(t, true) })
t.Run("IncludeNull", func(t *testing.T) { runTest(t, false) })
}

func TestIntegrationManualSchemaEvolution(t *testing.T) {
// This is sort of a stress test for race conditions when the schema changes seperately
integration.CheckSkip(t)
Expand Down
103 changes: 58 additions & 45 deletions internal/impl/snowflake/output_snowflake_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
ssoFieldBuildChunkSize = "chunk_size"
ssoFieldSchemaEvolution = "schema_evolution"
ssoFieldSchemaEvolutionEnabled = "enabled"
ssoFieldSchemaEvolutionIgnoreNulls = "ignore_nulls"
ssoFieldSchemaEvolutionNewColumnTypeMapping = "new_column_type_mapping"
ssoFieldSchemaEvolutionProcessors = "processors"
ssoFieldCommitTimeout = "commit_timeout"
Expand Down Expand Up @@ -121,6 +122,7 @@ ALTER TABLE t1 ADD COLUMN a2 NUMBER;
`),
service.NewObjectField(ssoFieldSchemaEvolution,
service.NewBoolField(ssoFieldSchemaEvolutionEnabled).Description("Whether schema evolution is enabled."),
service.NewBoolField(ssoFieldSchemaEvolutionIgnoreNulls).Description("If `true`, then new columns that are `null` are ignored and schema evolution is not triggered. If `false` then null columns trigger schema migrations in Snowflake. NOTE: unless you already know what type this column will be in advance, it's highly encouraged to ignore null values.").Default(true).Advanced(),
service.NewBloblangField(ssoFieldSchemaEvolutionNewColumnTypeMapping).Description(`
The mapping function from Redpanda Connect type to column type in Snowflake. Overriding this can allow for customization of the datatype if there is specific information that you know about the data types in use. This mapping should result in the `+"`root`"+` variable being assigned a string with the data type for the new column in Snowflake.
Expand Down Expand Up @@ -430,15 +432,25 @@ func newSnowflakeStreamer(
return nil, err
}
}
var schemaEvolutionEnabled bool
schemaEvolutionMode := streaming.SchemaModeIgnoreExtra
var schemaEvolutionProcessors []*service.OwnedProcessor
var schemaEvolutionMapping *bloblang.Executor
if conf.Contains(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionEnabled) {
seConf := conf.Namespace(ssoFieldSchemaEvolution)
schemaEvolutionEnabled, err = seConf.FieldBool(ssoFieldSchemaEvolutionEnabled)
schemaEvolutionEnabled, err := seConf.FieldBool(ssoFieldSchemaEvolutionEnabled)
if err != nil {
return nil, err
}
ignoreNulls, err := seConf.FieldBool(ssoFieldSchemaEvolutionIgnoreNulls)
if err != nil {
return nil, err
}
if schemaEvolutionEnabled {
schemaEvolutionMode = streaming.SchemaModeStrict
if !ignoreNulls {
schemaEvolutionMode = streaming.SchemaModeStrictWithNulls
}
}
if seConf.Contains(ssoFieldSchemaEvolutionProcessors) {
schemaEvolutionProcessors, err = seConf.FieldProcessorList(ssoFieldSchemaEvolutionProcessors)
if err != nil {
Expand Down Expand Up @@ -567,8 +579,9 @@ func newSnowflakeStreamer(
mgr.SetGeneric(SnowflakeClientResourceForTesting, restClient)
makeImpl := func(table string) (*snowpipeSchemaEvolver, service.BatchOutput) {
var schemaEvolver *snowpipeSchemaEvolver
if schemaEvolutionEnabled {
if schemaEvolutionMode != streaming.SchemaModeIgnoreExtra {
schemaEvolver = &snowpipeSchemaEvolver{
mode: schemaEvolutionMode,
schemaEvolutionMapping: schemaEvolutionMapping,
pipeline: schemaEvolutionProcessors,
restClient: restClient,
Expand All @@ -582,18 +595,18 @@ func newSnowflakeStreamer(
var impl service.BatchOutput
if channelName != nil {
indexed := &snowpipeIndexedOutput{
channelName: channelName,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMigrationEnabled: schemaEvolver != nil,
commitTimeout: commitTimeout,
channelName: channelName,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMode: schemaEvolutionMode,
commitTimeout: commitTimeout,
}
indexed.channelPool = pool.NewIndexed(func(ctx context.Context, name string) (*streaming.SnowflakeIngestionChannel, error) {
hash := sha256.Sum256([]byte(name))
Expand All @@ -609,18 +622,18 @@ func newSnowflakeStreamer(
channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table)
}
pooled := &snowpipePooledOutput{
channelPrefix: channelPrefix,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMigrationEnabled: schemaEvolver != nil,
commitTimeout: commitTimeout,
channelPrefix: channelPrefix,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMode: schemaEvolutionMode,
commitTimeout: commitTimeout,
}
pooled.channelPool = pool.NewCapped(maxInFlight, func(ctx context.Context, id int) (*streaming.SnowflakeIngestionChannel, error) {
name := fmt.Sprintf("%s_%d", pooled.channelPrefix, id)
Expand Down Expand Up @@ -878,19 +891,19 @@ type snowpipePooledOutput struct {
channelPrefix, db, schema, table, role string
offsetToken *service.InterpolatedString
logger *service.Logger
schemaMigrationEnabled bool
schemaMode streaming.SchemaMode
}

func (o *snowpipePooledOutput) openChannel(ctx context.Context, name string, id int16) (*streaming.SnowflakeIngestionChannel, error) {
o.logger.Debugf("opening snowflake streaming channel for table `%s.%s.%s`: %s", o.db, o.schema, o.table, name)
return o.client.OpenChannel(ctx, streaming.ChannelOptions{
ID: id,
Name: name,
DatabaseName: o.db,
SchemaName: o.schema,
TableName: o.table,
BuildOptions: o.buildOpts,
StrictSchemaEnforcement: o.schemaMigrationEnabled,
ID: id,
Name: name,
DatabaseName: o.db,
SchemaName: o.schema,
TableName: o.table,
BuildOptions: o.buildOpts,
SchemaMode: o.schemaMode,
})
}

Expand Down Expand Up @@ -918,7 +931,7 @@ func (o *snowpipePooledOutput) WriteBatch(ctx context.Context, batch service.Mes
if err != nil {
// Only evolve the schema if requested.
var schemaErr *schemaMigrationNeededError
if o.schemaMigrationEnabled {
if o.schemaMode != streaming.SchemaModeIgnoreExtra {
var ok bool
schemaErr, ok = asSchemaMigrationError(err)
if !ok {
Expand Down Expand Up @@ -975,19 +988,19 @@ type snowpipeIndexedOutput struct {
db, schema, table, role string
offsetToken, channelName *service.InterpolatedString
logger *service.Logger
schemaMigrationEnabled bool
schemaMode streaming.SchemaMode
}

func (o *snowpipeIndexedOutput) openChannel(ctx context.Context, name string, id int16) (*streaming.SnowflakeIngestionChannel, error) {
o.logger.Debugf("opening snowflake streaming channel for table `%s.%s.%s`: %s", o.db, o.schema, o.table, name)
return o.client.OpenChannel(ctx, streaming.ChannelOptions{
ID: id,
Name: name,
DatabaseName: o.db,
SchemaName: o.schema,
TableName: o.table,
BuildOptions: o.buildOpts,
StrictSchemaEnforcement: o.schemaMigrationEnabled,
ID: id,
Name: name,
DatabaseName: o.db,
SchemaName: o.schema,
TableName: o.table,
BuildOptions: o.buildOpts,
SchemaMode: o.schemaMode,
})
}

Expand Down Expand Up @@ -1019,7 +1032,7 @@ func (o *snowpipeIndexedOutput) WriteBatch(ctx context.Context, batch service.Me
if err != nil {
// Only evolve the schema if requested.
var schemaErr *schemaMigrationNeededError
if o.schemaMigrationEnabled {
if o.schemaMode != streaming.SchemaModeIgnoreExtra {
var ok bool
schemaErr, ok = asSchemaMigrationError(err)
if !ok {
Expand Down
4 changes: 4 additions & 0 deletions internal/impl/snowflake/schema_evolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func asSchemaMigrationError(err error) (*schemaMigrationNeededError, bool) {
}

type snowpipeSchemaEvolver struct {
mode streaming.SchemaMode
schemaEvolutionMapping *bloblang.Executor
pipeline []*service.OwnedProcessor
logger *service.Logger
Expand Down Expand Up @@ -200,6 +201,9 @@ func (o *snowpipeSchemaEvolver) CreateOutputTable(ctx context.Context, batch ser
}
columns := []string{}
for k, v := range row {
if o.mode == streaming.SchemaModeStrict && v == nil {
continue
}
col := streaming.NewMissingColumnError(msg, k, v)
colType, err := o.ComputeMissingColumnType(ctx, col)
if err != nil {
Expand Down
22 changes: 18 additions & 4 deletions internal/impl/snowflake/streaming/parquet.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,22 @@ import (
"github.com/segmentio/encoding/thrift"
)

// SchemaMode specifies how to handle schema mismatches when constructing parquet files
type SchemaMode int

const (
// SchemaModeIgnoreExtra is a mode where unknown properties in messages are ignored
SchemaModeIgnoreExtra SchemaMode = iota
// SchemaModeStrict is a mode where non-null unknown properties in message result in errors
SchemaModeStrict
// SchemaModeStrictWithNulls is a mode where all unknown properties result in errors
SchemaModeStrictWithNulls
)

// messageToRow converts a message into columnar form using the provided name to index mapping.
// We have to materialize the column into a row so that we can know if a column is null - the
// msg can be sparse, but the row must not be sparse.
func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int, allowExtraProperties bool) error {
func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int, mode SchemaMode) error {
v, err := msg.AsStructured()
if err != nil {
return fmt.Errorf("error extracting object from message: %w", err)
Expand All @@ -38,7 +50,9 @@ func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int
for k, v := range row {
idx, ok := nameToPosition[normalizeColumnName(k)]
if !ok {
if !allowExtraProperties && v != nil {
if mode == SchemaModeStrict && v != nil {
missingColumns = append(missingColumns, NewMissingColumnError(msg, k, v))
} else if mode == SchemaModeStrictWithNulls {
missingColumns = append(missingColumns, NewMissingColumnError(msg, k, v))
}
continue
Expand All @@ -55,7 +69,7 @@ func constructRowGroup(
batch service.MessageBatch,
schema *parquet.Schema,
transformers []*dataTransformer,
allowExtraProperties bool,
mode SchemaMode,
) ([]parquet.Row, []*statsBuffer, error) {
// We write all of our data in a columnar fashion, but need to pivot that data so that we can feed it into
// out parquet library (which sadly will redo the pivot - maybe we need a lower level abstraction...).
Expand Down Expand Up @@ -83,7 +97,7 @@ func constructRowGroup(
// is needed
row := make([]any, rowWidth)
for _, msg := range batch {
err := messageToRow(msg, row, nameToPosition, allowExtraProperties)
err := messageToRow(msg, row, nameToPosition, mode)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/impl/snowflake/streaming/parquet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestWriteParquet(t *testing.T) {
batch,
schema,
transformers,
false,
SchemaModeIgnoreExtra,
)
require.NoError(t, err)
w := newParquetWriter("latest", schema)
Expand Down
6 changes: 3 additions & 3 deletions internal/impl/snowflake/streaming/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ type ChannelOptions struct {
TableName string
// The max parallelism used to build parquet files and convert message batches into rows.
BuildOptions BuildOptions
// If set to true, don't ignore extra columns in user data, but raise an error.
StrictSchemaEnforcement bool
// How to handle schema differences
SchemaMode SchemaMode
}

type encryptionInfo struct {
Expand Down Expand Up @@ -337,7 +337,7 @@ func (c *SnowflakeIngestionChannel) constructBdecPart(batch service.MessageBatch
rowGroups = append(rowGroups, rowGroup{})
chunk := batch[i : i+end]
wg.Go(func() error {
rows, stats, err := constructRowGroup(chunk, c.schema, c.transformers, !c.StrictSchemaEnforcement)
rows, stats, err := constructRowGroup(chunk, c.schema, c.transformers, c.SchemaMode)
rowGroups[j] = rowGroup{rows, stats}
return err
})
Expand Down

0 comments on commit a9b8ed0

Please sign in to comment.