diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 9acbb621737d..2e0363c836dc 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -2,7 +2,9 @@ package downloader import ( "crypto/sha256" + "errors" "fmt" + "hash" "io" "net/http" "net/url" @@ -204,6 +206,25 @@ func removePartialFile(tmpFilePath string) error { return nil } +func calculateHashForPartialFile(file *os.File) (hash.Hash, error) { + hash := sha256.New() + _, err := io.Copy(hash, file) + if err != nil { + return nil, err + } + return hash, nil +} + +func (uri URI) checkSeverSupportsRangeHeader() (bool, error) { + url := uri.ResolveURL() + resp, err := http.Head(url) + if err != nil { + return false, err + } + defer resp.Body.Close() + return resp.Header.Get("Accept-Ranges") == "bytes", nil +} + func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error { url := uri.ResolveURL() if uri.LooksLikeOCI() { @@ -266,8 +287,34 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat log.Info().Msgf("Downloading %q", url) - // Download file - resp, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create request for %q: %v", filePath, err) + } + + // save partial download to dedicated file + tmpFilePath := filePath + ".partial" + tmpFileInfo, err := os.Stat(tmpFilePath) + if err == nil { + support, err := uri.checkSeverSupportsRangeHeader() + if err != nil { + return fmt.Errorf("failed to check if uri server supports range header: %v", err) + } + if support { + startPos := tmpFileInfo.Size() + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos)) + } else { + err := removePartialFile(tmpFilePath) + if err != nil { + return err + } + } + } else if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to check file %q existence: %v", filePath, err) + } + + // Start the request + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("failed to download file %q: %v", filePath, err) } @@ -283,26 +330,20 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err) } - // save partial download to dedicated file - tmpFilePath := filePath + ".partial" - - // remove tmp file - err = removePartialFile(tmpFilePath) + // Create and write file + outFile, err := os.OpenFile(tmpFilePath, os.O_APPEND|os.O_RDWR|os.O_CREATE, 0644) if err != nil { - return err + return fmt.Errorf("failed to create / open file %q: %v", tmpFilePath, err) } - - // Create and write file content - outFile, err := os.Create(tmpFilePath) + defer outFile.Close() + hash, err := calculateHashForPartialFile(outFile) if err != nil { - return fmt.Errorf("failed to create file %q: %v", tmpFilePath, err) + return fmt.Errorf("failed to calculate hash for partial file") } - defer outFile.Close() - progress := &progressWriter{ fileName: tmpFilePath, total: resp.ContentLength, - hash: sha256.New(), + hash: hash, fileNo: fileN, totalFiles: total, downloadStatus: downloadStatus, diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index 3b7a80b3ee9a..6976c9b44bf7 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -1,6 +1,15 @@ package downloader_test import ( + "crypto/rand" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "os" + "regexp" + "strconv" + . "github.com/mudler/LocalAI/pkg/downloader" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -38,3 +47,139 @@ var _ = Describe("Gallery API tests", func() { }) }) }) + +type RangeHeaderError struct { + msg string +} + +func (e *RangeHeaderError) Error() string { return e.msg } + +var _ = Describe("Download Test", func() { + var mockData []byte + var mockDataSha string + var filePath string + + extractRangeHeader := func(rangeString string) (int, int, error) { + regex := regexp.MustCompile(`^bytes=(\d+)-(\d+|)$`) + matches := regex.FindStringSubmatch(rangeString) + rangeErr := RangeHeaderError{msg: "invalid / ill-formatted range"} + if matches == nil { + return -1, -1, &rangeErr + } + startPos, err := strconv.Atoi(matches[1]) + if err != nil { + return -1, -1, err + } + + endPos := -1 + if matches[2] != "" { + endPos, err = strconv.Atoi(matches[2]) + if err != nil { + return -1, -1, err + } + endPos += 1 // because range is inclusive in rangeString + } + return startPos, endPos, nil + } + + getMockServer := func(supportsRangeHeader bool) *httptest.Server { + mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" && r.Method != "GET" { + w.WriteHeader(http.StatusNotFound) + return + } + if r.Method == "HEAD" { + if supportsRangeHeader { + w.Header().Add("Accept-Ranges", "bytes") + } + w.WriteHeader(http.StatusOK) + return + } + // GET method + startPos := 0 + endPos := len(mockData) + var err error + var respData []byte + rangeString := r.Header.Get("Range") + if rangeString != "" { + startPos, endPos, err = extractRangeHeader(rangeString) + if err != nil { + if _, ok := err.(*RangeHeaderError); ok { + w.WriteHeader(http.StatusBadRequest) + return + } + Expect(err).ToNot(HaveOccurred()) + } + if endPos == -1 { + endPos = len(mockData) + } + if startPos < 0 || startPos >= len(mockData) || endPos < 0 || endPos > len(mockData) || startPos > endPos { + w.WriteHeader(http.StatusBadRequest) + return + } + } + respData = mockData[startPos:endPos] + w.WriteHeader(http.StatusOK) + w.Write(respData) + })) + mockServer.EnableHTTP2 = true + mockServer.Start() + return mockServer + } + + BeforeEach(func() { + mockData = make([]byte, 20000) + _, err := rand.Read(mockData) + Expect(err).ToNot(HaveOccurred()) + _mockDataSha := sha256.New() + _, err = _mockDataSha.Write(mockData) + Expect(err).ToNot(HaveOccurred()) + mockDataSha = fmt.Sprintf("%x", _mockDataSha.Sum(nil)) + dir, err := os.Getwd() + filePath = dir + "/my_supercool_model" + Expect(err).NotTo(HaveOccurred()) + }) + + Context("URI DownloadFile", func() { + It("fetches files from mock server", func() { + mockServer := getMockServer(true) + defer mockServer.Close() + uri := URI(mockServer.URL) + err := uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + }) + + It("resumes partially downloaded files", func() { + mockServer := getMockServer(true) + defer mockServer.Close() + uri := URI(mockServer.URL) + // Create a partial file + tmpFilePath := filePath + ".partial" + file, err := os.OpenFile(tmpFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + Expect(err).ToNot(HaveOccurred()) + _, err = file.Write(mockData[0:10000]) + Expect(err).ToNot(HaveOccurred()) + err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + }) + + It("restarts download from 0 if server doesn't support Range header", func() { + mockServer := getMockServer(false) + defer mockServer.Close() + uri := URI(mockServer.URL) + // Create a partial file + tmpFilePath := filePath + ".partial" + file, err := os.OpenFile(tmpFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + Expect(err).ToNot(HaveOccurred()) + _, err = file.Write(mockData[0:10000]) + Expect(err).ToNot(HaveOccurred()) + err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + AfterEach(func() { + os.Remove(filePath) // cleanup, also checks existance of filePath` + os.Remove(filePath + ".partial") + }) +})