diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cf721e08..00cc0cf39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file. ### Fixed - The `code` and `file` fields on the `javascript` processor docs no longer erroneously mention interpolation support. (@mihaitodor) +- The `postgres_cdc` now correctly handles `null` values. (@rockwotj) ## 4.44.0 - 2024-12-13 diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 36f904f6f..edf9f465c 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -537,6 +537,9 @@ func TestIntegrationPgCDCForPgOutputStreamComplexTypesPlugin(t *testing.T) { );`) require.NoError(t, err) + _, err = db.Exec(`INSERT INTO complex_types_example (json_data) VALUES ('{"nested":null}'::jsonb);`) + require.NoError(t, err) + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) template := fmt.Sprintf(` pg_stream: @@ -557,7 +560,7 @@ file: `, tmpDir) streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: TRACE`)) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) @@ -585,29 +588,28 @@ file: require.Eventually(t, func() bool { outBatchMut.Lock() defer outBatchMut.Unlock() - return len(outBatches) == 1 + return len(outBatches) == 2 }, time.Second*25, time.Millisecond*100) - messageWithComplexTypes := outBatches[0] - // producing change to non-complex type to trigger replication and receive updated row so we can check the complex types again // but after they have been produced by replication to ensure the consistency - _, err = db.Exec("UPDATE complex_types_example SET id = 2 WHERE id = 1") + _, err = db.Exec("UPDATE complex_types_example SET id = 3 WHERE id = 1") + require.NoError(t, err) + _, err = db.Exec("UPDATE complex_types_example SET id = 4 WHERE id = 2") require.NoError(t, err) assert.Eventually(t, func() bool { outBatchMut.Lock() defer outBatchMut.Unlock() - return len(outBatches) == 2 + return len(outBatches) == 4 }, time.Second*25, time.Millisecond*100) // replacing update with insert to remove replication messages type differences // so we will be checking only the data - lastMessage := outBatches[len(outBatches)-1] - lastMessage = strings.Replace(lastMessage, "update", "insert", 1) - messageWithComplexTypes = strings.Replace(messageWithComplexTypes, "\"table_snapshot_progress\":0,", "", 1) - - require.Equal(t, messageWithComplexTypes, strings.Replace(lastMessage, ":2", ":1", 1)) + require.JSONEq(t, `{"id":1, "int_array":[1, 2, 3, 4, 5], "ip_addr":"192.168.1.1/32", "json_data":{"name":"test", "value":42}, "location": "(45.5,-122.6)", "search_text":"'brown':3 'dog':9 'fox':4 'jump':5 'lazi':8 'quick':2", "tags":["tag1", "tag2", "tag3"], "time_range": "[2024-01-01 00:00:00,2024-12-31 00:00:00)", "uuid_col":"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"}`, outBatches[0]) + require.JSONEq(t, `{"id":2, "int_array":null, "ip_addr":null, "json_data":{"nested":null}, "location":null, "search_text":null, "tags":null, "time_range":null, "uuid_col":null}`, outBatches[1]) + require.JSONEq(t, `{"id":3, "int_array":[1, 2, 3, 4, 5], "ip_addr":"192.168.1.1/32", "json_data":{"name":"test", "value":42}, "location": "(45.5,-122.6)", "search_text":"'brown':3 'dog':9 'fox':4 'jump':5 'lazi':8 'quick':2", "tags":["tag1", "tag2", "tag3"], "time_range": "[2024-01-01 00:00:00,2024-12-31 00:00:00)", "uuid_col":"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"}`, outBatches[2]) + require.JSONEq(t, `{"id":4, "int_array":null, "ip_addr":null, "json_data":{"nested":null}, "location":null, "search_text":null, "tags":null, "time_range":null, "uuid_col":null}`, outBatches[3]) require.NoError(t, streamOut.StopWithin(time.Second*10)) } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 938502fc3..452574b77 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -569,7 +569,7 @@ func (s *Stream) processSnapshot() error { col := columnNames[i] var val any if val, err = getter(scanArgs[i]); err != nil { - return err + return fmt.Errorf("unable to decode column %s: %w", col, err) } data[col] = val normalized := sanitize.QuotePostgresIdentifier(col) diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 5a9878980..4bfbbd79e 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -84,6 +84,8 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM return nil, fmt.Errorf("unable to decode column data: %w", err) } values[colName] = val + default: + return nil, fmt.Errorf("unable to decode column data, unknown data type: %d", col.DataType) } } message.Data = values @@ -149,7 +151,10 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM return message, nil } -func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interface{}, error) { +func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (any, error) { + if data == nil { + return nil, nil + } if dt, ok := mi.TypeForOID(dataType); ok { val, err := dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data) if err != nil { @@ -177,6 +182,5 @@ func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interfa return val, err } - return string(data), nil } diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 1d0cbad2f..37b8d0c7b 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -152,17 +152,39 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( switch v.DatabaseTypeName() { case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullString).String, nil } + valueGetters[i] = func(v any) (any, error) { + str := v.(*sql.NullString) + if !str.Valid { + return nil, nil + } + return str.String, nil + } case "BOOL": scanArgs[i] = new(sql.NullBool) - valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullBool).Bool, nil } + valueGetters[i] = func(v any) (any, error) { + val := v.(*sql.NullBool) + if !val.Valid { + return nil, nil + } + return val.Bool, nil + } case "INT4": scanArgs[i] = new(sql.NullInt64) - valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullInt64).Int64, nil } + valueGetters[i] = func(v any) (any, error) { + val := v.(*sql.NullInt64) + if !val.Valid { + return nil, nil + } + return val.Int64, nil + } case "JSONB": scanArgs[i] = new(sql.NullString) valueGetters[i] = func(v any) (any, error) { - payload := v.(*sql.NullString).String + str := v.(*sql.NullString) + if !str.Valid { + return nil, nil + } + payload := str.String if payload == "" { return payload, nil } @@ -177,8 +199,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( scanArgs[i] = new(sql.NullString) valueGetters[i] = func(v any) (any, error) { inet := pgtype.Inet{} - val := v.(*sql.NullString).String - if err := inet.Scan(val); err != nil { + val := v.(*sql.NullString) + if !val.Valid { + return nil, nil + } + if err := inet.Scan(val.String); err != nil { return nil, err } @@ -188,8 +213,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( scanArgs[i] = new(sql.NullString) valueGetters[i] = func(v any) (any, error) { newArray := pgtype.Tsrange{} - val := v.(*sql.NullString).String - if err := newArray.Scan(val); err != nil { + val := v.(*sql.NullString) + if !val.Valid { + return nil, nil + } + if err := newArray.Scan(val.String); err != nil { return nil, err } @@ -200,8 +228,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( scanArgs[i] = new(sql.NullString) valueGetters[i] = func(v any) (any, error) { newArray := pgtype.Int4Array{} - val := v.(*sql.NullString).String - if err := newArray.Scan(val); err != nil { + val := v.(*sql.NullString) + if !val.Valid { + return nil, nil + } + if err := newArray.Scan(val.String); err != nil { return nil, err } @@ -211,8 +242,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( scanArgs[i] = new(sql.NullString) valueGetters[i] = func(v any) (any, error) { newArray := pgtype.TextArray{} - val := v.(*sql.NullString).String - if err := newArray.Scan(val); err != nil { + val := v.(*sql.NullString) + if !val.Valid { + return nil, nil + } + if err := newArray.Scan(val.String); err != nil { return nil, err } @@ -220,7 +254,13 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( } default: scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullString).String, nil } + valueGetters[i] = func(v any) (any, error) { + val := v.(*sql.NullString) + if !val.Valid { + return nil, nil + } + return val.String, nil + } } }