Skip to content

Commit 07d6c33

Browse files
authored
Improve memory usage when extracting zip archives (#21)
* zip: if the reader is capable of seeking do not buffer the entire archive * lint: Remove usage of deprecated io/util package * Avoid buffering in archive detection if the stream is seekable * Slightly increase test limits
1 parent 40e27c6 commit 07d6c33

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

extract_test.go

+14-17
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"context"
66
"fmt"
77
"io"
8-
"io/ioutil"
98
"net/http"
109
"os"
1110
"path/filepath"
@@ -187,9 +186,9 @@ func TestArchiveFailure(t *testing.T) {
187186

188187
func TestExtract(t *testing.T) {
189188
for _, test := range ExtractCases {
190-
dir, _ := ioutil.TempDir("", "")
189+
dir, _ := os.MkdirTemp("", "")
191190
dir = filepath.Join(dir, "test")
192-
data, err := ioutil.ReadFile(test.Archive)
191+
data, err := os.ReadFile(test.Archive)
193192
if err != nil {
194193
t.Fatal(err)
195194
}
@@ -222,8 +221,8 @@ func TestExtract(t *testing.T) {
222221
}
223222

224223
func BenchmarkArchive(b *testing.B) {
225-
dir, _ := ioutil.TempDir("", "")
226-
data, _ := ioutil.ReadFile("testdata/archive.tar.bz2")
224+
dir, _ := os.MkdirTemp("", "")
225+
data, _ := os.ReadFile("testdata/archive.tar.bz2")
227226

228227
b.StartTimer()
229228

@@ -244,8 +243,8 @@ func BenchmarkArchive(b *testing.B) {
244243
}
245244

246245
func BenchmarkTarBz2(b *testing.B) {
247-
dir, _ := ioutil.TempDir("", "")
248-
data, _ := ioutil.ReadFile("testdata/archive.tar.bz2")
246+
dir, _ := os.MkdirTemp("", "")
247+
data, _ := os.ReadFile("testdata/archive.tar.bz2")
249248

250249
b.StartTimer()
251250

@@ -266,8 +265,8 @@ func BenchmarkTarBz2(b *testing.B) {
266265
}
267266

268267
func BenchmarkTarGz(b *testing.B) {
269-
dir, _ := ioutil.TempDir("", "")
270-
data, _ := ioutil.ReadFile("testdata/archive.tar.gz")
268+
dir, _ := os.MkdirTemp("", "")
269+
data, _ := os.ReadFile("testdata/archive.tar.gz")
271270

272271
b.StartTimer()
273272

@@ -288,8 +287,8 @@ func BenchmarkTarGz(b *testing.B) {
288287
}
289288

290289
func BenchmarkZip(b *testing.B) {
291-
dir, _ := ioutil.TempDir("", "")
292-
data, _ := ioutil.ReadFile("testdata/archive.zip")
290+
dir, _ := os.MkdirTemp("", "")
291+
data, _ := os.ReadFile("testdata/archive.zip")
293292

294293
b.StartTimer()
295294

@@ -319,7 +318,7 @@ func testWalk(t *testing.T, dir string, testFiles Files) {
319318
} else if info.Mode()&os.ModeSymlink != 0 {
320319
files[path] = "link"
321320
} else {
322-
data, err := ioutil.ReadFile(filepath.Join(dir, path))
321+
data, err := os.ReadFile(filepath.Join(dir, path))
323322
require.NoError(t, err)
324323
files[path] = strings.TrimSpace(string(data))
325324
}
@@ -370,7 +369,7 @@ func TestTarGzMemoryConsumption(t *testing.T) {
370369
runtime.GC()
371370
runtime.ReadMemStats(&m)
372371

373-
err = extract.Gz(context.Background(), f, tmpDir.String(), nil)
372+
err = extract.Archive(context.Background(), f, tmpDir.String(), nil)
374373
require.NoError(t, err)
375374

376375
runtime.ReadMemStats(&m2)
@@ -398,7 +397,7 @@ func TestZipMemoryConsumption(t *testing.T) {
398397
runtime.GC()
399398
runtime.ReadMemStats(&m)
400399

401-
err = extract.Zip(context.Background(), f, tmpDir.String(), nil)
400+
err = extract.Archive(context.Background(), f, tmpDir.String(), nil)
402401
require.NoError(t, err)
403402

404403
runtime.ReadMemStats(&m2)
@@ -407,9 +406,7 @@ func TestZipMemoryConsumption(t *testing.T) {
407406
heapUsed = 0
408407
}
409408
fmt.Println("Heap memory used during the test:", heapUsed)
410-
// the .zip file require random access, so the full io.Reader content must be cached, since
411-
// the test file is 130MB, that's the reason for the high memory consumed.
412-
require.True(t, heapUsed < 250000000, "heap consumption should be less than 250M but is %d", heapUsed)
409+
require.True(t, heapUsed < 10000000, "heap consumption should be less than 10M but is %d", heapUsed)
413410
}
414411

415412
func download(t require.TestingT, url string, file *paths.Path) error {

extractor.go

+32-8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import (
77
"compress/bzip2"
88
"compress/gzip"
99
"context"
10+
"fmt"
1011
"io"
11-
"io/ioutil"
1212
"os"
1313
"path/filepath"
1414
"strings"
@@ -237,11 +237,27 @@ func (e *Extractor) Tar(ctx context.Context, body io.Reader, location string, re
237237
// Zip extracts a .zip archived stream of data in the specified location.
238238
// It accepts a rename function to handle the names of the files (see the example).
239239
func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, rename Renamer) error {
240-
// read the whole body into a buffer. Not sure this is the best way to do it
241-
buffer := bytes.NewBuffer([]byte{})
242-
copyCancel(ctx, buffer, body)
243-
244-
archive, err := zip.NewReader(bytes.NewReader(buffer.Bytes()), int64(buffer.Len()))
240+
var bodySize int64
241+
bodyReaderAt, isReaderAt := (body).(io.ReaderAt)
242+
if bodySeeker, isSeeker := (body).(io.Seeker); isReaderAt && isSeeker {
243+
// get the size by seeking to the end
244+
endPos, err := bodySeeker.Seek(0, io.SeekEnd)
245+
if err != nil {
246+
return fmt.Errorf("failed to seek to the end of the body: %s", err)
247+
}
248+
// reset the reader to the beginning
249+
if _, err := bodySeeker.Seek(0, io.SeekStart); err != nil {
250+
return fmt.Errorf("failed to seek to the beginning of the body: %w", err)
251+
}
252+
bodySize = endPos
253+
} else {
254+
// read the whole body into a buffer. Not sure this is the best way to do it
255+
buffer := bytes.NewBuffer([]byte{})
256+
copyCancel(ctx, buffer, body)
257+
bodyReaderAt = bytes.NewReader(buffer.Bytes())
258+
bodySize = int64(buffer.Len())
259+
}
260+
archive, err := zip.NewReader(bodyReaderAt, bodySize)
245261
if err != nil {
246262
return errors.Annotatef(err, "Read the zip file")
247263
}
@@ -290,7 +306,7 @@ func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, re
290306
case info.Mode()&os.ModeSymlink != 0:
291307
if f, err := header.Open(); err != nil {
292308
return errors.Annotatef(err, "Open link %s", path)
293-
} else if name, err := ioutil.ReadAll(f); err != nil {
309+
} else if name, err := io.ReadAll(f); err != nil {
294310
return errors.Annotatef(err, "Read address of link %s", path)
295311
} else {
296312
links = append(links, link{Path: path, Name: string(name)})
@@ -347,7 +363,15 @@ func match(r io.Reader) (io.Reader, types.Type, error) {
347363
return nil, types.Unknown, err
348364
}
349365

350-
r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r)
366+
if seeker, ok := r.(io.Seeker); ok {
367+
// if the stream is seekable, we just rewind it
368+
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
369+
return nil, types.Unknown, err
370+
}
371+
} else {
372+
// otherwise we create a new reader that will prepend the buffer
373+
r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r)
374+
}
351375

352376
typ, err := filetype.Match(buffer)
353377

0 commit comments

Comments
 (0)