From fa04b2491e45b77fc97491fa02c7c65683e7065a Mon Sep 17 00:00:00 2001 From: "Jorge E. Moreira" Date: Wed, 25 Oct 2023 11:39:43 -0700 Subject: [PATCH] Make host orchestrator client standalone so that it can be used outside the context of the cloud orchestrator The follwing changes were implemented: - Be more specific on what requests require retries - Make the credentials header configurable - Make file upload configuration optional - Reduce the timeout for retries on some tests --- pkg/cli/cli.go | 15 +- pkg/cli/cli_test.go | 8 +- pkg/cli/cvd.go | 8 +- pkg/client/client.go | 61 +++++--- pkg/client/client_test.go | 2 + pkg/client/host_orchestrator_client.go | 127 ++++++++++----- pkg/client/host_orchestrator_client_test.go | 41 +++-- pkg/client/httputils.go | 165 +++++++++++--------- 8 files changed, 256 insertions(+), 171 deletions(-) diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index c9dc94bc..d0498e1c 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -1093,14 +1093,13 @@ func buildServiceBuilder(builder client.ServiceBuilder) serviceBuilder { dumpOut = c.ErrOrStderr() } opts := &client.ServiceOptions{ - RootEndpoint: buildServiceRootEndpoint(flags.ServiceURL, flags.Zone), - ProxyURL: proxyURL, - DumpOut: dumpOut, - ErrOut: c.ErrOrStderr(), - RetryAttempts: 3, - RetryDelay: 5 * time.Second, - ChunkSizeBytes: chunkSizeBytes, - ChunkUploadBackOffOpts: client.DefaultChunkUploadBackOffOpts(), + RootEndpoint: buildServiceRootEndpoint(flags.ServiceURL, flags.Zone), + ProxyURL: proxyURL, + DumpOut: dumpOut, + ErrOut: c.ErrOrStderr(), + RetryAttempts: 3, + RetryDelay: 5 * time.Second, + ChunkSizeBytes: chunkSizeBytes, } return builder(opts) } diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index 03e6202a..b0705fb0 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -126,11 +126,11 @@ func (fakeHostService) ConnectWebRTC(device string, observer wclient.Observer, l return nil, nil } -func (fakeHostService) FetchArtifacts(req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) { +func (fakeHostService) FetchArtifacts(req *hoapi.FetchArtifactsRequest, creds string) (*hoapi.FetchArtifactsResponse, error) { return &hoapi.FetchArtifactsResponse{AndroidCIBundle: &hoapi.AndroidCIBundle{}}, nil } -func (fakeHostService) CreateCVD(req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) { +func (fakeHostService) CreateCVD(req *hoapi.CreateCVDRequest, creds string) (*hoapi.CreateCVDResponse, error) { return &hoapi.CreateCVDResponse{CVDs: []*hoapi.CVD{{Name: "cvd-1"}}}, nil } @@ -146,6 +146,10 @@ func (fakeHostService) UploadFiles(uploadDir string, filenames []string) error { return nil } +func (fakeHostService) UploadFilesWithOptions(uploadDir string, filenames []string, options client.UploadOptions) error { + return nil +} + func (fakeHostService) DownloadRuntimeArtifacts(dst io.Writer) error { return nil } diff --git a/pkg/cli/cvd.go b/pkg/cli/cvd.go index 3e7e8ff8..a800fc7b 100644 --- a/pkg/cli/cvd.go +++ b/pkg/cli/cvd.go @@ -148,7 +148,7 @@ func (c *cvdCreator) createCVDFromLocalBuild() ([]*hoapi.CVD, error) { }, AdditionalInstancesNum: c.opts.AdditionalInstancesNum(), } - res, err := c.service.HostService(c.opts.Host).CreateCVD(&req) + res, err := c.service.HostService(c.opts.Host).CreateCVD(&req, client.InjectedCredentials) if err != nil { return nil, err } @@ -173,7 +173,7 @@ func (c *cvdCreator) createWithCanonicalConfig() ([]*hoapi.CVD, error) { EnvConfig: c.opts.EnvConfig, } c.statePrinter.Print(stateMsgFetchAndStart) - res, err := c.service.HostService(c.opts.Host).CreateCVD(createReq) + res, err := c.service.HostService(c.opts.Host).CreateCVD(createReq, client.InjectedCredentials) c.statePrinter.PrintDone(stateMsgFetchAndStart, err) if err != nil { return nil, err @@ -197,7 +197,7 @@ func (c *cvdCreator) createWithOpts() ([]*hoapi.CVD, error) { AndroidCIBundle: &hoapi.AndroidCIBundle{Build: mainBuild, Type: hoapi.MainBundleType}, } c.statePrinter.Print(stateMsgFetchMainBundle) - fetchMainBuildRes, err := c.service.HostService(c.opts.Host).FetchArtifacts(fetchReq) + fetchMainBuildRes, err := c.service.HostService(c.opts.Host).FetchArtifacts(fetchReq, client.InjectedCredentials) c.statePrinter.PrintDone(stateMsgFetchMainBundle, err) if err != nil { return nil, err @@ -216,7 +216,7 @@ func (c *cvdCreator) createWithOpts() ([]*hoapi.CVD, error) { AdditionalInstancesNum: c.opts.AdditionalInstancesNum(), } c.statePrinter.Print(stateMsgStartCVD) - res, err := c.service.HostService(c.opts.Host).CreateCVD(createReq) + res, err := c.service.HostService(c.opts.Host).CreateCVD(createReq, client.InjectedCredentials) c.statePrinter.PrintDone(stateMsgStartCVD, err) if err != nil { return nil, err diff --git a/pkg/client/client.go b/pkg/client/client.go index 130c0398..740980d8 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -28,11 +28,11 @@ import ( "github.com/hashicorp/go-multierror" ) -type OpTimeoutError string - -func (s OpTimeoutError) Error() string { - return fmt.Sprintf("waiting for operation %q timed out", string(s)) -} +const ( + // Value to pass as credentials to the Host Orchestrator service endpoints. Any non-empty value is enough. + InjectedCredentials = "inject" + headerNameCOInjectBuildAPICreds = "X-Cutf-Cloud-Orchestrator-Inject-BuildAPI-Creds" +) type ApiCallError struct { Code int `json:"code,omitempty"` @@ -54,14 +54,13 @@ func (e *ApiCallError) Is(target error) bool { } type ServiceOptions struct { - RootEndpoint string - ProxyURL string - DumpOut io.Writer - ErrOut io.Writer - RetryAttempts int - RetryDelay time.Duration - ChunkSizeBytes int64 - ChunkUploadBackOffOpts BackOffOpts + RootEndpoint string + ProxyURL string + DumpOut io.Writer + ErrOut io.Writer + RetryAttempts int + RetryDelay time.Duration + ChunkSizeBytes int64 } type Service interface { @@ -98,8 +97,6 @@ func NewService(opts *ServiceOptions) (Service, error) { httpHelper: HTTPHelper{ Client: httpClient, RootEndpoint: opts.RootEndpoint, - Retries: uint(opts.RetryAttempts), - RetryDelay: opts.RetryDelay, Dumpster: opts.DumpOut, }, }, nil @@ -114,6 +111,20 @@ func (c *serviceImpl) CreateHost(req *apiv1.CreateHostRequest) (*apiv1.HostInsta if err := c.waitForOperation(&op, ins); err != nil { return nil, err } + + // There is a short delay between the creation of the host and the availability of the host + // orchestrator. This call ensures the host orchestrator had time to start before returning + // from the this function. + retryOpts := RetryOptions{ + StatusCodes: []int{http.StatusBadGateway}, + NumRetries: 3, + RetryDelay: 5 * time.Second, + } + hostPath := fmt.Sprintf("/hosts/%s/", ins.Name) + if err := c.httpHelper.NewGetRequest(hostPath).DoWithRetries(nil, retryOpts); err != nil { + return nil, fmt.Errorf("Unable to communicate with host orchestrator: %w", err) + } + return ins, nil } @@ -146,22 +157,26 @@ func (c *serviceImpl) DeleteHosts(names []string) error { func (c *serviceImpl) waitForOperation(op *apiv1.Operation, res any) error { path := "/operations/" + op.Name + "/:wait" - return c.httpHelper.NewPostRequest(path, nil).Do(res) + retryOpts := RetryOptions{ + []int{http.StatusServiceUnavailable}, + uint(c.ServiceOptions.RetryAttempts), + c.RetryDelay, + } + return c.httpHelper.NewPostRequest(path, nil).DoWithRetries(res, retryOpts) } -const headerNameCOInjectBuildAPICreds = "X-Cutf-Cloud-Orchestrator-Inject-BuildAPI-Creds" - func (s *serviceImpl) RootURI() string { return s.RootEndpoint } func (s *serviceImpl) HostService(host string) HostOrchestratorService { - hs := &hostOrchestratorServiceImpl{ - httpHelper: s.httpHelper, - ChunkSizeBytes: s.ChunkSizeBytes, - ChunkUploadBackOffOpts: s.ChunkUploadBackOffOpts, + hs := &HostOrchestratorServiceImpl{ + HTTPHelper: s.httpHelper, + WaitRetries: uint(s.ServiceOptions.RetryAttempts), + WaitRetryDelay: s.ServiceOptions.RetryDelay, + BuildAPICredentialsHeader: headerNameCOInjectBuildAPICreds, } - hs.httpHelper.RootEndpoint = s.httpHelper.RootEndpoint + "/hosts/" + host + hs.HTTPHelper.RootEndpoint = s.httpHelper.RootEndpoint + "/hosts/" + host return hs } diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 185f611d..42db40ab 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -44,6 +44,8 @@ func TestRetryLogic(t *testing.T) { return } writeOK(w, &apiv1.HostInstance{Name: "foo"}) + case "GET /hosts/foo/": + writeOK(w, make(map[string]any)) default: t.Fatal("unexpected endpoint: " + ep) } diff --git a/pkg/client/host_orchestrator_client.go b/pkg/client/host_orchestrator_client.go index caed50f9..d6499d01 100644 --- a/pkg/client/host_orchestrator_client.go +++ b/pkg/client/host_orchestrator_client.go @@ -39,11 +39,15 @@ type HostOrchestratorService interface { // Uploads user files to a previously created directory. UploadFiles(uploadDir string, filenames []string) error + UploadFilesWithOptions(uploadDir string, filenames []string, backOffOpts UploadOptions) error // Create a new device with artifacts from the build server or previously uploaded by the user. - CreateCVD(req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) + // If not empty, the provided credentials will be use to download any necessary artifacts from the build api. + CreateCVD(req *hoapi.CreateCVDRequest, buildAPICredentials string) (*hoapi.CreateCVDResponse, error) - FetchArtifacts(req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) + // Calls cvd fetch in the remote host, the downloaded artifacts can be used to create a CVD later. + // If not empty, the provided credentials will be used by the host orchestrator to access the build api. + FetchArtifacts(req *hoapi.FetchArtifactsRequest, buildAPICredentials string) (*hoapi.FetchArtifactsResponse, error) // Downloads runtime artifacts tar file into `dst`. DownloadRuntimeArtifacts(dst io.Writer) error @@ -52,21 +56,36 @@ type HostOrchestratorService interface { ConnectWebRTC(device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) } -type hostOrchestratorServiceImpl struct { - httpHelper HTTPHelper - ChunkSizeBytes int64 - ChunkUploadBackOffOpts BackOffOpts +const defaultHostOrchestratorCredentialsHeader = "X-Cutf-Host-Orchestrator-BuildAPI-Creds" + +func NewHostOrchestratoService(url string) HostOrchestratorService { + return &HostOrchestratorServiceImpl{ + HTTPHelper: HTTPHelper{ + Client: http.DefaultClient, + RootEndpoint: url, + }, + WaitRetries: 3, + WaitRetryDelay: 5 * time.Second, + BuildAPICredentialsHeader: defaultHostOrchestratorCredentialsHeader, + } +} + +type HostOrchestratorServiceImpl struct { + HTTPHelper HTTPHelper + WaitRetries uint + WaitRetryDelay time.Duration + BuildAPICredentialsHeader string } -func (c *hostOrchestratorServiceImpl) getInfraConfig() (*hoapi.InfraConfig, error) { +func (c *HostOrchestratorServiceImpl) getInfraConfig() (*hoapi.InfraConfig, error) { var res hoapi.InfraConfig - if err := c.httpHelper.NewGetRequest(fmt.Sprintf("/infra_config")).Do(&res); err != nil { + if err := c.HTTPHelper.NewGetRequest(fmt.Sprintf("/infra_config")).Do(&res); err != nil { return nil, err } return &res, nil } -func (c *hostOrchestratorServiceImpl) ConnectWebRTC(device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) { +func (c *HostOrchestratorServiceImpl) ConnectWebRTC(device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) { polledConn, err := c.createPolledConnection(device) if err != nil { return nil, fmt.Errorf("Failed to create polled connection: %w", err) @@ -88,7 +107,7 @@ func (c *hostOrchestratorServiceImpl) ConnectWebRTC(device string, observer wcli return conn, nil } -func (c *hostOrchestratorServiceImpl) initHandling(connID string, iceServers []webrtc.ICEServer, logger io.Writer) wclient.Signaling { +func (c *HostOrchestratorServiceImpl) initHandling(connID string, iceServers []webrtc.ICEServer, logger io.Writer) wclient.Signaling { sendCh := make(chan any) recvCh := make(chan map[string]any) @@ -113,14 +132,14 @@ const ( maxConsecutiveErrors = 10 ) -func (c *hostOrchestratorServiceImpl) webRTCPoll(sinkCh chan map[string]any, connID string, stopCh chan bool, logger io.Writer) { +func (c *HostOrchestratorServiceImpl) webRTCPoll(sinkCh chan map[string]any, connID string, stopCh chan bool, logger io.Writer) { start := 0 pollInterval := initialPollInterval errCount := 0 for { path := fmt.Sprintf("/polled_connections/%s/messages?start=%d", connID, start) var messages []map[string]any - if err := c.httpHelper.NewGetRequest(path).Do(&messages); err != nil { + if err := c.HTTPHelper.NewGetRequest(path).Do(&messages); err != nil { fmt.Fprintf(logger, "Error polling messages: %v\n", err) errCount++ if errCount >= maxConsecutiveErrors { @@ -158,7 +177,7 @@ func (c *hostOrchestratorServiceImpl) webRTCPoll(sinkCh chan map[string]any, con } } -func (c *hostOrchestratorServiceImpl) webRTCForward(srcCh chan any, connID string, stopPollCh chan bool, logger io.Writer) { +func (c *HostOrchestratorServiceImpl) webRTCForward(srcCh chan any, connID string, stopPollCh chan bool, logger io.Writer) { for { msg, open := <-srcCh if !open { @@ -170,7 +189,7 @@ func (c *hostOrchestratorServiceImpl) webRTCForward(srcCh chan any, connID strin path := fmt.Sprintf("/polled_connections/%s/:forward", connID) i := 0 for ; i < maxConsecutiveErrors; i++ { - rb := c.httpHelper.NewPostRequest(path, &forwardMsg) + rb := c.HTTPHelper.NewPostRequest(path, &forwardMsg) if err := rb.Do(nil); err != nil { fmt.Fprintf(logger, "Error sending message to device: %v\n", err) } else { @@ -185,25 +204,31 @@ func (c *hostOrchestratorServiceImpl) webRTCForward(srcCh chan any, connID strin } } -func (c *hostOrchestratorServiceImpl) createPolledConnection(device string) (*hoapi.NewConnReply, error) { +func (c *HostOrchestratorServiceImpl) createPolledConnection(device string) (*hoapi.NewConnReply, error) { var res hoapi.NewConnReply - rb := c.httpHelper.NewPostRequest("/polled_connections", &hoapi.NewConnMsg{DeviceId: device}) + rb := c.HTTPHelper.NewPostRequest("/polled_connections", &hoapi.NewConnMsg{DeviceId: device}) if err := rb.Do(&res); err != nil { return nil, err } return &res, nil } -func (c *hostOrchestratorServiceImpl) waitForOperation(op *hoapi.Operation, res any) error { +func (c *HostOrchestratorServiceImpl) waitForOperation(op *hoapi.Operation, res any) error { path := "/operations/" + op.Name + "/:wait" - return c.httpHelper.NewPostRequest(path, nil).Do(res) + retryOpts := RetryOptions{ + StatusCodes: []int{http.StatusServiceUnavailable}, + NumRetries: c.WaitRetries, + RetryDelay: c.WaitRetryDelay, + } + return c.HTTPHelper.NewPostRequest(path, nil).DoWithRetries(res, retryOpts) } -func (c *hostOrchestratorServiceImpl) FetchArtifacts(req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) { +func (c *HostOrchestratorServiceImpl) FetchArtifacts(req *hoapi.FetchArtifactsRequest, creds string) (*hoapi.FetchArtifactsResponse, error) { var op hoapi.Operation - rb := c.httpHelper.NewPostRequest("/artifacts", req) - // Cloud Orchestrator only checks for the presence of the header, hence an empty string value is ok. - rb.Header.Add(headerNameCOInjectBuildAPICreds, "") + rb := c.HTTPHelper.NewPostRequest("/artifacts", req) + if creds != "" { + rb.AddHeader(c.BuildAPICredentialsHeader, creds) + } if err := rb.Do(&op); err != nil { return nil, err } @@ -215,11 +240,12 @@ func (c *hostOrchestratorServiceImpl) FetchArtifacts(req *hoapi.FetchArtifactsRe return res, nil } -func (c *hostOrchestratorServiceImpl) CreateCVD(req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) { +func (c *HostOrchestratorServiceImpl) CreateCVD(req *hoapi.CreateCVDRequest, creds string) (*hoapi.CreateCVDResponse, error) { var op hoapi.Operation - rb := c.httpHelper.NewPostRequest("/cvds", req) - // Cloud Orchestrator only checks for the existence of the header, hence an empty string value is ok. - rb.Header.Add(headerNameCOInjectBuildAPICreds, "") + rb := c.HTTPHelper.NewPostRequest("/cvds", req) + if creds != "" { + rb.AddHeader(c.BuildAPICredentialsHeader, creds) + } if err := rb.Do(&op); err != nil { return nil, err } @@ -230,20 +256,20 @@ func (c *hostOrchestratorServiceImpl) CreateCVD(req *hoapi.CreateCVDRequest) (*h return res, nil } -func (c *hostOrchestratorServiceImpl) ListCVDs() ([]*hoapi.CVD, error) { +func (c *HostOrchestratorServiceImpl) ListCVDs() ([]*hoapi.CVD, error) { var res hoapi.ListCVDsResponse - if err := c.httpHelper.NewGetRequest("/cvds").Do(&res); err != nil { + if err := c.HTTPHelper.NewGetRequest("/cvds").Do(&res); err != nil { return nil, err } return res.CVDs, nil } -func (c *hostOrchestratorServiceImpl) DownloadRuntimeArtifacts(dst io.Writer) error { - req, err := http.NewRequest("POST", c.httpHelper.RootEndpoint+"/runtimeartifacts/:pull", nil) +func (c *HostOrchestratorServiceImpl) DownloadRuntimeArtifacts(dst io.Writer) error { + req, err := http.NewRequest("POST", c.HTTPHelper.RootEndpoint+"/runtimeartifacts/:pull", nil) if err != nil { return err } - res, err := c.httpHelper.Client.Do(req) + res, err := c.HTTPHelper.Client.Do(req) if err != nil { return err } @@ -257,9 +283,9 @@ func (c *hostOrchestratorServiceImpl) DownloadRuntimeArtifacts(dst io.Writer) er return nil } -func (c *hostOrchestratorServiceImpl) CreateUploadDir() (string, error) { +func (c *HostOrchestratorServiceImpl) CreateUploadDir() (string, error) { uploadDir := &hoapi.UploadDirectory{} - if err := c.httpHelper.NewPostRequest("/userartifacts", nil).Do(uploadDir); err != nil { + if err := c.HTTPHelper.NewPostRequest("/userartifacts", nil).Do(uploadDir); err != nil { return "", err } return uploadDir.Name, nil @@ -267,17 +293,36 @@ func (c *hostOrchestratorServiceImpl) CreateUploadDir() (string, error) { const openConnections = 32 -func (c *hostOrchestratorServiceImpl) UploadFiles(uploadDir string, filenames []string) error { - if c.ChunkSizeBytes == 0 { +func (c *HostOrchestratorServiceImpl) UploadFiles(uploadDir string, filenames []string) error { + return c.UploadFilesWithOptions(uploadDir, filenames, DefaultUploadOptions()) +} + +func DefaultUploadOptions() UploadOptions { + return UploadOptions{ + InitialDuration: 500 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 1.5, + MaxElapsedTime: 2 * time.Minute, + ChunkSizeBytes: 16 * 1024 * 1024, // 16 MB + NumWorkers: 32, + } +} + +func (c *HostOrchestratorServiceImpl) UploadFilesWithOptions(uploadDir string, filenames []string, uploadOpts UploadOptions) error { + if uploadOpts.ChunkSizeBytes == 0 { panic("ChunkSizeBytes value cannot be zero") } + if uploadOpts.NumWorkers == 0 { + panic("NumWorkers value cannot be zero") + } + if uploadOpts.MaxElapsedTime == 0 { + panic("MaxElapsedTime value cannot be zero") + } uploader := &FilesUploader{ - Client: c.httpHelper.Client, - EndpointURL: c.httpHelper.RootEndpoint + "/userartifacts/" + uploadDir, - ChunkSizeBytes: c.ChunkSizeBytes, - DumpOut: c.httpHelper.Dumpster, - NumWorkers: openConnections, - BackOffOpts: c.ChunkUploadBackOffOpts, + Client: c.HTTPHelper.Client, + EndpointURL: c.HTTPHelper.RootEndpoint + "/userartifacts/" + uploadDir, + DumpOut: c.HTTPHelper.Dumpster, + UploadOptions: uploadOpts, } return uploader.Upload(filenames) } diff --git a/pkg/client/host_orchestrator_client_test.go b/pkg/client/host_orchestrator_client_test.go index b97eb3a4..5b0368d8 100644 --- a/pkg/client/host_orchestrator_client_test.go +++ b/pkg/client/host_orchestrator_client_test.go @@ -40,7 +40,7 @@ func TestUploadFilesChunkSizeBytesIsZeroPanic(t *testing.T) { } srv, _ := NewService(opts) - srv.HostService("foo").UploadFiles("bar", []string{"baz"}) + srv.HostService("foo").UploadFilesWithOptions("bar", []string{"baz"}, UploadOptions{ChunkSizeBytes: 0}) } func TestUploadFilesSucceeds(t *testing.T) { @@ -94,7 +94,6 @@ func TestUploadFilesSucceeds(t *testing.T) { default: t.Fatal("unexpected endpoint: " + ep) } - })) defer ts.Close() opts := &ServiceOptions{ @@ -104,7 +103,14 @@ func TestUploadFilesSucceeds(t *testing.T) { } srv, _ := NewService(opts) - err := srv.HostService(host).UploadFiles(uploadDir, []string{quxFile, waldoFile, xyzzyFile}) + err := srv.HostService(host).UploadFilesWithOptions(uploadDir, []string{quxFile, waldoFile, xyzzyFile}, + UploadOptions{ + InitialDuration: 100 * time.Millisecond, + Multiplier: 2, + MaxElapsedTime: 1 * time.Second, + ChunkSizeBytes: 2, + NumWorkers: 10, + }) if err != nil { t.Fatal(err) } @@ -132,15 +138,16 @@ func TestUploadFilesExponentialBackoff(t *testing.T) { RootEndpoint: ts.URL, DumpOut: io.Discard, ChunkSizeBytes: 2, - ChunkUploadBackOffOpts: BackOffOpts{ - InitialDuration: 100 * time.Millisecond, - Multiplier: 2, - MaxElapsedTime: 1 * time.Minute, - }, } srv, _ := NewService(opts) - err := srv.HostService("foo").UploadFiles("bar", []string{waldoFile}) + err := srv.HostService("foo").UploadFilesWithOptions("bar", []string{waldoFile}, UploadOptions{ + InitialDuration: 100 * time.Millisecond, + Multiplier: 2, + MaxElapsedTime: 1 * time.Second, + ChunkSizeBytes: 2, + NumWorkers: 10, + }) if err != nil { t.Fatal(err) @@ -149,7 +156,7 @@ func TestUploadFilesExponentialBackoff(t *testing.T) { t.Fatal("first retry shouldn't be in less than 100ms") } if timestamps[2].Sub(timestamps[1]) < 200*time.Millisecond { - t.Fatal("first retry shouldn't be in less than 200ms") + t.Fatal("second retry shouldn't be in less than 200ms") } } @@ -167,15 +174,17 @@ func TestUploadFilesExponentialBackoffReachedElapsedTime(t *testing.T) { RootEndpoint: ts.URL, DumpOut: io.Discard, ChunkSizeBytes: 2, - ChunkUploadBackOffOpts: BackOffOpts{ - InitialDuration: 100 * time.Millisecond, - Multiplier: 2, - MaxElapsedTime: 1 * time.Second, - }, } srv, _ := NewService(opts) - err := srv.HostService("foo").UploadFiles("bar", []string{waldoFile}) + err := srv.HostService("foo").UploadFilesWithOptions("bar", []string{waldoFile}, UploadOptions{ + InitialDuration: 100 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 2, + MaxElapsedTime: 1 * time.Second, + ChunkSizeBytes: 2, + NumWorkers: 10, + }) if err == nil { t.Fatal("expected error") diff --git a/pkg/client/httputils.go b/pkg/client/httputils.go index c1af683f..ad633d04 100644 --- a/pkg/client/httputils.go +++ b/pkg/client/httputils.go @@ -27,63 +27,54 @@ import ( "net/http/httputil" "os" "path/filepath" + "slices" "strconv" "strings" "sync" "time" ) -type BackOffOpts struct { - InitialDuration time.Duration - RandomizationFactor float64 - Multiplier float64 - MaxElapsedTime time.Duration -} - -func DefaultChunkUploadBackOffOpts() BackOffOpts { - return BackOffOpts{ - InitialDuration: 500 * time.Millisecond, - RandomizationFactor: 0.5, - Multiplier: 1.5, - MaxElapsedTime: 2 * time.Minute, - } -} - type HTTPHelper struct { - Client *http.Client - RootEndpoint string - Retries uint - RetryDelay time.Duration - ChunkSizeBytes int64 - ChunkUploadBackOffOpts BackOffOpts - Dumpster io.Writer + Client *http.Client + RootEndpoint string + Dumpster io.Writer } func (h *HTTPHelper) NewGetRequest(path string) *HTTPRequestBuilder { + req, err := http.NewRequest(http.MethodGet, h.RootEndpoint+path, nil) return &HTTPRequestBuilder{ - helper: h, - url: h.RootEndpoint + path, - method: "GET", - Header: make(http.Header), + helper: h, + request: req, + err: err, } } func (h *HTTPHelper) NewDeleteRequest(path string) *HTTPRequestBuilder { + req, err := http.NewRequest(http.MethodDelete, h.RootEndpoint+path, nil) return &HTTPRequestBuilder{ - helper: h, - url: h.RootEndpoint + path, - method: "DELETE", - Header: make(http.Header), + helper: h, + request: req, + err: err, } } func (h *HTTPHelper) NewPostRequest(path string, jsonBody any) *HTTPRequestBuilder { + body := []byte{} + var err error + if jsonBody != nil { + if body, err = json.Marshal(jsonBody); err != nil { + return &HTTPRequestBuilder{helper: h, request: nil, err: err} + } + } + req, err := http.NewRequest(http.MethodPost, h.RootEndpoint+path, bytes.NewBuffer(body)) + if err != nil { + return &HTTPRequestBuilder{helper: h, request: nil, err: err} + } + req.Header.Set("Content-Type", "application/json") return &HTTPRequestBuilder{ - helper: h, - url: h.RootEndpoint + path, - method: "POST", - Header: make(http.Header), - jsonBody: jsonBody, + helper: h, + request: req, + err: err, } } @@ -112,45 +103,54 @@ func (h *HTTPHelper) dumpResponse(r *http.Response) error { } type HTTPRequestBuilder struct { - helper *HTTPHelper - url string - Header http.Header - method string - jsonBody any + helper *HTTPHelper + request *http.Request + err error } -func (rb *HTTPRequestBuilder) Do(ret any) error { - var body io.Reader - if rb.jsonBody != nil { - json, err := json.Marshal(rb.jsonBody) - if err != nil { - return fmt.Errorf("Error marshaling request: %w", err) - } - body = bytes.NewBuffer(json) - rb.Header.Set("Content-Type", "application/json") +func (rb *HTTPRequestBuilder) AddHeader(key, value string) { + if rb.request == nil { + return } - req, err := http.NewRequest(rb.method, rb.url, body) - if err != nil { - return fmt.Errorf("Error creating request: %w", err) + rb.request.Header.Add(key, value) +} + +func (rb *HTTPRequestBuilder) SetHeader(key, value string) { + if rb.request == nil { + return } - for k, v := range rb.Header { - req.Header[k] = v + rb.request.Header.Set(key, value) +} + +type RetryOptions struct { + StatusCodes []int + NumRetries uint + RetryDelay time.Duration +} + +func (rb *HTTPRequestBuilder) Do(ret any) error { + return rb.DoWithRetries(ret, RetryOptions{}) +} + +func (rb *HTTPRequestBuilder) DoWithRetries(ret any, retryOpts RetryOptions) error { + if rb.err != nil { + return rb.err } - if err := rb.helper.dumpRequest(req); err != nil { + if err := rb.helper.dumpRequest(rb.request); err != nil { return err } - res, err := rb.helper.Client.Do(req) + res, err := rb.helper.Client.Do(rb.request) if err != nil { return fmt.Errorf("Error sending request: %w", err) } - for i := uint(0); i < rb.helper.Retries && isRetryableErrorCode(res.StatusCode); i++ { + for i := uint(0); i < retryOpts.NumRetries && slices.Contains(retryOpts.StatusCodes, res.StatusCode); i++ { err = rb.helper.dumpResponse(res) res.Body.Close() if err != nil { return err } - time.Sleep(rb.helper.RetryDelay) - if res, err = rb.helper.Client.Do(req); err != nil { + time.Sleep(retryOpts.RetryDelay) + if res, err = rb.helper.Client.Do(rb.request); err != nil { return fmt.Errorf("Error sending request: %w", err) } } @@ -158,6 +158,10 @@ func (rb *HTTPRequestBuilder) Do(ret any) error { if err := rb.helper.dumpResponse(res); err != nil { return err } + return rb.parseResponse(res, ret) +} + +func (rb *HTTPRequestBuilder) parseResponse(res *http.Response, ret any) error { dec := json.NewDecoder(res.Body) if res.StatusCode < 200 || res.StatusCode > 299 { errpl := new(ApiCallError) @@ -179,13 +183,20 @@ type fileInfo struct { TotalChunks int } +type UploadOptions struct { + InitialDuration time.Duration + RandomizationFactor float64 + Multiplier float64 + MaxElapsedTime time.Duration + ChunkSizeBytes int64 + NumWorkers int +} + type FilesUploader struct { - Client *http.Client - EndpointURL string - ChunkSizeBytes int64 - DumpOut io.Writer - NumWorkers int - BackOffOpts + Client *http.Client + EndpointURL string + DumpOut io.Writer + UploadOptions } func (u *FilesUploader) Upload(files []string) error { @@ -257,12 +268,12 @@ func (u *FilesUploader) startWorkers(ctx context.Context, jobsChan <-chan upload for i := 0; i < u.NumWorkers; i++ { wg.Add(1) w := uploadChunkWorker{ - Context: ctx, - Client: u.Client, - EndpointURL: u.EndpointURL, - DumpOut: u.DumpOut, - JobsChan: jobsChan, - BackOffOpts: u.BackOffOpts, + Context: ctx, + Client: u.Client, + EndpointURL: u.EndpointURL, + DumpOut: u.DumpOut, + JobsChan: jobsChan, + UploadOptions: u.UploadOptions, } go func() { defer wg.Done() @@ -294,17 +305,17 @@ type uploadChunkWorker struct { EndpointURL string DumpOut io.Writer JobsChan <-chan uploadChunkJob - BackOffOpts + UploadOptions } // Returns a channel that will return the result for each of the handled `uploadChunkJob` instances. func (w *uploadChunkWorker) Start() <-chan error { ch := make(chan error) b := backoff.NewExponentialBackOff() - b.InitialInterval = w.BackOffOpts.InitialDuration - b.RandomizationFactor = w.BackOffOpts.RandomizationFactor - b.Multiplier = w.BackOffOpts.Multiplier - b.MaxElapsedTime = w.BackOffOpts.MaxElapsedTime + b.InitialInterval = w.UploadOptions.InitialDuration + b.RandomizationFactor = w.UploadOptions.RandomizationFactor + b.Multiplier = w.UploadOptions.Multiplier + b.MaxElapsedTime = w.UploadOptions.MaxElapsedTime b.Reset() go func() { defer close(ch)