From 4896efd286cee5a2ecce986f2fb489848bd4cf29 Mon Sep 17 00:00:00 2001 From: pierre Date: Wed, 4 Nov 2020 11:43:29 +0100 Subject: [PATCH] Reader: concurrency working except for WriteTo --- go.mod | 2 - internal/lz4stream/block.go | 85 +++++++----- internal/lz4stream/frame.go | 5 +- options.go | 5 + reader.go | 25 ++-- reader_test.go | 251 ++++++++++++++++++++---------------- 6 files changed, 213 insertions(+), 160 deletions(-) diff --git a/go.mod b/go.mod index 2263e7f9..42229b29 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/pierrec/lz4/v4 go 1.14 - -require github.com/pierrec/lz4 v2.6.0+incompatible // indirect diff --git a/internal/lz4stream/block.go b/internal/lz4stream/block.go index 720bd710..f3f1a975 100644 --- a/internal/lz4stream/block.go +++ b/internal/lz4stream/block.go @@ -3,12 +3,11 @@ package lz4stream import ( "encoding/binary" "fmt" - "io" - "sync" - "github.com/pierrec/lz4/v4/internal/lz4block" "github.com/pierrec/lz4/v4/internal/lz4errors" "github.com/pierrec/lz4/v4/internal/xxh32" + "io" + "sync" ) type Blocks struct { @@ -87,6 +86,11 @@ func (b *Blocks) ErrorR() error { return b.err } +type BlockResult struct { + Data []byte + Err error +} + // initR returns a channel that streams the uncompressed blocks if in concurrent // mode and no error. When the channel is closed, check for any error with b.ErrorR. // @@ -105,11 +109,16 @@ func (b *Blocks) initR(f *Frame, num int, src io.Reader) (chan []byte, error) { data := make(chan []byte) // Read blocks from the source sequentially // and uncompress them concurrently. + + // In legacy mode, accrue the uncompress sizes in cum. + var cum uint32 go func() { + var cumx uint32 + var err error for b.ErrorR() == nil { block := NewFrameDataBlock(size) - if err := block.Read(f, src); err != nil { - b.closeR(err) + cumx, err = block.Read(f, src, cum) + if err != nil { break } // Recheck for an error as reading may be slow and uncompressing is expensive. @@ -119,7 +128,7 @@ func (b *Blocks) initR(f *Frame, num int, src io.Reader) (chan []byte, error) { c := make(chan []byte) blocks <- c go func() { - data, err := block.Uncompress(f, size.Get()) + data, err := block.Uncompress(f, size.Get(), false) if err != nil { b.closeR(err) } else { @@ -132,11 +141,15 @@ func (b *Blocks) initR(f *Frame, num int, src io.Reader) (chan []byte, error) { blocks <- c c <- nil // signal the collection loop that we are done <-c // wait for the collect loop to complete + if f.isLegacy() && cum == cumx { + err = io.EOF + } + b.closeR(err) close(data) }() // Collect the uncompressed blocks and make them available // on the returned channel. - go func() { + go func(leg bool) { defer close(blocks) for c := range blocks { buf := <-c @@ -145,11 +158,18 @@ func (b *Blocks) initR(f *Frame, num int, src io.Reader) (chan []byte, error) { close(c) return } + // Perform checksum now as the blocks are received in order. + if f.Descriptor.Flags.ContentChecksum() { + _, _ = f.checksum.Write(buf) + } + if leg { + cum += uint32(len(buf)) + } data <- buf - size.Put(buf) + //size.Put(buf) close(c) } - }() + }(f.isLegacy()) return data, nil } @@ -173,14 +193,12 @@ type FrameDataBlock struct { Checksum uint32 data []byte // buffer for compressed data src []byte // uncompressed data - done bool // for legacy support err error // used in concurrent mode } func (b *FrameDataBlock) Close(f *Frame) { b.Size = 0 b.Checksum = 0 - b.done = false b.err = nil if b.data != nil { // Block was not already closed. @@ -224,6 +242,8 @@ func (b *FrameDataBlock) Compress(f *Frame, src []byte, level lz4block.Compressi } func (b *FrameDataBlock) Write(f *Frame, dst io.Writer) error { + // Write is called in the same order as blocks are compressed, + // so content checksum must be done here. if f.Descriptor.Flags.ContentChecksum() { _, _ = f.checksum.Write(b.src) } @@ -246,45 +266,47 @@ func (b *FrameDataBlock) Write(f *Frame, dst io.Writer) error { } // Read updates b with the next block data, size and checksum if available. -func (b *FrameDataBlock) Read(f *Frame, src io.Reader) error { +func (b *FrameDataBlock) Read(f *Frame, src io.Reader, cum uint32) (uint32, error) { x, err := f.readUint32(src) if err != nil { - return err + return 0, err } - switch leg := f.isLegacy(); { - case leg && x == frameMagicLegacy: - // Concatenated legacy frame. - return b.Read(f, src) - case leg && b.done: - // In legacy mode, all blocks are of size 8Mb. - // When a uncompressed block size is less than 8Mb, - // then it means the end of the stream is reached. - return io.EOF - case !leg && x == 0: + if f.isLegacy() { + switch x { + case frameMagicLegacy: + // Concatenated legacy frame. + return b.Read(f, src, cum) + case cum: + // Only works in non concurrent mode, for concurrent mode + // it is handled separately. + // Linux kernel format appends the total uncompressed size at the end. + return 0, io.EOF + } + } else if x == 0 { // Marker for end of stream. - return io.EOF + return 0, io.EOF } b.Size = DataBlockSize(x) size := b.Size.size() if size > cap(b.data) { - return lz4errors.ErrOptionInvalidBlockSize + return x, lz4errors.ErrOptionInvalidBlockSize } b.data = b.data[:size] if _, err := io.ReadFull(src, b.data); err != nil { - return err + return x, err } if f.Descriptor.Flags.BlockChecksum() { sum, err := f.readUint32(src) if err != nil { - return err + return 0, err } b.Checksum = sum } - return nil + return x, nil } -func (b *FrameDataBlock) Uncompress(f *Frame, dst []byte) ([]byte, error) { +func (b *FrameDataBlock) Uncompress(f *Frame, dst []byte, sum bool) ([]byte, error) { if b.Size.Uncompressed() { n := copy(dst, b.data) dst = dst[:n] @@ -294,9 +316,6 @@ func (b *FrameDataBlock) Uncompress(f *Frame, dst []byte) ([]byte, error) { return nil, err } dst = dst[:n] - if f.isLegacy() && uint32(n) < lz4block.Block8Mb { - b.done = true - } } if f.Descriptor.Flags.BlockChecksum() { if c := xxh32.ChecksumZero(dst); c != b.Checksum { @@ -304,7 +323,7 @@ func (b *FrameDataBlock) Uncompress(f *Frame, dst []byte) ([]byte, error) { return nil, err } } - if f.Descriptor.Flags.ContentChecksum() { + if sum && f.Descriptor.Flags.ContentChecksum() { _, _ = f.checksum.Write(dst) } return dst, nil diff --git a/internal/lz4stream/frame.go b/internal/lz4stream/frame.go index 9c1b121a..cfbd5674 100644 --- a/internal/lz4stream/frame.go +++ b/internal/lz4stream/frame.go @@ -4,11 +4,12 @@ package lz4stream import ( "encoding/binary" "fmt" + "io" + "io/ioutil" + "github.com/pierrec/lz4/v4/internal/lz4block" "github.com/pierrec/lz4/v4/internal/lz4errors" "github.com/pierrec/lz4/v4/internal/xxh32" - "io" - "io/ioutil" ) //go:generate go run gen.go diff --git a/options.go b/options.go index 6ab82214..4e1b6703 100644 --- a/options.go +++ b/options.go @@ -193,6 +193,11 @@ func OnBlockDoneOption(handler func(size int)) Option { // LegacyOption provides support for writing LZ4 frames in the legacy format. // // See https://github.com/lz4/lz4/blob/dev/doc/lz4_Frame_format.md#legacy-frame. +// +// NB. compressed Linux kernel images use a tweaked LZ4 legacy format where +// the compressed stream is followed by the original (uncompressed) size of +// the kernel (https://events.static.linuxfound.org/sites/events/files/lcjpcojp13_klee.pdf). +// This is also supported as a special case. func LegacyOption(legacy bool) Option { return func(a applier) error { switch rw := a.(type) { diff --git a/reader.go b/reader.go index 764d101e..a4d44c4b 100644 --- a/reader.go +++ b/reader.go @@ -38,6 +38,7 @@ type Reader struct { reads chan []byte // pending data idx int // size of pending data handler func(int) + cum uint32 } func (*Reader) private() {} @@ -83,6 +84,7 @@ func (r *Reader) init() error { r.idx = 0 size := r.frame.Descriptor.Flags.BlockSizeIndex() r.data = size.Get() + r.cum = 0 return nil } @@ -100,14 +102,17 @@ func (r *Reader) Read(buf []byte) (n int, err error) { default: return 0, r.state.fail() } - var bn int for len(buf) > 0 { + var bn int if r.idx == 0 { - if !r.isNotConcurrent() { - r.data = <-r.reads - err = r.frame.Blocks.ErrorR() - } else { + if r.isNotConcurrent() { bn, err = r.read(buf) + } else { + r.data = <-r.reads + if len(r.data) == 0 { + // No uncompressed data: something went wrong or we are done. + err = r.frame.Blocks.ErrorR() + } } switch err { case nil: @@ -145,7 +150,8 @@ func (r *Reader) Read(buf []byte) (n int, err error) { // - else, the uncompress data is stored in r.data and 0 is returned func (r *Reader) read(buf []byte) (int, error) { block := r.frame.Blocks.Block - if err := block.Read(r.frame, r.src); err != nil { + _, err := block.Read(r.frame, r.src, r.cum) + if err != nil { return 0, err } var direct bool @@ -155,10 +161,11 @@ func (r *Reader) read(buf []byte) (int, error) { direct = true dst = buf } - dst, err := block.Uncompress(r.frame, dst) + dst, err = block.Uncompress(r.frame, dst, true) if err != nil { return 0, err } + r.cum += uint32(len(dst)) if direct { return len(dst), nil } @@ -202,7 +209,7 @@ func (r *Reader) WriteTo_(w io.Writer) (n int64, err error) { size := r.frame.Descriptor.Flags.BlockSizeIndex() data := size.Get() for { - err = block.Read(r.frame, r.src) + _, err = block.Read(r.frame, r.src, 0) switch err { case nil: case io.EOF: @@ -212,7 +219,7 @@ func (r *Reader) WriteTo_(w io.Writer) (n int64, err error) { return } dst := data - dst, err = block.Uncompress(r.frame, dst) + dst, err = block.Uncompress(r.frame, dst, true) if err != nil { return } diff --git a/reader_test.go b/reader_test.go index d6bf1874..010f3862 100644 --- a/reader_test.go +++ b/reader_test.go @@ -3,6 +3,7 @@ package lz4_test import ( "bytes" "errors" + "fmt" "io" "io/ioutil" "os" @@ -13,6 +14,10 @@ import ( "github.com/pierrec/lz4/v4" ) +func _o(s ...lz4.Option) []lz4.Option { + return s +} + func TestReader(t *testing.T) { goldenFiles := []string{ "testdata/e.txt.lz4", @@ -27,57 +32,66 @@ func TestReader(t *testing.T) { } for _, fname := range goldenFiles { - t.Run(fname, func(t *testing.T) { - fname := fname - t.Parallel() - - f, err := os.Open(fname) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - rawfile := strings.TrimSuffix(fname, ".lz4") - raw, err := ioutil.ReadFile(rawfile) - if err != nil { - t.Fatal(err) - } - - out := new(bytes.Buffer) - zr := lz4.NewReader(f) - n, err := io.Copy(out, zr) - if err != nil { - t.Fatal(err) - } - - if got, want := int(n), len(raw); got != want { - t.Errorf("invalid size: got %d; want %d", got, want) - } - - if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) { - t.Fatal("uncompressed data does not match original") - } - - if len(raw) < 20 { - return - } - - f2, err := os.Open(fname) - if err != nil { - t.Fatal(err) - } - defer f2.Close() - - out.Reset() - zr = lz4.NewReader(f2) - _, err = io.CopyN(out, zr, 10) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(out.Bytes(), raw[:10]) { - t.Fatal("partial read does not match original") - } - }) + for _, opts := range [][]lz4.Option{ + nil, + _o(lz4.ConcurrencyOption(-1)), + } { + label := fmt.Sprintf("%s %v", fname, opts) + t.Run(label, func(t *testing.T) { + fname := fname + t.Parallel() + + f, err := os.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + rawfile := strings.TrimSuffix(fname, ".lz4") + raw, err := ioutil.ReadFile(rawfile) + if err != nil { + t.Fatal(err) + } + + out := new(bytes.Buffer) + zr := lz4.NewReader(f) + if err := zr.Apply(opts...); err != nil { + t.Fatal(err) + } + n, err := io.Copy(out, zr) + if err != nil { + t.Error(err) + } + + if got, want := int(n), len(raw); got != want { + t.Errorf("invalid size: got %d; want %d", got, want) + } + + if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) { + t.Fatal("uncompressed data does not match original") + } + + if len(raw) < 20 { + return + } + + f2, err := os.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + + out.Reset() + zr = lz4.NewReader(f2) + _, err = io.CopyN(out, zr, 10) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out.Bytes(), raw[:10]) { + t.Fatal("partial read does not match original") + } + }) + } } } @@ -135,72 +149,81 @@ func TestWriteToBrokenWriter(t *testing.T) { func TestReaderLegacy(t *testing.T) { goldenFiles := []string{ - //"testdata/vmlinux_LZ4_19377.lz4", + "testdata/vmlinux_LZ4_19377.lz4", "testdata/bzImage_lz4_isolated.lz4", } for _, fname := range goldenFiles { - t.Run(fname, func(t *testing.T) { - fname := fname - t.Parallel() - - var out bytes.Buffer - rawfile := strings.TrimSuffix(fname, ".lz4") - raw, err := ioutil.ReadFile(rawfile) - if err != nil { - t.Fatal(err) - } - - f, err := os.Open(fname) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - zr := lz4.NewReader(f) - n, err := io.Copy(&out, zr) - if err != nil { - t.Fatal(err, n) - } - - if got, want := int(n), len(raw); got != want { - t.Errorf("invalid sizes: got %d; want %d", got, want) - } - - if got, want := out.Bytes(), raw; !bytes.Equal(got, want) { - t.Fatal("uncompressed data does not match original") - } - - if len(raw) < 20 { - return - } - - f2, err := os.Open(fname) - if err != nil { - t.Fatal(err) - } - defer f2.Close() - - out.Reset() - zr = lz4.NewReader(f2) - _, err = io.CopyN(&out, zr, 10) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(out.Bytes(), raw[:10]) { - t.Fatal("partial read does not match original") - } - - out.Reset() - _, err = io.CopyN(&out, zr, 10) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(out.Bytes(), raw[10:20]) { - t.Fatal("after seek, partial read does not match original") - } - }) + for _, opts := range [][]lz4.Option{ + nil, + _o(lz4.ConcurrencyOption(-1)), + } { + label := fmt.Sprintf("%s %v", fname, opts) + t.Run(label, func(t *testing.T) { + fname := fname + t.Parallel() + + var out bytes.Buffer + rawfile := strings.TrimSuffix(fname, ".lz4") + raw, err := ioutil.ReadFile(rawfile) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + zr := lz4.NewReader(f) + if err := zr.Apply(opts...); err != nil { + t.Fatal(err) + } + n, err := io.Copy(&out, zr) + if err != nil { + t.Fatal(err, n) + } + + if got, want := int(n), len(raw); got != want { + t.Errorf("invalid sizes: got %d; want %d", got, want) + } + + if got, want := out.Bytes(), raw; !bytes.Equal(got, want) { + t.Fatal("uncompressed data does not match original") + } + + if len(raw) < 20 { + return + } + + f2, err := os.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + + out.Reset() + zr = lz4.NewReader(f2) + _, err = io.CopyN(&out, zr, 10) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(out.Bytes(), raw[:10]) { + t.Fatal("partial read does not match original") + } + + out.Reset() + _, err = io.CopyN(&out, zr, 10) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(out.Bytes(), raw[10:20]) { + t.Fatal("after seek, partial read does not match original") + } + }) + } } }