diff --git a/frame.go b/frame.go index 5f421c95..bbf89265 100644 --- a/frame.go +++ b/frame.go @@ -31,11 +31,11 @@ func (f *Frame) closeW(w *Writer) error { return err } buf := w.buf[:0] + // End mark (data block size of uint32(0)). + buf = append(buf, 0, 0, 0, 0) if f.Descriptor.Flags.ContentChecksum() { buf = f.checksum.Sum(buf) } - // End mark (data block size of uint32(0)). - buf = append(buf, 0, 0, 0, 0) _, err := w.src.Write(buf) return err } @@ -62,7 +62,7 @@ newFrame: } goto newFrame default: - return ErrInvalid + return ErrInvalidFrame } if err := f.Descriptor.initR(r); err != nil { return err @@ -103,14 +103,16 @@ func (fd *FrameDescriptor) write(w *Writer) error { return nil } - buf := w.buf[:2] - binary.LittleEndian.PutUint16(buf, uint16(fd.Flags)) + buf := w.buf[:4+2] + // Write the magic number here even though it belongs to the Frame. + binary.LittleEndian.PutUint32(buf, w.frame.Magic) + binary.LittleEndian.PutUint16(buf[4:], uint16(fd.Flags)) if fd.Flags.Size() { - buf = buf[:10] - binary.LittleEndian.PutUint64(buf[2:], fd.ContentSize) + buf = buf[:4+2+8] + binary.LittleEndian.PutUint64(buf[4+2:], fd.ContentSize) } - fd.Checksum = descriptorChecksum(buf) + fd.Checksum = descriptorChecksum(buf[4:]) buf = append(buf, fd.Checksum) _, err := w.src.Write(buf) diff --git a/lz4.go b/lz4.go index a7649fd9..079f117e 100644 --- a/lz4.go +++ b/lz4.go @@ -28,28 +28,24 @@ const ( // ErrInvalidSourceShortBuffer is returned by UncompressBlock or CompressBLock when a compressed // block is corrupted or the destination buffer is not large enough for the uncompressed data. ErrInvalidSourceShortBuffer _error = "lz4: invalid source or destination buffer too short" - // ErrClosed is returned when calling Write/Read or Close on an already closed Writer/Reader. - ErrClosed _error = "lz4: closed Writer" - // ErrInvalid is returned when reading an invalid LZ4 archive. - ErrInvalid _error = "lz4: bad magic number" - // ErrBlockDependency is returned when attempting to decompress an archive created with block dependency. - ErrBlockDependency _error = "lz4: block dependency not supported" + // ErrInvalidFrame is returned when reading an invalid LZ4 archive. + ErrInvalidFrame _error = "lz4: bad magic number" // ErrUnsupportedSeek is returned when attempting to Seek any way but forward from the current position. ErrUnsupportedSeek _error = "lz4: can only seek forward from io.SeekCurrent" // ErrInternalUnhandledState is an internal error. ErrInternalUnhandledState _error = "lz4: unhandled state" - // ErrInvalidHeaderChecksum + // ErrInvalidHeaderChecksum is returned when reading a frame. ErrInvalidHeaderChecksum _error = "lz4: invalid header checksum" - // ErrInvalidBlockChecksum + // ErrInvalidBlockChecksum is returned when reading a frame. ErrInvalidBlockChecksum _error = "lz4: invalid block checksum" - // ErrInvalidFrameChecksum + // ErrInvalidFrameChecksum is returned when reading a frame. ErrInvalidFrameChecksum _error = "lz4: invalid frame checksum" - // ErrInvalidCompressionLevel - ErrInvalidCompressionLevel _error = "lz4: invalid compression level" - // ErrCannotApplyOptions - ErrCannotApplyOptions _error = "lz4: cannot apply options" - // ErrInvalidBlockSize - ErrInvalidBlockSize _error = "lz4: invalid block size" - // ErrOptionNotApplicable + // ErrOptionInvalidCompressionLevel is returned when the supplied compression level is invalid. + ErrOptionInvalidCompressionLevel _error = "lz4: invalid compression level" + // ErrOptionClosedOrError is returned when an option is applied to a closed or in error object. + ErrOptionClosedOrError _error = "lz4: cannot apply options on closed or in error object" + // ErrOptionInvalidBlockSize is returned when + ErrOptionInvalidBlockSize _error = "lz4: invalid block size" + // ErrOptionNotApplicable is returned when trying to apply an option to an object not supporting it. ErrOptionNotApplicable _error = "lz4: option not applicable" ) diff --git a/options.go b/options.go index 1d04fb01..7acaea64 100644 --- a/options.go +++ b/options.go @@ -2,6 +2,7 @@ package lz4 import ( "fmt" + "reflect" "runtime" "sync" ) @@ -17,6 +18,11 @@ type ( Option func(Applier) error ) +func (o Option) String() string { + //TODO proper naming of options + return reflect.TypeOf(o).String() +} + // Default options. var ( defaultBlockSizeOption = BlockSizeOption(Block4Mb) @@ -98,7 +104,7 @@ func BlockSizeOption(size BlockSize) Option { return ErrOptionNotApplicable } if !size.isValid() { - return fmt.Errorf("%w: %d", ErrInvalidBlockSize, size) + return fmt.Errorf("%w: %d", ErrOptionInvalidBlockSize, size) } w.frame.Descriptor.Flags.BlockSizeIndexSet(size.index()) return nil @@ -188,7 +194,7 @@ func CompressionLevelOption(level CompressionLevel) Option { switch level { case Fast, Level1, Level2, Level3, Level4, Level5, Level6, Level7, Level8, Level9: default: - return fmt.Errorf("%w: %d", ErrInvalidCompressionLevel, level) + return fmt.Errorf("%w: %d", ErrOptionInvalidCompressionLevel, level) } w.level = level return nil diff --git a/reader.go b/reader.go index f880b799..20b52150 100644 --- a/reader.go +++ b/reader.go @@ -7,8 +7,7 @@ import ( var readerStates = []aState{ noState: newState, errorState: newState, - newState: headerState, - headerState: readState, + newState: readState, readState: closedState, closedState: newState, } @@ -40,7 +39,7 @@ func (r *Reader) Apply(options ...Option) (err error) { case errorState: return r.state.err default: - return ErrCannotApplyOptions + return ErrOptionClosedOrError } for _, o := range options { if err = o(r); err != nil { @@ -69,7 +68,6 @@ func (r *Reader) Read(buf []byte) (n int, err error) { return 0, r.state.err case newState: // First initialization. - r.state.next(nil) if err = r.frame.initR(r); r.state.next(err) { return } diff --git a/reader_test.go b/reader_test.go index 8c2db94d..7b1ff485 100644 --- a/reader_test.go +++ b/reader_test.go @@ -42,9 +42,9 @@ func TestReader(t *testing.T) { t.Fatal(err) } - var out bytes.Buffer + out := new(bytes.Buffer) zr := lz4.NewReader(f) - n, err := io.Copy(&out, zr) + n, err := io.Copy(out, zr) if err != nil { t.Fatal(err) } @@ -69,7 +69,7 @@ func TestReader(t *testing.T) { out.Reset() zr = lz4.NewReader(f2) - _, err = io.CopyN(&out, zr, 10) + _, err = io.CopyN(out, zr, 10) if err != nil { t.Fatal(err) } @@ -78,6 +78,7 @@ func TestReader(t *testing.T) { } return + //TODO add Reader.Seek pos, err := zr.Seek(-1, io.SeekCurrent) if err == nil { t.Fatal("expected error from invalid seek") @@ -109,7 +110,7 @@ func TestReader(t *testing.T) { } out.Reset() - _, err = io.CopyN(&out, zr, 10) + _, err = io.CopyN(out, zr, 10) if err != nil { t.Fatal(err) } diff --git a/state.go b/state.go index 2566e670..f9f9cc50 100644 --- a/state.go +++ b/state.go @@ -12,7 +12,6 @@ const ( noState aState = iota // uninitialized reader errorState // unrecoverable error encountered newState // instantiated object - headerState // processing header readState // reading data writeState // writing data closedState // all done diff --git a/state_gen.go b/state_gen.go index 76ce43d6..75fb8289 100644 --- a/state_gen.go +++ b/state_gen.go @@ -11,15 +11,14 @@ func _() { _ = x[noState-0] _ = x[errorState-1] _ = x[newState-2] - _ = x[headerState-3] - _ = x[readState-4] - _ = x[writeState-5] - _ = x[closedState-6] + _ = x[readState-3] + _ = x[writeState-4] + _ = x[closedState-5] } -const _aState_name = "noStateerrorStatenewStateheaderStatereadStatewriteStateclosedState" +const _aState_name = "noStateerrorStatenewStatereadStatewriteStateclosedState" -var _aState_index = [...]uint8{0, 7, 17, 25, 36, 45, 55, 66} +var _aState_index = [...]uint8{0, 7, 17, 25, 34, 44, 55} func (i aState) String() string { if i >= aState(len(_aState_index)-1) { diff --git a/writer.go b/writer.go index 5b3f57b1..43957ac3 100644 --- a/writer.go +++ b/writer.go @@ -4,8 +4,7 @@ import "io" var writerStates = []aState{ noState: newState, - newState: headerState, - headerState: writeState, + newState: writeState, writeState: closedState, closedState: newState, errorState: newState, @@ -21,7 +20,7 @@ func NewWriter(w io.Writer) *Writer { type Writer struct { state _State - buf [11]byte // frame descriptor needs at most 4+8+1=11 bytes + buf [15]byte // frame descriptor needs at most 4(magic)+4+8+1=11 bytes src io.Writer // destination writer level CompressionLevel // how hard to try num int // concurrency level @@ -41,7 +40,7 @@ func (w *Writer) Apply(options ...Option) (err error) { case errorState: return w.state.err default: - return ErrCannotApplyOptions + return ErrOptionClosedOrError } for _, o := range options { if err = o(w); err != nil { @@ -62,7 +61,6 @@ func (w *Writer) Write(buf []byte) (n int, err error) { case closedState, errorState: return 0, w.state.err case newState: - w.state.next(nil) if err = w.frame.Descriptor.write(w); w.state.next(err) { return } @@ -74,7 +72,7 @@ func (w *Writer) Write(buf []byte) (n int, err error) { for len(buf) > 0 { if w.idx == 0 && len(buf) >= zn { // Avoid a copy as there is enough data for a block. - if err = w.write(); err != nil { + if err = w.write(buf[:zn], false); err != nil { return } n += zn @@ -93,7 +91,7 @@ func (w *Writer) Write(buf []byte) (n int, err error) { } // Buffer full. - if err = w.write(); err != nil { + if err = w.write(w.data, true); err != nil { return } w.idx = 0 @@ -101,10 +99,11 @@ func (w *Writer) Write(buf []byte) (n int, err error) { return } -func (w *Writer) write() error { +func (w *Writer) write(data []byte, direct bool) error { if w.isNotConcurrent() { - defer w.handler(len(w.data)) - return w.frame.Blocks.Block.compress(w, w.data, w.ht).write(w) + defer w.handler(len(data)) + block := w.frame.Blocks.Block + return block.compress(w, data, w.ht).write(w) } size := w.frame.Descriptor.Flags.BlockSizeIndex() c := make(chan *FrameDataBlock) @@ -122,20 +121,18 @@ func (w *Writer) write() error { size.put(data) <-c size.put(zdata) - }(c, w.data, size) + }(c, data, size) - if w.idx > 0 { - // Not closed. + if direct { w.data = size.get() } - w.idx = 0 return nil } // Close closes the Writer, flushing any unwritten data to the underlying io.Writer, // but does not close the underlying io.Writer. -func (w *Writer) Close() error { +func (w *Writer) Close() (err error) { switch w.state.state { case writeState: case errorState: @@ -143,21 +140,19 @@ func (w *Writer) Close() error { default: return nil } - var err error defer func() { w.state.next(err) }() - if idx := w.idx; idx > 0 { + if w.idx > 0 { // Flush pending data. - w.data = w.data[:idx] - w.idx = 0 - if err = w.write(); err != nil { + if err = w.write(w.data[:w.idx], false); err != nil { return err } - w.data = nil + w.idx = 0 } if w.isNotConcurrent() { htPool.Put(w.ht) size := w.frame.Descriptor.Flags.BlockSizeIndex() size.put(w.data) + w.data = nil } return w.frame.closeW(w) } diff --git a/_writer_test.go b/writer_test.go similarity index 54% rename from _writer_test.go rename to writer_test.go index 29dad5a2..96e07e99 100644 --- a/_writer_test.go +++ b/writer_test.go @@ -8,6 +8,8 @@ import ( "os" "reflect" "testing" + + "github.com/pierrec/lz4" ) func TestWriter(t *testing.T) { @@ -23,72 +25,67 @@ func TestWriter(t *testing.T) { } for _, fname := range goldenFiles { - for _, size := range []int{0, 4} { - for _, header := range []Header{ - {}, // Default header. - {BlockChecksum: true}, - {NoChecksum: true}, - {BlockMaxSize: 64 << 10}, // 64Kb - {CompressionLevel: 10}, - {Size: 123}, - } { - label := fmt.Sprintf("%s/%s", fname, header) - t.Run(label, func(t *testing.T) { - fname := fname - header := header - t.Parallel() - - raw, err := ioutil.ReadFile(fname) - if err != nil { - t.Fatal(err) - } - r := bytes.NewReader(raw) - - // Compress. - var zout bytes.Buffer - zw := NewWriter(&zout) - zw.Header = header - zw.WithConcurrency(size) - _, err = io.Copy(zw, r) - if err != nil { - t.Fatal(err) - } - err = zw.Close() - if err != nil { - t.Fatal(err) - } - - // Uncompress. - var out bytes.Buffer - zr := NewReader(&zout) - n, err := io.Copy(&out, zr) - if err != nil { - t.Fatal(err) - } - - // The uncompressed data must be the same as the initial input. - if got, want := int(n), len(raw); got != want { - t.Errorf("invalid sizes: got %d; want %d", got, want) - } - - if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) { - t.Fatal("uncompressed data does not match original") - } - }) - } + for _, option := range []lz4.Option{ + lz4.ConcurrencyOption(1), + //lz4.BlockChecksumOption(true), + //lz4.SizeOption(123), + //lz4.ConcurrencyOption(2), + } { + label := fmt.Sprintf("%s/%s", fname, option) + t.Run(label, func(t *testing.T) { + fname := fname + t.Parallel() + + raw, err := ioutil.ReadFile(fname) + if err != nil { + t.Fatal(err) + } + r := bytes.NewReader(raw) + + // Compress. + zout := new(bytes.Buffer) + zw := lz4.NewWriter(zout) + if err := zw.Apply(option); err != nil { + t.Fatal(err) + } + _, err = io.Copy(zw, r) + if err != nil { + t.Fatal(err) + } + err = zw.Close() + if err != nil { + t.Fatal(err) + } + + // Uncompress. + out := new(bytes.Buffer) + zr := lz4.NewReader(zout) + n, err := io.Copy(out, zr) + if err != nil { + t.Fatal(err) + } + + // The uncompressed data must be the same as the initial input. + if got, want := int(n), len(raw); got != want { + t.Errorf("invalid sizes: got %d; want %d", got, want) + } + + if got, want := out.Bytes(), raw; !reflect.DeepEqual(got, want) { + t.Fatal("uncompressed data does not match original") + } + }) } } } func TestIssue41(t *testing.T) { r, w := io.Pipe() - zw := NewWriter(w) - zr := NewReader(r) + zw := lz4.NewWriter(w) + zr := lz4.NewReader(r) data := "x" go func() { _, _ = fmt.Fprint(zw, data) - _ = zw.Flush() _ = zw.Close() _ = w.Close() }() @@ -110,7 +107,7 @@ func TestIssue43(t *testing.T) { } defer f.Close() - zw := NewWriter(w) + zw := lz4.NewWriter(w) defer zw.Close() _, err = io.Copy(zw, f) @@ -118,7 +115,7 @@ func TestIssue43(t *testing.T) { t.Fatal(err) } }() - _, err := io.Copy(ioutil.Discard, NewReader(r)) + _, err := io.Copy(ioutil.Discard, lz4.NewReader(r)) if err != nil { t.Fatal(err) } @@ -131,16 +128,15 @@ func TestIssue51(t *testing.T) { } zbuf := make([]byte, 8192) - ht := make([]int, htSize) - n, err := CompressBlock(data, zbuf, ht) + n, err := lz4.CompressBlock(data, zbuf, nil) if err != nil { t.Fatal(err) } zbuf = zbuf[:n] buf := make([]byte, 8192) - n, err = UncompressBlock(zbuf, buf) + n, err = lz4.UncompressBlock(zbuf, buf) if err != nil { t.Fatal(err) } @@ -157,11 +153,11 @@ func TestIssue71(t *testing.T) { } { t.Run(tc, func(t *testing.T) { src := []byte(tc) - bound := CompressBlockBound(len(tc)) + bound := lz4.CompressBlockBound(len(tc)) // Small buffer. zSmall := make([]byte, bound-1) - n, err := CompressBlock(src, zSmall, nil) + n, err := lz4.CompressBlock(src, zSmall, nil) if err != nil { t.Fatal(err) } @@ -171,7 +167,7 @@ func TestIssue71(t *testing.T) { // Large enough buffer. zLarge := make([]byte, bound) - n, err = CompressBlock(src, zLarge, nil) + n, err = lz4.CompressBlock(src, zLarge, nil) if err != nil { t.Fatal(err) }