Skip to content

Commit 9eb2bfd

Browse files
authored
Abort multi part download if the object is modified during download
* add version control of downloader * add changelog
1 parent 8d203cc commit 9eb2bfd

File tree

3 files changed

+187
-2
lines changed

3 files changed

+187
-2
lines changed

CHANGELOG_PENDING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
### SDK Bugs
66
* Fix improper use of printf-style functions.
77
* Required for Go 1.24.
8+
* `service/s3/s3manager`: Abort multipart download if object is modified during download
9+
* Fixes [4986](https://github.com/aws/aws-sdk-go/issues/4986)

service/s3/s3manager/download.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,15 @@ type downloader struct {
290290
in *s3.GetObjectInput
291291
w io.WriterAt
292292

293-
wg sync.WaitGroup
294-
m sync.Mutex
293+
wg sync.WaitGroup
294+
m sync.Mutex
295+
once sync.Once
295296

296297
pos int64
297298
totalBytes int64
298299
written int64
299300
err error
301+
etag string
300302

301303
partBodyMaxRetries int
302304
}
@@ -424,6 +426,9 @@ func (d *downloader) downloadChunk(chunk dlchunk) error {
424426

425427
// Get the next byte range of data
426428
in.Range = aws.String(chunk.ByteRange())
429+
if in.VersionId == nil && d.etag != "" {
430+
in.IfMatch = aws.String(d.etag)
431+
}
427432

428433
var n int64
429434
var err error
@@ -466,6 +471,9 @@ func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64
466471
return 0, err
467472
}
468473
d.setTotalBytes(resp) // Set total if not yet set.
474+
d.once.Do(func() {
475+
d.etag = aws.StringValue(resp.ETag)
476+
})
469477

470478
var src io.Reader = resp.Body
471479
if d.cfg.BufferProvider != nil {

service/s3/s3manager/download_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import (
3030
"github.com/aws/aws-sdk-go/service/s3/s3manager"
3131
)
3232

33+
var etag string = "myetag"
34+
3335
func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
3436
var m sync.Mutex
3537
names := []string{}
@@ -203,6 +205,93 @@ func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.S3, *[]string) {
203205
return svc, &names
204206
}
205207

208+
func dlLoggingSvcWithVersions(data []byte) (*s3.S3, *[]string, *[]string, *[]string) {
209+
var m sync.Mutex
210+
versions := []string{}
211+
etags := []string{}
212+
names := []string{}
213+
214+
svc := s3.New(unit.Session)
215+
svc.Handlers.Send.Clear()
216+
svc.Handlers.Send.PushBack(func(r *request.Request) {
217+
m.Lock()
218+
defer m.Unlock()
219+
220+
names = append(names, r.Operation.Name)
221+
versions = append(versions, aws.StringValue(r.Params.(*s3.GetObjectInput).VersionId))
222+
etags = append(etags, aws.StringValue(r.Params.(*s3.GetObjectInput).IfMatch))
223+
224+
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
225+
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
226+
start, _ := strconv.ParseInt(rng[1], 10, 64)
227+
fin, _ := strconv.ParseInt(rng[2], 10, 64)
228+
fin++
229+
230+
if fin > int64(len(data)) {
231+
fin = int64(len(data))
232+
}
233+
234+
bodyBytes := data[start:fin]
235+
r.HTTPResponse = &http.Response{
236+
StatusCode: 200,
237+
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
238+
Header: http.Header{},
239+
}
240+
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
241+
start, fin-1, len(data)))
242+
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
243+
r.HTTPResponse.Header.Set("Etag", etag)
244+
})
245+
246+
return svc, &names, &versions, &etags
247+
}
248+
249+
func dlLoggingSvcWithVersionMismatch(t *testing.T) *s3.S3 {
250+
var m sync.Mutex
251+
reqCount := int64(0)
252+
body := bytes.NewReader(make([]byte, s3manager.DefaultDownloadPartSize))
253+
svc := s3.New(unit.Session)
254+
svc.Handlers.Send.Clear()
255+
svc.Handlers.Send.PushBack(func(r *request.Request) {
256+
m.Lock()
257+
defer m.Unlock()
258+
259+
statusCode := http.StatusOK
260+
var eTag *string
261+
switch atomic.LoadInt64(&reqCount) {
262+
case 0:
263+
if a := r.Params.(*s3.GetObjectInput).IfMatch; a != nil {
264+
t.Errorf("expect no Etag in first request, got %s", aws.StringValue(a))
265+
statusCode = http.StatusBadRequest
266+
} else {
267+
eTag = aws.String(etag)
268+
}
269+
case 1:
270+
// Give a chance for the multipart chunks to be queued up
271+
time.Sleep(1 * time.Second)
272+
// mock the precondition error when object is synchronously updated
273+
statusCode = http.StatusPreconditionFailed
274+
default:
275+
if a := aws.StringValue(r.Params.(*s3.GetObjectInput).IfMatch); a != etag {
276+
t.Errorf("expect subrequests' IfMatch header to be %s, got %s", etag, a)
277+
statusCode = http.StatusBadRequest
278+
}
279+
}
280+
r.HTTPResponse = &http.Response{
281+
StatusCode: statusCode,
282+
Body: ioutil.NopCloser(body),
283+
Header: http.Header{},
284+
}
285+
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes 0-%d/%d",
286+
body.Len()-1, body.Len()*10))
287+
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", body.Len()))
288+
r.HTTPResponse.Header.Set("Etag", aws.StringValue(eTag))
289+
atomic.AddInt64(&reqCount, 1)
290+
})
291+
292+
return svc
293+
}
294+
206295
func TestDownloadOrder(t *testing.T) {
207296
s, names, ranges := dlLoggingSvc(buf12MB)
208297

@@ -526,6 +615,92 @@ func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
526615
}
527616
}
528617

618+
func TestDownloadWithVersionID(t *testing.T) {
619+
s, names, versions, etags := dlLoggingSvcWithVersions(buf12MB)
620+
d := s3manager.NewDownloaderWithClient(s)
621+
622+
w := &aws.WriteAtBuffer{}
623+
n, err := d.Download(w, &s3.GetObjectInput{
624+
Bucket: aws.String("bucket"),
625+
Key: aws.String("key"),
626+
VersionId: aws.String("vid"),
627+
})
628+
629+
if err != nil {
630+
t.Fatalf("expect no error, got %v", err)
631+
}
632+
633+
if e, a := int64(len(buf12MB)), n; e != a {
634+
t.Errorf("expect %d buffer length, got %d", e, a)
635+
}
636+
637+
expectCalls := []string{"GetObject", "GetObject", "GetObject"}
638+
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
639+
t.Errorf("expect %v API calls, got %v", e, a)
640+
}
641+
642+
expectVersions := []string{"vid", "vid", "vid"}
643+
if e, a := expectVersions, *versions; !reflect.DeepEqual(e, a) {
644+
t.Errorf("expect %v version ids, got %v", e, a)
645+
}
646+
647+
expectETags := []string{"", "", ""}
648+
if e, a := expectETags, *etags; !reflect.DeepEqual(e, a) {
649+
t.Errorf("expect %v etags, got %v", e, a)
650+
}
651+
}
652+
653+
func TestDownloadWithETags(t *testing.T) {
654+
s, names, versions, etags := dlLoggingSvcWithVersions(buf12MB)
655+
d := s3manager.NewDownloaderWithClient(s)
656+
657+
w := &aws.WriteAtBuffer{}
658+
n, err := d.Download(w, &s3.GetObjectInput{
659+
Bucket: aws.String("bucket"),
660+
Key: aws.String("key"),
661+
})
662+
663+
if err != nil {
664+
t.Fatalf("expect no error, got %v", err)
665+
}
666+
667+
if e, a := int64(len(buf12MB)), n; e != a {
668+
t.Errorf("expect %d buffer length, got %d", e, a)
669+
}
670+
671+
expectCalls := []string{"GetObject", "GetObject", "GetObject"}
672+
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
673+
t.Errorf("expect %v API calls, got %v", e, a)
674+
}
675+
676+
expectVersions := []string{"", "", ""}
677+
if e, a := expectVersions, *versions; !reflect.DeepEqual(e, a) {
678+
t.Errorf("expect %v version ids, got %v", e, a)
679+
}
680+
681+
expectETags := []string{"", etag, etag}
682+
if e, a := expectETags, *etags; !reflect.DeepEqual(e, a) {
683+
t.Errorf("expect %v etags, got %v", e, a)
684+
}
685+
}
686+
687+
func TestDownloadWithVersionMismatch(t *testing.T) {
688+
s := dlLoggingSvcWithVersionMismatch(t)
689+
d := s3manager.NewDownloaderWithClient(s)
690+
691+
w := &aws.WriteAtBuffer{}
692+
_, err := d.Download(w, &s3.GetObjectInput{
693+
Bucket: aws.String("bucket"),
694+
Key: aws.String("key"),
695+
})
696+
697+
if err == nil {
698+
t.Fatalf("expect error, got none")
699+
} else if e, a := "PreconditionFailed", err.Error(); !strings.Contains(a, e) {
700+
t.Fatalf("expect error message to contain %s, but did not %s", e, a)
701+
}
702+
}
703+
529704
func TestDownloadWithContextCanceled(t *testing.T) {
530705
d := s3manager.NewDownloader(unit.Session)
531706

0 commit comments

Comments
 (0)