Skip to content

Commit

Permalink
Update tar decompress (#200)
Browse files Browse the repository at this point in the history
* Update stream decompressor

Make the stream decompressor a bit better and cleaner. This removes some
additional code that was no longer needed once we moved to bufio.Reader
for tar peeking

* Emit log  on compressed tars

Compressed tar files for model weights can significantly impact
performance. The log will help identify when that happens so that the
user can opt to use uncompressed tar files where relevant.
  • Loading branch information
tempusfrangit authored Apr 28, 2024
1 parent 14a5144 commit 6193d8a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 26 deletions.
3 changes: 2 additions & 1 deletion pkg/consumer/tar_extractor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package consumer

import (
"bufio"
"fmt"
"io"

Expand All @@ -14,7 +15,7 @@ type TarExtractor struct {
var _ Consumer = &TarExtractor{}

func (f *TarExtractor) Consume(reader io.Reader, destPath string) error {
err := extract.TarFile(reader, destPath, f.Overwrite)
err := extract.TarFile(bufio.NewReader(reader), destPath, f.Overwrite)
if err != nil {
return fmt.Errorf("error extracting file: %w", err)
}
Expand Down
19 changes: 3 additions & 16 deletions pkg/extract/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ var _ decompressor = bzip2Decompressor{}
var _ decompressor = xzDecompressor{}
var _ decompressor = lzwDecompressor{}
var _ decompressor = lz4Decompressor{}
var _ decompressor = noOpDecompressor{}

// decompressor represents different compression formats.
type decompressor interface {
Expand All @@ -43,20 +42,13 @@ func detectFormat(input []byte) decompressor {
inputSize := len(input)

if inputSize < 2 {
return noOpDecompressor{}
return nil
}
// pad to 8 bytes
if inputSize < 8 {
input = append(input, make([]byte, peekSize-inputSize)...)
}

// magic16 := binary.BigEndian.Uint16(input)
// magic32 := binary.BigEndian.Uint32(input)
// // We need to pre-pend the padding since we're reading into something bigendian and exceeding the
// // 48bits size of the magic number bytes. The 16 and 32 bit magic numbers are complete bytes and
// // therefore do not need any padding.
// magic48 := binary.BigEndian.Uint64(append(make([]byte, 2), input[0:6]...))

switch true {
case bytes.HasPrefix(input, gzipMagic):
log.Debug().
Expand Down Expand Up @@ -95,8 +87,9 @@ func detectFormat(input []byte) decompressor {
log.Debug().
Str("type", "none").
Msg("Compression Format")
return noOpDecompressor{}
return nil
}

}

type gzipDecompressor struct{}
Expand Down Expand Up @@ -131,9 +124,3 @@ type lz4Decompressor struct{}
func (d lz4Decompressor) decompress(r io.Reader) (io.Reader, error) {
return lz4.NewReader(r), nil
}

type noOpDecompressor struct{}

func (d noOpDecompressor) decompress(r io.Reader) (io.Reader, error) {
return r, nil
}
4 changes: 2 additions & 2 deletions pkg/extract/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ func TestDetectFormat(t *testing.T) {
{
name: "Less than 2 bytes",
input: []byte{0x1f},
expectType: "extract.noOpDecompressor",
expectType: "",
},
{
name: "UNKNOWN",
input: []byte{0xde, 0xad},
expectType: "extract.noOpDecompressor",
expectType: "",
},
}

Expand Down
20 changes: 13 additions & 7 deletions pkg/extract/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,25 @@ type link struct {
newName string
}

func TarFile(r io.Reader, destDir string, overwrite bool) error {
func TarFile(r *bufio.Reader, destDir string, overwrite bool) error {
var links []*link
var reader io.Reader = r

log := logging.GetLogger()

startTime := time.Now()
peekableReader := bufio.NewReader(r)
peekData, err := peekableReader.Peek(peekSize)
peekData, err := r.Peek(peekSize)
if err != nil {
return fmt.Errorf("error reading peek data: %w", err)
}
decompressor := detectFormat(peekData)
reader, err := decompressor.decompress(peekableReader)
if err != nil {
return fmt.Errorf("error creating decompressed stream: %w", err)
if decompressor := detectFormat(peekData); decompressor != nil {
reader, err = decompressor.decompress(reader)
if err != nil {
return fmt.Errorf("error creating decompressed stream: %w", err)
}
log.Info().
Str("decompressor", fmt.Sprintf("%T", decompressor)).
Msg("Tar Compression Detected: Compression can significantly slowdown pget (e.g. for model weights)")
}
tarReader := tar.NewReader(reader)
logger := logging.GetLogger()
Expand Down

0 comments on commit 6193d8a

Please sign in to comment.