Skip to content

Commit

Permalink
Improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mgomes committed Dec 8, 2024
1 parent 100f9e5 commit 0cb873d
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 49 deletions.
19 changes: 16 additions & 3 deletions cmd/dl/dl_part.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,42 @@ func (p *downloadPart) downloadPartFilename() string {
return path.Join(p.dir, fmt.Sprintf("download.part%d", p.index))
}

func (p *downloadPart) fetchPart(wg *sync.WaitGroup, bar *progressbar.ProgressBar) {
func (p *downloadPart) fetchPart(wg *sync.WaitGroup, bar *progressbar.ProgressBar, errCh chan<- error) {
defer wg.Done()

byteRange := fmt.Sprintf("bytes=%d-%d", p.startByte, p.endByte)
req, _ := http.NewRequest("GET", p.uri, nil)
req, err := http.NewRequest("GET", p.uri, nil)
if err != nil {
errCh <- fmt.Errorf("failed to create request for part %d: %w", p.index, err)
return
}
req.Header.Set("Range", byteRange)
req.Header.Set("User-Agent", "dl/1.0")

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
errCh <- fmt.Errorf("failed to download part %d: %w", p.index, err)
return
}
defer resp.Body.Close()

if resp.StatusCode < 200 || resp.StatusCode > 299 {
errCh <- fmt.Errorf("non-2xx status (%d) for part %d", resp.StatusCode, p.index)
return
}

// Create the file
filename := p.downloadPartFilename()
out, err := os.Create(filename)
if err != nil {
errCh <- fmt.Errorf("failed to create file for part %d: %w", p.index, err)
return
}
defer out.Close()

// Write the body to file
_, _ = io.Copy(io.MultiWriter(out, bar), resp.Body)
if _, copyErr := io.Copy(io.MultiWriter(out, bar), resp.Body); copyErr != nil {
errCh <- fmt.Errorf("error writing part %d to file: %w", p.index, copyErr)
}
}
161 changes: 115 additions & 46 deletions cmd/dl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ import (
)

type download struct {
uri string
filesize uint64
filename string
workingDir string
boost int
parts []downloadPart
uri string
filesize uint64
filename string
workingDir string
boost int
parts []downloadPart
supportsRange bool
}

func main() {
Expand All @@ -32,79 +33,105 @@ func main() {

flag.Parse()

file_uris := flag.Args()

var err error
fileURIs := flag.Args()
if len(fileURIs) == 0 {
fmt.Fprintln(os.Stderr, "No URI provided.")
os.Exit(1)
}

for _, uri := range file_uris {
for _, uri := range fileURIs {
var dl download
dl.uri = uri
dl.boost = *boostPtr

err = dl.FetchMetadata()
if err != nil {
panic(err)
// Fetch file metadata
if err := dl.FetchMetadata(); err != nil {
fmt.Fprintf(os.Stderr, "Error fetching metadata: %v\n", err)
os.Exit(1)
}

// Use filename from args if specified
if *filenamePtr != "" {
dl.filename = *filenamePtr
}

// Determine working directory
if *workingDirPtr != "" {
dl.workingDir = *workingDirPtr
} else {
dl.workingDir, err = os.Getwd()
wd, err := os.Getwd()
if err != nil {
panic(err)
fmt.Fprintf(os.Stderr, "Error getting working directory: %v\n", err)
os.Exit(1)
}
dl.workingDir = wd
}

// Signal handling for cleanup
sigc := make(chan os.Signal, 1)
signal.Notify(sigc,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT)
signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go func() {
sig := <-sigc
fmt.Printf("\n%s; cleaning up...\n", sig)
fmt.Printf("\nReceived signal %s; cleaning up...\n", sig)
dl.cleanupParts()
os.Exit(0)
os.Exit(1)
}()

fmt.Println(dl.filename)
fmt.Println("Downloading:", dl.filename)

// If the server does not support partial downloads and boost > 1, fallback to a single download stream
if !dl.supportsRange && dl.boost > 1 {
fmt.Println("Server does not support partial content. Falling back to single-threaded download.")
dl.boost = 1
}

if err := dl.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "Error while downloading parts: %v\n", err)
dl.cleanupParts()
os.Exit(1)
}

if err := dl.ConcatFiles(); err != nil {
fmt.Fprintf(os.Stderr, "Error combining files: %v\n", err)
dl.cleanupParts()
os.Exit(1)
}

dl.Fetch()
dl.ConcatFiles()
fmt.Println("Download completed:", dl.filename)
}
}

func (dl *download) FetchMetadata() error {
resp, err := http.Head(dl.uri)
if err != nil {
return err
return fmt.Errorf("HEAD request failed: %w", err)
}
defer resp.Body.Close()

contentLength := resp.Header.Get("Content-Length")
if contentLength == "" {
return fmt.Errorf("missing Content-Length header, cannot determine file size")
}

dl.filesize, err = strconv.ParseUint(contentLength, 0, 64)
if err != nil {
return err
return fmt.Errorf("invalid Content-Length: %w", err)
}

// Check if server supports range requests
acceptRanges := resp.Header.Get("Accept-Ranges")
dl.supportsRange = (strings.ToLower(acceptRanges) == "bytes")

contentDisposition := resp.Header.Get("Content-Disposition")
_, params, err := mime.ParseMediaType(contentDisposition)
if err != nil {
// If we fail to parse or filename not found, fallback to extracting from URI
dl.filename = dl.filenameFromURI()
return err
} else {
dl.filename = params["filename"]
}

// No filename specified in the header; use the pathname
if dl.filename == "" {
dl.filename = dl.filenameFromURI()
if dl.filename == "" {
dl.filename = dl.filenameFromURI()
}
}

return nil
Expand All @@ -113,11 +140,39 @@ func (dl *download) FetchMetadata() error {
func (dl *download) Fetch() error {
var wg sync.WaitGroup

errCh := make(chan error, dl.boost) // Collect errors from goroutines
defer close(errCh)

bar := progressbar.DefaultBytes(
int64(dl.filesize),
"Downloading",
)

// If boost == 1 or server does not support ranges, just download whole file at once
if dl.boost == 1 {
part := downloadPart{
index: 0,
uri: dl.uri,
dir: dl.workingDir,
startByte: 0,
endByte: dl.filesize - 1,
}
part.filename = part.downloadPartFilename()
dl.parts = append(dl.parts, part)

wg.Add(1)
go part.fetchPart(&wg, bar, errCh)
wg.Wait()

select {
case err := <-errCh:
return err
default:
return nil
}
}

// Multi-part download
for i := 0; i < dl.boost; i++ {
start, end := dl.calculatePartBoundary(i)
wg.Add(1)
Expand All @@ -130,23 +185,26 @@ func (dl *download) Fetch() error {
}
dlPart.filename = dlPart.downloadPartFilename()
dl.parts = append(dl.parts, dlPart)
go dlPart.fetchPart(&wg, bar)
go dlPart.fetchPart(&wg, bar, errCh)
}

wg.Wait()
return nil

// Check for errors
select {
case err := <-errCh:
return err
default:
return nil
}
}

func (dl *download) calculatePartBoundary(part int) (startByte uint64, endByte uint64) {
chunkSize := dl.filesize / uint64(dl.boost)
var previousEndByte uint64

if part == 0 {
startByte = 0
previousEndByte = 0
} else {
previousEndByte = uint64(part)*chunkSize - 1
startByte = previousEndByte + 1
startByte = uint64(part) * chunkSize
}

// For the last part, pick up all remaining bytes
Expand All @@ -164,10 +222,15 @@ func (dl *download) filenameFromURI() string {
return splitURI[len(splitURI)-1]
}

func (dl *download) ConcatFiles() {
var readers []io.Reader
func (dl *download) ConcatFiles() error {
// Verify that all parts exist
for _, part := range dl.parts {
if _, err := os.Stat(part.downloadPartFilename()); err != nil {
return fmt.Errorf("missing part file: %s, error: %w", part.downloadPartFilename(), err)
}
}

defer dl.cleanupParts()
var readers []io.Reader

bar := progressbar.DefaultBytes(
int64(dl.filesize),
Expand All @@ -177,7 +240,7 @@ func (dl *download) ConcatFiles() {
for _, part := range dl.parts {
downloadPart, err := os.Open(part.downloadPartFilename())
if err != nil {
panic(err)
return fmt.Errorf("error opening part file: %w", err)
}
defer downloadPart.Close()
readers = append(readers, downloadPart)
Expand All @@ -187,13 +250,19 @@ func (dl *download) ConcatFiles() {

outFile, err := os.Create(dl.filename)
if err != nil {
panic(err)
return fmt.Errorf("error creating output file: %w", err)
}
defer outFile.Close()

_, err = io.Copy(io.MultiWriter(outFile, bar), inputFiles)
if err != nil {
panic(err)
return fmt.Errorf("error concatenating files: %w", err)
}

// Cleanup only after successful concatenation
dl.cleanupParts()

return nil
}

func (dl *download) cleanupParts() {
Expand Down

0 comments on commit 0cb873d

Please sign in to comment.