diff --git a/.github/workflows/continuous-integration.yaml b/.github/workflows/continuous-integration.yaml index c639222e..fa5de61b 100644 --- a/.github/workflows/continuous-integration.yaml +++ b/.github/workflows/continuous-integration.yaml @@ -49,8 +49,8 @@ jobs: - name: Test code run: | - go test ./pkg/... - go test ./internal/... + go test -race ./pkg/... + go test -race ./internal/... shell: bash - diff --git a/go.mod b/go.mod index cea5c5a6..b278b776 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/vimeo/go-util v1.4.1 golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df golang.org/x/net v0.29.0 + golang.org/x/sync v0.8.0 google.golang.org/api v0.199.0 google.golang.org/grpc v1.67.0 google.golang.org/protobuf v1.34.2 @@ -95,7 +96,6 @@ require ( go.opentelemetry.io/otel/trace v1.29.0 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/text v0.18.0 // indirect golang.org/x/time v0.6.0 // indirect diff --git a/pkg/handler/body_reader.go b/pkg/handler/body_reader.go index ab9c8d3c..ed3a5f63 100644 --- a/pkg/handler/body_reader.go +++ b/pkg/handler/body_reader.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "strings" + "sync" "sync/atomic" "time" ) @@ -28,8 +29,11 @@ type bodyReader struct { bytesCounter int64 ctx *httpContext reader io.ReadCloser - err error onReadDone func() + + // lock protects concurrent access to err. + lock sync.RWMutex + err error } func newBodyReader(c *httpContext, maxSize int64) *bodyReader { @@ -41,7 +45,10 @@ func newBodyReader(c *httpContext, maxSize int64) *bodyReader { } func (r *bodyReader) Read(b []byte) (int, error) { - if r.err != nil { + r.lock.RLock() + hasErrored := r.err != nil + r.lock.RUnlock() + if hasErrored { return 0, io.EOF } @@ -99,20 +106,26 @@ func (r *bodyReader) Read(b []byte) (int, error) { // Other errors are stored for retrival with hasError, but is not returned // to the consumer. We do not overwrite an error if it has been set already. + r.lock.Lock() if r.err == nil { r.err = err } + r.lock.Unlock() } return n, nil } -func (r bodyReader) hasError() error { - if r.err == io.EOF { +func (r *bodyReader) hasError() error { + r.lock.RLock() + err := r.err + r.lock.RUnlock() + + if err == io.EOF { return nil } - return r.err + return err } func (r *bodyReader) bytesRead() int64 { @@ -120,7 +133,9 @@ func (r *bodyReader) bytesRead() int64 { } func (r *bodyReader) closeWithError(err error) { + r.lock.Lock() r.err = err + r.lock.Unlock() // SetReadDeadline with the current time causes concurrent reads to the body to time out, // so the body will be closed sooner with less delay. diff --git a/pkg/s3store/multi_error.go b/pkg/s3store/multi_error.go index a2d9e3db..78fd97ec 100644 --- a/pkg/s3store/multi_error.go +++ b/pkg/s3store/multi_error.go @@ -4,6 +4,7 @@ import ( "errors" ) +// TODO: Replace with errors.Join func newMultiError(errs []error) error { message := "Multiple errors occurred:\n" for _, err := range errs { diff --git a/pkg/s3store/s3store.go b/pkg/s3store/s3store.go index db4f29fd..e592a648 100644 --- a/pkg/s3store/s3store.go +++ b/pkg/s3store/s3store.go @@ -88,6 +88,7 @@ import ( "github.com/tus/tusd/v2/internal/uid" "github.com/tus/tusd/v2/pkg/handler" "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -469,8 +470,7 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re }() go partProducer.produce(producerCtx, optimalPartSize) - var wg sync.WaitGroup - var uploadErr error + var eg errgroup.Group for { // We acquire the semaphore before starting the goroutine to avoid @@ -497,10 +497,8 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re } upload.parts = append(upload.parts, part) - wg.Add(1) - go func(file io.ReadSeeker, part *s3Part, closePart func() error) { + eg.Go(func() error { defer upload.store.releaseUploadSemaphore() - defer wg.Done() t := time.Now() uploadPartInput := &s3.UploadPartInput{ @@ -509,39 +507,46 @@ func (upload *s3Upload) uploadParts(ctx context.Context, offset int64, src io.Re UploadId: aws.String(upload.multipartId), PartNumber: aws.Int32(part.number), } - etag, err := upload.putPartForUpload(ctx, uploadPartInput, file, part.size) + etag, err := upload.putPartForUpload(ctx, uploadPartInput, partfile, part.size) store.observeRequestDuration(t, metricUploadPart) - if err != nil { - uploadErr = err - } else { + if err == nil { part.etag = etag } - if cerr := closePart(); cerr != nil && uploadErr == nil { - uploadErr = cerr + + cerr := closePart() + if err != nil { + return err } - }(partfile, part, closePart) + if cerr != nil { + return cerr + } + return nil + }) } else { - wg.Add(1) - go func(file io.ReadSeeker, closePart func() error) { + eg.Go(func() error { defer upload.store.releaseUploadSemaphore() - defer wg.Done() - if err := store.putIncompletePartForUpload(ctx, upload.objectId, file); err != nil { - uploadErr = err + err := store.putIncompletePartForUpload(ctx, upload.objectId, partfile) + if err == nil { + upload.incompletePartSize = partsize } - if cerr := closePart(); cerr != nil && uploadErr == nil { - uploadErr = cerr + + cerr := closePart() + if err != nil { + return err + } + if cerr != nil { + return cerr } - upload.incompletePartSize = partsize - }(partfile, closePart) + return nil + }) } bytesUploaded += partsize nextPartNum += 1 } - wg.Wait() - + uploadErr := eg.Wait() if uploadErr != nil { return 0, uploadErr } @@ -969,47 +974,42 @@ func (upload *s3Upload) concatUsingDownload(ctx context.Context, partialUploads func (upload *s3Upload) concatUsingMultipart(ctx context.Context, partialUploads []handler.Upload) error { store := upload.store - numPartialUploads := len(partialUploads) - errs := make([]error, 0, numPartialUploads) + upload.parts = make([]*s3Part, len(partialUploads)) // Copy partial uploads concurrently - var wg sync.WaitGroup - wg.Add(numPartialUploads) + var eg errgroup.Group for i, partialUpload := range partialUploads { + // Part numbers must be in the range of 1 to 10000, inclusive. Since // slice indexes start at 0, we add 1 to ensure that i >= 1. partNumber := int32(i + 1) partialS3Upload := partialUpload.(*s3Upload) - upload.parts = append(upload.parts, &s3Part{ - number: partNumber, - size: -1, - etag: "", - }) - - go func(partNumber int32, sourceObject string) { - defer wg.Done() - + eg.Go(func() error { res, err := store.Service.UploadPartCopy(ctx, &s3.UploadPartCopyInput{ Bucket: aws.String(store.Bucket), Key: store.keyWithPrefix(upload.objectId), UploadId: aws.String(upload.multipartId), PartNumber: aws.Int32(partNumber), - CopySource: aws.String(store.Bucket + "/" + *store.keyWithPrefix(sourceObject)), + CopySource: aws.String(store.Bucket + "/" + *store.keyWithPrefix(partialS3Upload.objectId)), }) if err != nil { - errs = append(errs, err) - return + return err } - upload.parts[partNumber-1].etag = *res.CopyPartResult.ETag - }(partNumber, partialS3Upload.objectId) - } + upload.parts[partNumber-1] = &s3Part{ + number: partNumber, + size: -1, // -1 is fine here bcause FinishUpload does not need this info. + etag: *res.CopyPartResult.ETag, + } - wg.Wait() + return nil + }) + } - if len(errs) > 0 { - return newMultiError(errs) + err := eg.Wait() + if err != nil { + return err } return upload.FinishUpload(ctx)