diff --git a/fs.go b/fs.go index 274595af..35a52dfb 100644 --- a/fs.go +++ b/fs.go @@ -221,6 +221,14 @@ type ArchiveFS struct { Context context.Context // optional } +// context always return a context, preferring f.Context if not nil. +func (f ArchiveFS) context() context.Context { + if f.Context != nil { + return f.Context + } + return context.Background() +} + // Open opens the named file from within the archive. If name is "." then // the archive file itself will be opened as a directory file. func (f ArchiveFS) Open(name string) (fs.File, error) { @@ -312,7 +320,7 @@ func (f ArchiveFS) Open(name string) (fs.File, error) { inputStream = io.NewSectionReader(f.Stream, 0, f.Stream.Size()) } - err = f.Format.Extract(f.Context, inputStream, []string{name}, handler) + err = f.Format.Extract(f.context(), inputStream, []string{name}, handler) if err != nil && fsFile != nil { if ef, ok := fsFile.(extractedFile); ok { if ef.parentArchive != nil { @@ -377,7 +385,7 @@ func (f ArchiveFS) Stat(name string) (fs.FileInfo, error) { if f.Stream != nil { inputStream = io.NewSectionReader(f.Stream, 0, f.Stream.Size()) } - err = f.Format.Extract(f.Context, inputStream, []string{name}, handler) + err = f.Format.Extract(f.context(), inputStream, []string{name}, handler) if err != nil && result.FileInfo == nil { return nil, err } @@ -446,7 +454,7 @@ func (f ArchiveFS) ReadDir(name string) ([]fs.DirEntry, error) { inputStream = io.NewSectionReader(f.Stream, 0, f.Stream.Size()) } - err = f.Format.Extract(f.Context, inputStream, filter, handler) + err = f.Format.Extract(f.context(), inputStream, filter, handler) return entries, err } diff --git a/interfaces.go b/interfaces.go index 0fc88930..782cae64 100644 --- a/interfaces.go +++ b/interfaces.go @@ -56,10 +56,22 @@ type Decompressor interface { type Archiver interface { // Archive writes an archive file to output with the given files. // - // Context is optional, but if given, cancellation must be honored. + // Context cancellation must be honored. Archive(ctx context.Context, output io.Writer, files []File) error } +// ArchiverAsync is an Archiver that can also create archives +// asynchronously by pumping files into a channel as they are +// discovered. +type ArchiverAsync interface { + Archiver + + // Use ArchiveAsync if you can't pre-assemble a list of all + // the files for the archive. Close the files channel after + // all the files have been sent. + ArchiveAsync(ctx context.Context, output io.Writer, files <-chan File) error +} + // Extractor can extract files from an archive. type Extractor interface { // Extract reads the files at pathsInArchive from sourceArchive. @@ -68,7 +80,7 @@ type Extractor interface { // If a path refers to a directory, all files within it are extracted. // Extracted files are passed to the handleFile callback for handling. // - // Context is optional, but if given, cancellation must be honored. + // Context cancellation must be honored. Extract(ctx context.Context, sourceArchive io.Reader, pathsInArchive []string, handleFile FileHandler) error } @@ -76,6 +88,6 @@ type Extractor interface { type Inserter interface { // Insert inserts the files into archive. // - // Context is optional, but if given, cancellation must be honored. + // Context cancellation must be honored. Insert(ctx context.Context, archive io.ReadWriteSeeker, files []File) error } diff --git a/rar.go b/rar.go index 213bed1b..e41a192e 100644 --- a/rar.go +++ b/rar.go @@ -56,10 +56,6 @@ func (r Rar) Archive(_ context.Context, _ io.Writer, _ []File) error { } func (r Rar) Extract(ctx context.Context, sourceArchive io.Reader, pathsInArchive []string, handleFile FileHandler) error { - if ctx == nil { - ctx = context.Background() - } - var options []rardecode.Option if r.Password != "" { options = append(options, rardecode.Password(r.Password)) diff --git a/tar.go b/tar.go index dfe81ea7..803ea78b 100644 --- a/tar.go +++ b/tar.go @@ -42,19 +42,28 @@ func (t Tar) Match(filename string, stream io.Reader) (MatchResult, error) { } func (t Tar) Archive(ctx context.Context, output io.Writer, files []File) error { - if ctx == nil { - ctx = context.Background() - } - tw := tar.NewWriter(output) defer tw.Close() for _, file := range files { - if err := ctx.Err(); err != nil { - return err // honor context cancellation + if err := t.writeFileToArchive(ctx, tw, file); err != nil { + if t.ContinueOnError && ctx.Err() == nil { // context errors should always abort + log.Printf("[ERROR] %v", err) + continue + } + return err } - err := t.writeFileToArchive(ctx, tw, file) - if err != nil { + } + + return nil +} + +func (t Tar) ArchiveAsync(ctx context.Context, output io.Writer, files <-chan File) error { + tw := tar.NewWriter(output) + defer tw.Close() + + for file := range files { + if err := t.writeFileToArchive(ctx, tw, file); err != nil { if t.ContinueOnError && ctx.Err() == nil { // context errors should always abort log.Printf("[ERROR] %v", err) continue @@ -67,6 +76,10 @@ func (t Tar) Archive(ctx context.Context, output io.Writer, files []File) error } func (Tar) writeFileToArchive(ctx context.Context, tw *tar.Writer, file File) error { + if err := ctx.Err(); err != nil { + return err // honor context cancellation + } + hdr, err := tar.FileInfoHeader(file, file.LinkTarget) if err != nil { return fmt.Errorf("file %s: creating header: %w", file.NameInArchive, err) @@ -91,10 +104,6 @@ func (Tar) writeFileToArchive(ctx context.Context, tw *tar.Writer, file File) er } func (t Tar) Insert(ctx context.Context, into io.ReadWriteSeeker, files []File) error { - if ctx == nil { - ctx = context.Background() - } - // Tar files may end with some, none, or a lot of zero-byte padding. The spec says // it should end with two 512-byte trailer records consisting solely of null/0 // bytes: https://www.gnu.org/software/tar/manual/html_node/Standard.html. However, @@ -165,10 +174,6 @@ func (t Tar) Insert(ctx context.Context, into io.ReadWriteSeeker, files []File) } func (t Tar) Extract(ctx context.Context, sourceArchive io.Reader, pathsInArchive []string, handleFile FileHandler) error { - if ctx == nil { - ctx = context.Background() - } - tr := tar.NewReader(sourceArchive) // important to initialize to non-nil, empty value due to how fileIsIncluded works diff --git a/zip.go b/zip.go index fd4a4418..20fa4250 100644 --- a/zip.go +++ b/zip.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "io/fs" + "log" "path" "strings" @@ -101,54 +102,77 @@ func (z Zip) Match(filename string, stream io.Reader) (MatchResult, error) { } func (z Zip) Archive(ctx context.Context, output io.Writer, files []File) error { - if ctx == nil { - ctx = context.Background() - } - zw := zip.NewWriter(output) defer zw.Close() for i, file := range files { - if err := ctx.Err(); err != nil { - return err // honor context cancellation + if err := z.archiveOneFile(ctx, zw, i, file); err != nil { + return err } + } - hdr, err := zip.FileInfoHeader(file) - if err != nil { - return fmt.Errorf("getting info for file %d: %s: %w", i, file.Name(), err) - } - hdr.Name = file.NameInArchive // complete path, since FileInfoHeader() only has base name + return nil +} - // customize header based on file properties - if file.IsDir() { - if !strings.HasSuffix(hdr.Name, "/") { - hdr.Name += "/" // required - } - hdr.Method = zip.Store - } else if z.SelectiveCompression { - // only enable compression on compressable files - ext := strings.ToLower(path.Ext(hdr.Name)) - if _, ok := compressedFormats[ext]; ok { - hdr.Method = zip.Store - } else { - hdr.Method = z.Compression +func (z Zip) ArchiveAsync(ctx context.Context, output io.Writer, files <-chan File) error { + zw := zip.NewWriter(output) + defer zw.Close() + + var i int + for file := range files { + if err := z.archiveOneFile(ctx, zw, i, file); err != nil { + if z.ContinueOnError && ctx.Err() == nil { // context errors should always abort + log.Printf("[ERROR] %v", err) + continue } + return err } + i++ + } - w, err := zw.CreateHeader(hdr) - if err != nil { - return fmt.Errorf("creating header for file %d: %s: %w", i, file.Name(), err) - } + return nil +} - // directories have no file body - if file.IsDir() { - continue +func (z Zip) archiveOneFile(ctx context.Context, zw *zip.Writer, idx int, file File) error { + if err := ctx.Err(); err != nil { + return err // honor context cancellation + } + + hdr, err := zip.FileInfoHeader(file) + if err != nil { + return fmt.Errorf("getting info for file %d: %s: %w", idx, file.Name(), err) + } + hdr.Name = file.NameInArchive // complete path, since FileInfoHeader() only has base name + + // customize header based on file properties + if file.IsDir() { + if !strings.HasSuffix(hdr.Name, "/") { + hdr.Name += "/" // required } - if err := openAndCopyFile(file, w); err != nil { - return fmt.Errorf("writing file %d: %s: %w", i, file.Name(), err) + hdr.Method = zip.Store + } else if z.SelectiveCompression { + // only enable compression on compressable files + ext := strings.ToLower(path.Ext(hdr.Name)) + if _, ok := compressedFormats[ext]; ok { + hdr.Method = zip.Store + } else { + hdr.Method = z.Compression } } + w, err := zw.CreateHeader(hdr) + if err != nil { + return fmt.Errorf("creating header for file %d: %s: %w", idx, file.Name(), err) + } + + // directories have no file body + if file.IsDir() { + return nil + } + if err := openAndCopyFile(file, w); err != nil { + return fmt.Errorf("writing file %d: %s: %w", idx, file.Name(), err) + } + return nil } @@ -159,10 +183,6 @@ func (z Zip) Archive(ctx context.Context, output io.Writer, files []File) error // with. Due to the nature of the zip archive format, if sourceArchive is not an io.Seeker // and io.ReaderAt, an error is returned. func (z Zip) Extract(ctx context.Context, sourceArchive io.Reader, pathsInArchive []string, handleFile FileHandler) error { - if ctx == nil { - ctx = context.Background() - } - sra, ok := sourceArchive.(seekReaderAt) if !ok { return fmt.Errorf("input type must be an io.ReaderAt and io.Seeker because of zip format constraints")