diff --git a/clients/destination.go b/clients/destination.go index f213fd42e1..b53d3b90ed 100644 --- a/clients/destination.go +++ b/clients/destination.go @@ -249,7 +249,7 @@ func (c *DestinationClient) Migrate(ctx context.Context, tables []*schema.Table) // Write writes rows as they are received from the channel to the destination plugin. // resources is marshaled schema.Resource. We are not marshalling this inside the function -// because usually it is alreadun marshalled from the destination plugin. +// because usually it is already marshalled from the destination plugin. func (c *DestinationClient) Write(ctx context.Context, source string, syncTime time.Time, resources <-chan []byte) (uint64, error) { saveClient, err := c.pbClient.Write(ctx) if err != nil { @@ -261,6 +261,10 @@ func (c *DestinationClient) Write(ctx context.Context, source string, syncTime t Source: source, Timestamp: timestamppb.New(syncTime), }); err != nil { + if err == io.EOF { + // don't send write request if the channel is closed + break + } return 0, fmt.Errorf("failed to call Write.Send: %w", err) } } @@ -292,6 +296,10 @@ func (c *DestinationClient) Write2(ctx context.Context, tables schema.Tables, so if err := saveClient.Send(&pb.Write2_Request{ Resource: resource, }); err != nil { + if err == io.EOF { + // don't send write request if the channel is closed + break + } return fmt.Errorf("failed to call Write2.Send: %w", err) } } diff --git a/clients/destination_test.go b/clients/destination_test.go index a59adbbfb9..f0998a581f 100644 --- a/clients/destination_test.go +++ b/clients/destination_test.go @@ -2,12 +2,17 @@ package clients import ( "context" + "encoding/json" "os" + "path" "strings" "testing" + "time" + "github.com/cloudquery/plugin-sdk/schema" "github.com/cloudquery/plugin-sdk/specs" "github.com/rs/zerolog" + "github.com/stretchr/testify/require" ) var newDestinationClientTestCases = []specs.Source{ @@ -59,3 +64,51 @@ func TestDestinationClient(t *testing.T) { }) } } + +func TestDestinationClientWriteReturnsCorrectError(t *testing.T) { + ctx := context.Background() + l := zerolog.New(zerolog.NewTestWriter(t)).Output(zerolog.ConsoleWriter{Out: os.Stderr}).Level(zerolog.DebugLevel) + dirName := t.TempDir() + c, err := NewDestinationClient(ctx, specs.RegistryGithub, "cloudquery/sqlite", "v1.0.11", WithDestinationLogger(l), WithDestinationDirectory(dirName)) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := c.Terminate(); err != nil { + t.Logf("failed to terminate destination client: %v", err) + } + }() + sqliteSpec := struct { + connectionString string + }{connectionString: path.Join(dirName, "test.sql")} + if err := c.Initialize(ctx, specs.Destination{Spec: sqliteSpec}); err != nil { + t.Fatal(err) + } + + name, err := c.Name(ctx) + if err != nil { + t.Fatal("failed to get name", err) + } + + columns := []schema.Column{{Name: "int", Type: schema.TypeInt}} + tables := schema.Tables{&schema.Table{Name: "test-1", Columns: columns}, &schema.Table{Name: "test-2", Columns: columns}} + resource1 := schema.Resource{Item: map[string]any{"int": 1}, Table: tables[0]} + destResource1, _ := json.Marshal(resource1.ToDestinationResource()) + resource2 := schema.Resource{Item: map[string]any{"int": 1}, Table: tables[1]} + destResource2, _ := json.Marshal(resource2.ToDestinationResource()) + resourcesChannel := make(chan []byte) + go func() { + defer close(resourcesChannel) + // we need to stream enough data to the server so it at least starts processing it and return the relevant error + for i := 1; i < 100000; i++ { + resourcesChannel <- destResource1 + resourcesChannel <- destResource2 + resourcesChannel <- destResource1 + resourcesChannel <- destResource2 + } + }() + + err = c.Write2(ctx, tables, name, time.Now().UTC(), resourcesChannel) + require.ErrorContains(t, err, "context canceled") +}