diff --git a/v2/writer.go b/v2/writer.go index 8ed5756f..72675cb2 100644 --- a/v2/writer.go +++ b/v2/writer.go @@ -1,12 +1,15 @@ package car import ( + "bytes" "errors" "fmt" "io" "os" + "github.com/ipfs/go-cid" "github.com/ipld/go-car/v2/index" + "github.com/ipld/go-car/v2/internal/carv1" internalio "github.com/ipld/go-car/v2/internal/io" ) @@ -210,3 +213,100 @@ func AttachIndex(path string, idx index.Index, offset uint64) error { indexWriter := internalio.NewOffsetWriter(out, int64(offset)) return index.WriteTo(idx, indexWriter) } + +// ReplaceRootsInFile replaces the root CIDs in CAR file at given path with the given roots. +// This function accepts both CARv1 and CARv2 files. +// +// Note that the roots are only replaced if their total serialized size exactly matches the total +// serialized size of existing roots in CAR file. +func ReplaceRootsInFile(path string, roots []cid.Cid) (err error) { + f, err := os.OpenFile(path, os.O_RDWR, 0o666) + if err != nil { + return err + } + defer func() { + // Close file and override return error type if it is nil. + if cerr := f.Close(); err == nil { + err = cerr + } + }() + + // Read header or pragma; note that both are a valid CARv1 header. + header, err := carv1.ReadHeader(f) + if err != nil { + return err + } + + var currentSize int64 + var newHeaderOffset int64 + switch header.Version { + case 1: + // When the given file is a CARv1 : + // 1. The offset at which the new header should be written is zero (newHeaderOffset = 0) + // 2. The current header size is equal to the number of bytes read, and + // + // Note that we explicitly avoid using carv1.HeaderSize to determine the current header size. + // This is based on the fact that carv1.ReadHeader does not read any extra bytes. + // Therefore, we can avoid extra allocations of carv1.HeaderSize to determine size by simply + // counting the bytes read so far. + currentSize, err = f.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + case 2: + // When the given file is a CARv2 : + // 1. The offset at which the new header should be written is carv2.Header.DataOffset + // 2. The inner CARv1 header size is equal to the number of bytes read minus carv2.Header.DataOffset + var v2h Header + if _, err = v2h.ReadFrom(f); err != nil { + return err + } + newHeaderOffset = int64(v2h.DataOffset) + if _, err = f.Seek(newHeaderOffset, io.SeekStart); err != nil { + return err + } + var innerV1Header *carv1.CarHeader + innerV1Header, err = carv1.ReadHeader(f) + if err != nil { + return err + } + if innerV1Header.Version != 1 { + err = fmt.Errorf("invalid data payload header: expected version 1, got %d", innerV1Header.Version) + } + var readSoFar int64 + readSoFar, err = f.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + currentSize = readSoFar - newHeaderOffset + default: + err = fmt.Errorf("invalid car version: %d", header.Version) + return err + } + + newHeader := &carv1.CarHeader{ + Roots: roots, + Version: 1, + } + // Serialize the new header straight up instead of using carv1.HeaderSize. + // Because, carv1.HeaderSize serialises it to calculate size anyway. + // By serializing straight up we get the replacement bytes and size. + // Otherwise, we end up serializing the new header twice: + // once through carv1.HeaderSize, and + // once to write it out. + var buf bytes.Buffer + if err = carv1.WriteHeader(newHeader, &buf); err != nil { + return err + } + // Assert the header sizes are consistent. + newSize := int64(buf.Len()) + if currentSize != newSize { + return fmt.Errorf("current header size (%d) must match replacement header size (%d)", currentSize, newSize) + } + // Seek to the offset at which the new header should be written. + if _, err = f.Seek(newHeaderOffset, io.SeekStart); err != nil { + return err + } + _, err = f.Write(buf.Bytes()) + return err +} diff --git a/v2/writer_test.go b/v2/writer_test.go index c29b4339..11b5ff12 100644 --- a/v2/writer_test.go +++ b/v2/writer_test.go @@ -131,3 +131,149 @@ func assertAddNodes(t *testing.T, adder format.NodeAdder, nds ...format.Node) { assert.NoError(t, adder.Add(context.Background(), nd)) } } + +func TestReplaceRootsInFile(t *testing.T) { + tests := []struct { + name string + path string + roots []cid.Cid + wantErrMsg string + }{ + { + name: "CorruptPragmaIsRejected", + path: "testdata/sample-corrupt-pragma.car", + wantErrMsg: "unexpected EOF", + }, + { + name: "CARv42IsRejected", + path: "testdata/sample-rootless-v42.car", + wantErrMsg: "invalid car version: 42", + }, + { + name: "CARv1RootsOfDifferentSizeAreNotReplaced", + path: "testdata/sample-v1.car", + wantErrMsg: "current header size (61) must match replacement header size (18)", + }, + { + name: "CARv2RootsOfDifferentSizeAreNotReplaced", + path: "testdata/sample-wrapped-v2.car", + wantErrMsg: "current header size (61) must match replacement header size (18)", + }, + { + name: "CARv1NonEmptyRootsOfDifferentSizeAreNotReplaced", + path: "testdata/sample-v1.car", + roots: []cid.Cid{requireDecodedCid(t, "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zR1n")}, + wantErrMsg: "current header size (61) must match replacement header size (57)", + }, + { + name: "CARv1ZeroLenNonEmptyRootsOfDifferentSizeAreNotReplaced", + path: "testdata/sample-v1-with-zero-len-section.car", + roots: []cid.Cid{merkledag.NewRawNode([]byte("fish")).Cid()}, + wantErrMsg: "current header size (61) must match replacement header size (59)", + }, + { + name: "CARv2NonEmptyRootsOfDifferentSizeAreNotReplaced", + path: "testdata/sample-wrapped-v2.car", + roots: []cid.Cid{merkledag.NewRawNode([]byte("fish")).Cid()}, + wantErrMsg: "current header size (61) must match replacement header size (59)", + }, + { + name: "CARv2IndexlessNonEmptyRootsOfDifferentSizeAreNotReplaced", + path: "testdata/sample-v2-indexless.car", + roots: []cid.Cid{merkledag.NewRawNode([]byte("fish")).Cid()}, + wantErrMsg: "current header size (61) must match replacement header size (59)", + }, + { + name: "CARv1SameSizeRootsAreReplaced", + path: "testdata/sample-v1.car", + roots: []cid.Cid{requireDecodedCid(t, "bafy2bzaced4ueelaegfs5fqu4tzsh6ywbbpfk3cxppupmxfdhbpbhzawfw5od")}, + }, + { + name: "CARv2SameSizeRootsAreReplaced", + path: "testdata/sample-wrapped-v2.car", + roots: []cid.Cid{requireDecodedCid(t, "bafy2bzaced4ueelaegfs5fqu4tzsh6ywbbpfk3cxppupmxfdhbpbhzawfw5oi")}, + }, + { + name: "CARv2IndexlessSameSizeRootsAreReplaced", + path: "testdata/sample-v2-indexless.car", + roots: []cid.Cid{requireDecodedCid(t, "bafy2bzaced4ueelaegfs5fqu4tzsh6ywbbpfk3cxppupmxfdhbpbhzawfw5oi")}, + }, + { + name: "CARv1ZeroLenSameSizeRootsAreReplaced", + path: "testdata/sample-v1-with-zero-len-section.car", + roots: []cid.Cid{requireDecodedCid(t, "bafy2bzaced4ueelaegfs5fqu4tzsh6ywbbpfk3cxppupmxfdhbpbhzawfw5o5")}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Make a copy of input files to preserve original for comparison. + // This also avoids modification files in testdata. + tmpCopy := requireTmpCopy(t, tt.path) + err := ReplaceRootsInFile(tmpCopy, tt.roots) + if tt.wantErrMsg != "" { + require.EqualError(t, err, tt.wantErrMsg) + return + } + require.NoError(t, err) + + original, err := os.Open(tt.path) + require.NoError(t, err) + defer func() { require.NoError(t, original.Close()) }() + + target, err := os.Open(tmpCopy) + require.NoError(t, err) + defer func() { require.NoError(t, target.Close()) }() + + // Assert file size has not changed. + wantStat, err := original.Stat() + require.NoError(t, err) + gotStat, err := target.Stat() + require.NoError(t, err) + require.Equal(t, wantStat.Size(), gotStat.Size()) + + wantReader, err := NewBlockReader(original, ZeroLengthSectionAsEOF(true)) + require.NoError(t, err) + gotReader, err := NewBlockReader(target, ZeroLengthSectionAsEOF(true)) + require.NoError(t, err) + + // Assert roots are replaced. + require.Equal(t, tt.roots, gotReader.Roots) + + // Assert data blocks are identical. + for { + wantNext, wantErr := wantReader.Next() + gotNext, gotErr := gotReader.Next() + if wantErr == io.EOF { + require.Equal(t, io.EOF, gotErr) + break + } + require.NoError(t, wantErr) + require.NoError(t, gotErr) + require.Equal(t, wantNext, gotNext) + } + }) + } +} + +func requireDecodedCid(t *testing.T, s string) cid.Cid { + decoded, err := cid.Decode(s) + require.NoError(t, err) + return decoded +} + +func requireTmpCopy(t *testing.T, src string) string { + srcF, err := os.Open(src) + require.NoError(t, err) + defer func() { require.NoError(t, srcF.Close()) }() + stats, err := srcF.Stat() + require.NoError(t, err) + + dst := filepath.Join(t.TempDir(), stats.Name()) + dstF, err := os.Create(dst) + require.NoError(t, err) + defer func() { require.NoError(t, dstF.Close()) }() + + _, err = io.Copy(dstF, srcF) + require.NoError(t, err) + return dst +}