diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 7ff725f3..5080b9f6 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -720,7 +720,7 @@ func runPullCommand(c *cobra.Command, args []string, flags *CVDRemoteFlags, opts if err != nil { return err } - if err := service.DownloadRuntimeArtifacts(host, f); err != nil { + if err := service.HostService(host).DownloadRuntimeArtifacts(f); err != nil { return err } if err := f.Close(); err != nil { diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index 939fceba..03e6202a 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -103,47 +103,53 @@ func (fakeService) DeleteHosts(name []string) error { return nil } -func (fakeService) GetInfraConfig(host string) (*apiv1.InfraConfig, error) { +const serviceURL = "http://waldo.com" + +func (fakeService) RootURI() string { + return serviceURL + "/v1" +} + +func (fakeService) HostService(host string) client.HostOrchestratorService { + if host == "" { + panic("empty host") + } + return &fakeHostService{} +} + +type fakeHostService struct{} + +func (fakeHostService) GetInfraConfig() (*apiv1.InfraConfig, error) { return nil, nil } -func (fakeService) ConnectWebRTC(host, device string, observer wclient.Observer, logger io.Writer, opts client.ConnectWebRTCOpts) (*wclient.Connection, error) { +func (fakeHostService) ConnectWebRTC(device string, observer wclient.Observer, logger io.Writer, opts client.ConnectWebRTCOpts) (*wclient.Connection, error) { return nil, nil } -func (fakeService) FetchArtifacts(host string, req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) { +func (fakeHostService) FetchArtifacts(req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) { return &hoapi.FetchArtifactsResponse{AndroidCIBundle: &hoapi.AndroidCIBundle{}}, nil } -func (fakeService) CreateCVD(host string, req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) { - if host == "" { - panic("empty host") - } +func (fakeHostService) CreateCVD(req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) { return &hoapi.CreateCVDResponse{CVDs: []*hoapi.CVD{{Name: "cvd-1"}}}, nil } -func (fakeService) ListCVDs(host string) ([]*hoapi.CVD, error) { +func (fakeHostService) ListCVDs() ([]*hoapi.CVD, error) { return []*hoapi.CVD{{Name: "cvd-1"}}, nil } -func (fakeService) CreateUpload(host string) (string, error) { +func (fakeHostService) CreateUploadDir() (string, error) { return "", nil } -func (fakeService) UploadFiles(host, uploadDir string, filenames []string) error { +func (fakeHostService) UploadFiles(uploadDir string, filenames []string) error { return nil } -func (fakeService) DownloadRuntimeArtifacts(host string, dst io.Writer) error { +func (fakeHostService) DownloadRuntimeArtifacts(dst io.Writer) error { return nil } -const serviceURL = "http://waldo.com" - -func (fakeService) RootURI() string { - return serviceURL + "/v1" -} - func TestCommandSucceeds(t *testing.T) { tests := []struct { Name string diff --git a/pkg/cli/conn.go b/pkg/cli/conn.go index 8d03ed68..9a8520ad 100644 --- a/pkg/cli/conn.go +++ b/pkg/cli/conn.go @@ -390,7 +390,7 @@ func NewConnController( opts := client.ConnectWebRTCOpts{ LocalICEConfig: localICEConfig, } - conn, err := service.ConnectWebRTC(cvd.Host, cvd.WebRTCDeviceID, tc, logger.Writer(), opts) + conn, err := service.HostService(cvd.Host).ConnectWebRTC(cvd.WebRTCDeviceID, tc, logger.Writer(), opts) if err != nil { return nil, fmt.Errorf("Failed to connect to %q: %w", cvd.WebRTCDeviceID, err) } diff --git a/pkg/cli/cvd.go b/pkg/cli/cvd.go index ed2cb920..3e7e8ff8 100644 --- a/pkg/cli/cvd.go +++ b/pkg/cli/cvd.go @@ -131,11 +131,11 @@ func (c *cvdCreator) createCVDFromLocalBuild() ([]*hoapi.CVD, error) { return nil, fmt.Errorf("Invalid cvd host package: %w", err) } names = append(names, filepath.Join(vars.HostOut, CVDHostPackageName)) - uploadDir, err := c.service.CreateUpload(c.opts.Host) + uploadDir, err := c.service.HostService(c.opts.Host).CreateUploadDir() if err != nil { return nil, err } - if err := c.service.UploadFiles(c.opts.Host, uploadDir, names); err != nil { + if err := c.service.HostService(c.opts.Host).UploadFiles(uploadDir, names); err != nil { return nil, err } req := hoapi.CreateCVDRequest{ @@ -148,7 +148,7 @@ func (c *cvdCreator) createCVDFromLocalBuild() ([]*hoapi.CVD, error) { }, AdditionalInstancesNum: c.opts.AdditionalInstancesNum(), } - res, err := c.service.CreateCVD(c.opts.Host, &req) + res, err := c.service.HostService(c.opts.Host).CreateCVD(&req) if err != nil { return nil, err } @@ -173,7 +173,7 @@ func (c *cvdCreator) createWithCanonicalConfig() ([]*hoapi.CVD, error) { EnvConfig: c.opts.EnvConfig, } c.statePrinter.Print(stateMsgFetchAndStart) - res, err := c.service.CreateCVD(c.opts.Host, createReq) + res, err := c.service.HostService(c.opts.Host).CreateCVD(createReq) c.statePrinter.PrintDone(stateMsgFetchAndStart, err) if err != nil { return nil, err @@ -197,7 +197,7 @@ func (c *cvdCreator) createWithOpts() ([]*hoapi.CVD, error) { AndroidCIBundle: &hoapi.AndroidCIBundle{Build: mainBuild, Type: hoapi.MainBundleType}, } c.statePrinter.Print(stateMsgFetchMainBundle) - fetchMainBuildRes, err := c.service.FetchArtifacts(c.opts.Host, fetchReq) + fetchMainBuildRes, err := c.service.HostService(c.opts.Host).FetchArtifacts(fetchReq) c.statePrinter.PrintDone(stateMsgFetchMainBundle, err) if err != nil { return nil, err @@ -216,7 +216,7 @@ func (c *cvdCreator) createWithOpts() ([]*hoapi.CVD, error) { AdditionalInstancesNum: c.opts.AdditionalInstancesNum(), } c.statePrinter.Print(stateMsgStartCVD) - res, err := c.service.CreateCVD(c.opts.Host, createReq) + res, err := c.service.HostService(c.opts.Host).CreateCVD(createReq) c.statePrinter.PrintDone(stateMsgStartCVD, err) if err != nil { return nil, err @@ -291,7 +291,7 @@ func flattenCVDs(hosts []*RemoteHost) []*RemoteCVD { // Calling listCVDConnectionsByHost is inefficient, this internal function avoids that for listAllCVDs. func listHostCVDsInner(service client.Service, host string, statuses map[RemoteCVDLocator]ConnStatus) ([]*RemoteCVD, error) { - cvds, err := service.ListCVDs(host) + cvds, err := service.HostService(host).ListCVDs() if err != nil { return nil, err } diff --git a/pkg/client/client.go b/pkg/client/client.go index 7b4182a9..2ac02365 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -24,9 +24,7 @@ import ( "time" apiv1 "github.com/google/cloud-android-orchestration/api/v1" - wclient "github.com/google/cloud-android-orchestration/pkg/webrtcclient" - hoapi "github.com/google/android-cuttlefish/frontend/src/liboperator/api/v1" "github.com/hashicorp/go-multierror" "github.com/pion/webrtc/v3" ) @@ -67,10 +65,6 @@ type ServiceOptions struct { ChunkUploadBackOffOpts BackOffOpts } -type ConnectWebRTCOpts struct { - LocalICEConfig *wclient.ICEConfig -} - type Service interface { CreateHost(req *apiv1.CreateHostRequest) (*apiv1.HostInstance, error) @@ -78,22 +72,7 @@ type Service interface { DeleteHosts(names []string) error - GetInfraConfig(host string) (*apiv1.InfraConfig, error) - - ConnectWebRTC(host, device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) - - FetchArtifacts(host string, req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) - - CreateCVD(host string, req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) - - ListCVDs(host string) ([]*hoapi.CVD, error) - - // Downloads runtime artifacts tar file from passed `host` into `dst` filename. - DownloadRuntimeArtifacts(host string, dst io.Writer) error - - CreateUpload(host string) (string, error) - - UploadFiles(host, uploadDir string, filenames []string) error + HostService(host string) HostOrchestratorService RootURI() string } @@ -166,152 +145,11 @@ func (c *serviceImpl) DeleteHosts(names []string) error { return merr } -func (c *serviceImpl) GetInfraConfig(host string) (*apiv1.InfraConfig, error) { - var res apiv1.InfraConfig - if err := c.httpHelper.NewGetRequest(fmt.Sprintf("/hosts/%s/infra_config", host)).Do(&res); err != nil { - return nil, err - } - return &res, nil -} - -func (c *serviceImpl) ConnectWebRTC(host, device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) { - polledConn, err := c.createPolledConnection(host, device) - if err != nil { - return nil, fmt.Errorf("Failed to create polled connection: %w", err) - } - iceServers := []webrtc.ICEServer{} - if opts.LocalICEConfig != nil { - iceServers = append(iceServers, opts.LocalICEConfig.ICEServers...) - } - infraConfig, err := c.GetInfraConfig(host) - if err != nil { - return nil, fmt.Errorf("Failed to obtain infra config: %w", err) - } - iceServers = append(iceServers, asWebRTCICEServers(infraConfig.IceServers)...) - signaling := c.initHandling(host, polledConn.ConnId, iceServers) - conn, err := wclient.NewConnectionWithLogger(&signaling, observer, logger) - if err != nil { - return nil, fmt.Errorf("Failed to connect to device over webrtc: %w", err) - } - return conn, nil -} - -func (c *serviceImpl) createPolledConnection(host, device string) (*apiv1.NewConnReply, error) { - path := fmt.Sprintf("/hosts/%s/polled_connections", host) - req := apiv1.NewConnMsg{DeviceId: device} - var res apiv1.NewConnReply - if err := c.httpHelper.NewPostRequest(path, &req).Do(&res); err != nil { - return nil, err - } - return &res, nil -} - -func (c *serviceImpl) initHandling(host, connID string, iceServers []webrtc.ICEServer) wclient.Signaling { - sendCh := make(chan any) - recvCh := make(chan map[string]any) - - // The forwarding goroutine will close this channel and stop when the send - // channel is closed, which will cause the polling go routine to close its own - // channel and stop as well. - stopPollCh := make(chan bool) - go c.webRTCPoll(recvCh, host, connID, stopPollCh) - go c.webRTCForward(sendCh, host, connID, stopPollCh) - - return wclient.Signaling{ - SendCh: sendCh, - RecvCh: recvCh, - ICEServers: iceServers, - ClientICEServers: iceServers, - } -} - -const ( - initialPollInterval = 100 * time.Millisecond - maxPollInterval = 2 * time.Second - maxConsecutiveErrors = 10 -) - -func (c *serviceImpl) webRTCPoll(sinkCh chan map[string]any, host, connID string, stopCh chan bool) { - start := 0 - pollInterval := initialPollInterval - errCount := 0 - for { - path := fmt.Sprintf("/hosts/%s/polled_connections/%s/messages?start=%d", host, connID, start) - var messages []map[string]any - if err := c.httpHelper.NewGetRequest(path).Do(&messages); err != nil { - fmt.Fprintf(c.ErrOut, "Error polling messages: %v\n", err) - errCount++ - if errCount >= maxConsecutiveErrors { - fmt.Fprintln(c.ErrOut, "Reached maximum number of consecutive polling errors, exiting") - close(sinkCh) - return - } - } else { - errCount = 0 - } - if len(messages) > 0 { - pollInterval = initialPollInterval - } else { - pollInterval = 2 * pollInterval - if pollInterval > maxPollInterval { - pollInterval = maxPollInterval - } - } - for _, message := range messages { - if message["message_type"] != "device_msg" { - fmt.Fprintf(c.ErrOut, "unexpected message type: %s\n", message["message_type"]) - continue - } - sinkCh <- message["payload"].(map[string]any) - start++ - } - select { - case _, _ = <-stopCh: - // The forwarding goroutine has requested a stop - close(sinkCh) - return - case <-time.After(pollInterval): - // poll for messages again - } - } -} - -func (c *serviceImpl) webRTCForward(srcCh chan any, host, connID string, stopPollCh chan bool) { - for { - msg, open := <-srcCh - if !open { - // The webrtc client closed the channel - close(stopPollCh) - break - } - forwardMsg := apiv1.ForwardMsg{Payload: msg} - path := fmt.Sprintf("/hosts/%s/polled_connections/%s/:forward", host, connID) - i := 0 - for ; i < maxConsecutiveErrors; i++ { - if err := c.httpHelper.NewPostRequest(path, &forwardMsg).Do(nil); err != nil { - fmt.Fprintf(c.ErrOut, "Error sending message to device: %v\n", err) - } else { - break - } - } - if i == maxConsecutiveErrors { - fmt.Fprintln(c.ErrOut, "Reached maximum number of sending errors, exiting") - close(stopPollCh) - return - } - } -} - func (c *serviceImpl) waitForOperation(op *apiv1.Operation, res any) error { path := "/operations/" + op.Name + "/:wait" return c.httpHelper.NewPostRequest(path, nil).Do(res) } -func (c *serviceImpl) waitForHostOperation(host string, op *hoapi.Operation, res any) error { - path := "/hosts/" + host + "/operations/" + op.Name + "/:wait" - return c.httpHelper.NewPostRequest(path, nil).Do(res) -} - func asWebRTCICEServers(in []apiv1.IceServer) []webrtc.ICEServer { out := []webrtc.ICEServer{} for _, s := range in { @@ -324,93 +162,18 @@ func asWebRTCICEServers(in []apiv1.IceServer) []webrtc.ICEServer { const headerNameCOInjectBuildAPICreds = "X-Cutf-Cloud-Orchestrator-Inject-BuildAPI-Creds" -func (c *serviceImpl) FetchArtifacts( - host string, req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) { - var op hoapi.Operation - path := "/hosts/" + host + "/artifacts" - rb := c.httpHelper.NewPostRequest(path, req) - // Cloud Orchestrator only checks for the presence of the header, hence an empty string value is ok. - rb.Header.Add(headerNameCOInjectBuildAPICreds, "") - if err := rb.Do(&op); err != nil { - return nil, err - } - - res := &hoapi.FetchArtifactsResponse{} - if err := c.waitForHostOperation(host, &op, &res); err != nil { - return nil, err - } - return res, nil -} - -func (c *serviceImpl) CreateCVD(host string, req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) { - var op hoapi.Operation - rb := c.httpHelper.NewPostRequest("/hosts/"+host+"/cvds", req) - // Cloud Orchestrator only checks for the existence of the header, hence an empty string value is ok. - rb.Header.Add(headerNameCOInjectBuildAPICreds, "") - if err := rb.Do(&op); err != nil { - return nil, err - } - res := &hoapi.CreateCVDResponse{} - if err := c.waitForHostOperation(host, &op, &res); err != nil { - return nil, err - } - return res, nil -} - -func (c *serviceImpl) ListCVDs(host string) ([]*hoapi.CVD, error) { - var res hoapi.ListCVDsResponse - if err := c.httpHelper.NewGetRequest("/hosts/" + host + "/cvds").Do(&res); err != nil { - return nil, err - } - return res.CVDs, nil -} - -func (c *serviceImpl) DownloadRuntimeArtifacts(host string, dst io.Writer) error { - req, err := http.NewRequest("POST", c.RootEndpoint+"/hosts/"+host+"/runtimeartifacts/:pull", nil) - if err != nil { - return err - } - res, err := c.httpHelper.Client.Do(req) - if err != nil { - return err - } - defer res.Body.Close() - if _, err := io.Copy(dst, res.Body); err != nil { - return err - } - if res.StatusCode < 200 || res.StatusCode > 299 { - return &ApiCallError{ErrorMsg: res.Status} - } - return nil -} - -func (c *serviceImpl) CreateUpload(host string) (string, error) { - uploadDir := &hoapi.UploadDirectory{} - if err := c.httpHelper.NewPostRequest("/hosts/"+host+"/userartifacts", nil).Do(uploadDir); err != nil { - return "", err - } - return uploadDir.Name, nil +func (s *serviceImpl) RootURI() string { + return s.RootEndpoint } -const openConnections = 32 - -func (c *serviceImpl) UploadFiles(host, uploadDir string, filenames []string) error { - if c.ChunkSizeBytes == 0 { - panic("ChunkSizeBytes value cannot be zero") +func (s *serviceImpl) HostService(host string) HostOrchestratorService { + hs := &hostOrchestratorServiceImpl{ + httpHelper: s.httpHelper, + ChunkSizeBytes: s.ChunkSizeBytes, + ChunkUploadBackOffOpts: s.ChunkUploadBackOffOpts, } - uploader := &FilesUploader{ - Client: c.httpHelper.Client, - EndpointURL: c.RootEndpoint + "/hosts/" + host + "/userartifacts/" + uploadDir, - ChunkSizeBytes: c.ChunkSizeBytes, - DumpOut: c.DumpOut, - NumWorkers: openConnections, - BackOffOpts: c.ChunkUploadBackOffOpts, - } - return uploader.Upload(filenames) -} - -func (s *serviceImpl) RootURI() string { - return s.RootEndpoint + hs.httpHelper.RootEndpoint = s.httpHelper.RootEndpoint + "/hosts/" + host + return hs } func BuildRootEndpoint(serviceURL, version, zone string) string { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index ee8f2901..185f611d 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -16,14 +16,10 @@ package client import ( "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" - "os" - "path/filepath" "regexp" - "sync" "testing" "time" @@ -75,163 +71,6 @@ func TestRetryLogic(t *testing.T) { } } -func TestUploadFilesChunkSizeBytesIsZeroPanic(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("did not panic") - } - }() - opts := &ServiceOptions{ - RootEndpoint: "https://test.com", - DumpOut: io.Discard, - } - srv, _ := NewService(opts) - - srv.UploadFiles("foo", "bar", []string{"baz"}) -} - -func TestUploadFilesSucceeds(t *testing.T) { - host := "foo" - uploadDir := "bar" - tempDir := createTempDir(t) - defer os.RemoveAll(tempDir) - quxFile := createTempFile(t, tempDir, "qux", []byte("lorem")) - waldoFile := createTempFile(t, tempDir, "waldo", []byte("l")) - xyzzyFile := createTempFile(t, tempDir, "xyzzy", []byte("abraca")) - mu := sync.Mutex{} - // expected uploads are keyed by format %filename %chunknumber of %chunktotal with the chunk content as value - uploads := map[string]struct{ Content []byte }{ - // qux - "qux 1 of 3": {Content: []byte("lo")}, - "qux 2 of 3": {Content: []byte("re")}, - "qux 3 of 3": {Content: []byte("m")}, - // waldo - "waldo 1 of 1": {Content: []byte("l")}, - // xyzzy - "xyzzy 1 of 3": {Content: []byte("ab")}, - "xyzzy 2 of 3": {Content: []byte("ra")}, - "xyzzy 3 of 3": {Content: []byte("ca")}, - } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - defer mu.Unlock() - switch ep := r.Method + " " + r.URL.Path; ep { - case "PUT /hosts/" + host + "/userartifacts/" + uploadDir: - chunkNumber := r.PostFormValue("chunk_number") - chunkTotal := r.PostFormValue("chunk_total") - f, fheader, err := r.FormFile("file") - if err != nil { - t.Fatal(err) - } - expectedUploadKey := fmt.Sprintf("%s %s of %s", fheader.Filename, chunkNumber, chunkTotal) - val, ok := uploads[expectedUploadKey] - if !ok { - t.Fatalf("unexpected upload with filename: %q, chunk number: %s, chunk total: %s", - fheader.Filename, chunkNumber, chunkTotal) - } - delete(uploads, expectedUploadKey) - b, err := io.ReadAll(f) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(val.Content, b); diff != "" { - t.Fatalf("chunk content mismatch %q (-want +got):\n%s", fheader.Filename, diff) - } - writeOK(w, struct{}{}) - default: - t.Fatal("unexpected endpoint: " + ep) - } - - })) - defer ts.Close() - opts := &ServiceOptions{ - RootEndpoint: ts.URL, - DumpOut: io.Discard, - ChunkSizeBytes: 2, - } - srv, _ := NewService(opts) - - err := srv.UploadFiles(host, uploadDir, []string{quxFile, waldoFile, xyzzyFile}) - if err != nil { - t.Fatal(err) - } - - if len(uploads) != 0 { - t.Errorf("missing chunk uploads: %v", uploads) - } -} - -func TestUploadFilesExponentialBackoff(t *testing.T) { - tempDir := createTempDir(t) - defer os.RemoveAll(tempDir) - waldoFile := createTempFile(t, tempDir, "waldo", []byte("l")) - timestamps := make([]time.Time, 0) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - timestamps = append(timestamps, time.Now()) - if len(timestamps) < 3 { - writeErr(w, 500) - return - } - writeOK(w, struct{}{}) - })) - defer ts.Close() - opts := &ServiceOptions{ - RootEndpoint: ts.URL, - DumpOut: io.Discard, - ChunkSizeBytes: 2, - ChunkUploadBackOffOpts: BackOffOpts{ - InitialDuration: 100 * time.Millisecond, - Multiplier: 2, - MaxElapsedTime: 1 * time.Minute, - }, - } - srv, _ := NewService(opts) - - err := srv.UploadFiles("foo", "bar", []string{waldoFile}) - - if err != nil { - t.Fatal(err) - } - if timestamps[1].Sub(timestamps[0]) < 100*time.Millisecond { - t.Fatal("first retry shouldn't be in less than 100ms") - } - if timestamps[2].Sub(timestamps[1]) < 200*time.Millisecond { - t.Fatal("first retry shouldn't be in less than 200ms") - } -} - -func TestUploadFilesExponentialBackoffReachedElapsedTime(t *testing.T) { - tempDir := createTempDir(t) - defer os.RemoveAll(tempDir) - waldoFile := createTempFile(t, tempDir, "waldo", []byte("l")) - attempts := 0 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attempts = attempts + 1 - writeErr(w, 500) - })) - defer ts.Close() - opts := &ServiceOptions{ - RootEndpoint: ts.URL, - DumpOut: io.Discard, - ChunkSizeBytes: 2, - ChunkUploadBackOffOpts: BackOffOpts{ - InitialDuration: 100 * time.Millisecond, - Multiplier: 2, - MaxElapsedTime: 1 * time.Second, - }, - } - srv, _ := NewService(opts) - - err := srv.UploadFiles("foo", "bar", []string{waldoFile}) - - if err == nil { - t.Fatal("expected error") - } - if attempts == 0 { - t.Fatal("server was never reached") - } -} - func TestDeleteHosts(t *testing.T) { existingNames := map[string]struct{}{"bar": {}, "baz": {}} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -265,22 +104,6 @@ func TestDeleteHosts(t *testing.T) { } } -func createTempDir(t *testing.T) string { - dir, err := os.MkdirTemp("", "cvdrTest") - if err != nil { - t.Fatal(err) - } - return dir -} - -func createTempFile(t *testing.T, dir, name string, content []byte) string { - file := filepath.Join(dir, name) - if err := os.WriteFile(file, content, 0666); err != nil { - t.Fatal(err) - } - return file -} - func writeErr(w http.ResponseWriter, statusCode int) { write(w, &apiv1.Error{Code: statusCode}, statusCode) } diff --git a/pkg/client/host_orchestrator_client.go b/pkg/client/host_orchestrator_client.go new file mode 100644 index 00000000..91104b1a --- /dev/null +++ b/pkg/client/host_orchestrator_client.go @@ -0,0 +1,284 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package client + +import ( + "fmt" + "io" + "net/http" + "time" + + apiv1 "github.com/google/cloud-android-orchestration/api/v1" + wclient "github.com/google/cloud-android-orchestration/pkg/webrtcclient" + + hoapi "github.com/google/android-cuttlefish/frontend/src/liboperator/api/v1" + "github.com/pion/webrtc/v3" +) + +type ConnectWebRTCOpts struct { + LocalICEConfig *wclient.ICEConfig +} + +// A client to the host orchestrator service running in a remote host. +type HostOrchestratorService interface { + // Lists currently running devices. + ListCVDs() ([]*hoapi.CVD, error) + + // Creates a directory in the host where user artifacts can be uploaded to. + CreateUploadDir() (string, error) + + // Uploads user files to a previously created directory. + UploadFiles(uploadDir string, filenames []string) error + + // Create a new device with artifacts from the build server or previously uploaded by the user. + CreateCVD(req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) + + FetchArtifacts(req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) + + // Downloads runtime artifacts tar file into `dst`. + DownloadRuntimeArtifacts(dst io.Writer) error + + // Creates a webRTC connection to a device running in this host. + ConnectWebRTC(device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) +} + +type hostOrchestratorServiceImpl struct { + httpHelper HTTPHelper + ChunkSizeBytes int64 + ChunkUploadBackOffOpts BackOffOpts +} + +func (c *hostOrchestratorServiceImpl) getInfraConfig() (*apiv1.InfraConfig, error) { + var res apiv1.InfraConfig + if err := c.httpHelper.NewGetRequest(fmt.Sprintf("/infra_config")).Do(&res); err != nil { + return nil, err + } + return &res, nil +} + +func (c *hostOrchestratorServiceImpl) ConnectWebRTC(device string, observer wclient.Observer, logger io.Writer, opts ConnectWebRTCOpts) (*wclient.Connection, error) { + polledConn, err := c.createPolledConnection(device) + if err != nil { + return nil, fmt.Errorf("Failed to create polled connection: %w", err) + } + iceServers := []webrtc.ICEServer{} + if opts.LocalICEConfig != nil { + iceServers = append(iceServers, opts.LocalICEConfig.ICEServers...) + } + infraConfig, err := c.getInfraConfig() + if err != nil { + return nil, fmt.Errorf("Failed to obtain infra config: %w", err) + } + iceServers = append(iceServers, asWebRTCICEServers(infraConfig.IceServers)...) + signaling := c.initHandling(polledConn.ConnId, iceServers, logger) + conn, err := wclient.NewConnectionWithLogger(&signaling, observer, logger) + if err != nil { + return nil, fmt.Errorf("Failed to connect to device over webrtc: %w", err) + } + return conn, nil +} + +func (c *hostOrchestratorServiceImpl) initHandling(connID string, iceServers []webrtc.ICEServer, logger io.Writer) wclient.Signaling { + sendCh := make(chan any) + recvCh := make(chan map[string]any) + + // The forwarding goroutine will close this channel and stop when the send + // channel is closed, which will cause the polling go routine to close its own + // channel and stop as well. + stopPollCh := make(chan bool) + go c.webRTCPoll(recvCh, connID, stopPollCh, logger) + go c.webRTCForward(sendCh, connID, stopPollCh, logger) + + return wclient.Signaling{ + SendCh: sendCh, + RecvCh: recvCh, + ICEServers: iceServers, + ClientICEServers: iceServers, + } +} + +const ( + initialPollInterval = 100 * time.Millisecond + maxPollInterval = 2 * time.Second + maxConsecutiveErrors = 10 +) + +func (c *hostOrchestratorServiceImpl) webRTCPoll(sinkCh chan map[string]any, connID string, stopCh chan bool, logger io.Writer) { + start := 0 + pollInterval := initialPollInterval + errCount := 0 + for { + path := fmt.Sprintf("/polled_connections/%s/messages?start=%d", connID, start) + var messages []map[string]any + if err := c.httpHelper.NewGetRequest(path).Do(&messages); err != nil { + fmt.Fprintf(logger, "Error polling messages: %v\n", err) + errCount++ + if errCount >= maxConsecutiveErrors { + fmt.Fprintln(logger, "Reached maximum number of consecutive polling errors, exiting") + close(sinkCh) + return + } + } else { + errCount = 0 + } + if len(messages) > 0 { + pollInterval = initialPollInterval + } else { + pollInterval = 2 * pollInterval + if pollInterval > maxPollInterval { + pollInterval = maxPollInterval + } + } + for _, message := range messages { + if message["message_type"] != "device_msg" { + fmt.Fprintf(logger, "unexpected message type: %s\n", message["message_type"]) + continue + } + sinkCh <- message["payload"].(map[string]any) + start++ + } + select { + case _, _ = <-stopCh: + // The forwarding goroutine has requested a stop + close(sinkCh) + return + case <-time.After(pollInterval): + // poll for messages again + } + } +} + +func (c *hostOrchestratorServiceImpl) webRTCForward(srcCh chan any, connID string, stopPollCh chan bool, logger io.Writer) { + for { + msg, open := <-srcCh + if !open { + // The webrtc client closed the channel + close(stopPollCh) + break + } + forwardMsg := apiv1.ForwardMsg{Payload: msg} + path := fmt.Sprintf("/polled_connections/%s/:forward", connID) + i := 0 + for ; i < maxConsecutiveErrors; i++ { + rb := c.httpHelper.NewPostRequest(path, &forwardMsg) + if err := rb.Do(nil); err != nil { + fmt.Fprintf(logger, "Error sending message to device: %v\n", err) + } else { + break + } + } + if i == maxConsecutiveErrors { + fmt.Fprintln(logger, "Reached maximum number of sending errors, exiting") + close(stopPollCh) + return + } + } +} + +func (c *hostOrchestratorServiceImpl) createPolledConnection(device string) (*apiv1.NewConnReply, error) { + var res apiv1.NewConnReply + rb := c.httpHelper.NewPostRequest("/polled_connections", &apiv1.NewConnMsg{DeviceId: device}) + if err := rb.Do(&res); err != nil { + return nil, err + } + return &res, nil +} + +func (c *hostOrchestratorServiceImpl) waitForOperation(op *hoapi.Operation, res any) error { + path := "/operations/" + op.Name + "/:wait" + return c.httpHelper.NewPostRequest(path, nil).Do(res) +} + +func (c *hostOrchestratorServiceImpl) FetchArtifacts(req *hoapi.FetchArtifactsRequest) (*hoapi.FetchArtifactsResponse, error) { + var op hoapi.Operation + rb := c.httpHelper.NewPostRequest("/artifacts", req) + // Cloud Orchestrator only checks for the presence of the header, hence an empty string value is ok. + rb.Header.Add(headerNameCOInjectBuildAPICreds, "") + if err := rb.Do(&op); err != nil { + return nil, err + } + + res := &hoapi.FetchArtifactsResponse{} + if err := c.waitForOperation(&op, &res); err != nil { + return nil, err + } + return res, nil +} + +func (c *hostOrchestratorServiceImpl) CreateCVD(req *hoapi.CreateCVDRequest) (*hoapi.CreateCVDResponse, error) { + var op hoapi.Operation + rb := c.httpHelper.NewPostRequest("/cvds", req) + // Cloud Orchestrator only checks for the existence of the header, hence an empty string value is ok. + rb.Header.Add(headerNameCOInjectBuildAPICreds, "") + if err := rb.Do(&op); err != nil { + return nil, err + } + res := &hoapi.CreateCVDResponse{} + if err := c.waitForOperation(&op, &res); err != nil { + return nil, err + } + return res, nil +} + +func (c *hostOrchestratorServiceImpl) ListCVDs() ([]*hoapi.CVD, error) { + var res hoapi.ListCVDsResponse + if err := c.httpHelper.NewGetRequest("/cvds").Do(&res); err != nil { + return nil, err + } + return res.CVDs, nil +} + +func (c *hostOrchestratorServiceImpl) DownloadRuntimeArtifacts(dst io.Writer) error { + req, err := http.NewRequest("POST", c.httpHelper.RootEndpoint+"/runtimeartifacts/:pull", nil) + if err != nil { + return err + } + res, err := c.httpHelper.Client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if _, err := io.Copy(dst, res.Body); err != nil { + return err + } + if res.StatusCode < 200 || res.StatusCode > 299 { + return &ApiCallError{ErrorMsg: res.Status} + } + return nil +} + +func (c *hostOrchestratorServiceImpl) CreateUploadDir() (string, error) { + uploadDir := &hoapi.UploadDirectory{} + if err := c.httpHelper.NewPostRequest("/userartifacts", nil).Do(uploadDir); err != nil { + return "", err + } + return uploadDir.Name, nil +} + +const openConnections = 32 + +func (c *hostOrchestratorServiceImpl) UploadFiles(uploadDir string, filenames []string) error { + if c.ChunkSizeBytes == 0 { + panic("ChunkSizeBytes value cannot be zero") + } + uploader := &FilesUploader{ + Client: c.httpHelper.Client, + EndpointURL: c.httpHelper.RootEndpoint + "/userartifacts/" + uploadDir, + ChunkSizeBytes: c.ChunkSizeBytes, + DumpOut: c.httpHelper.Dumpster, + NumWorkers: openConnections, + BackOffOpts: c.ChunkUploadBackOffOpts, + } + return uploader.Upload(filenames) +} diff --git a/pkg/client/host_orchestrator_client_test.go b/pkg/client/host_orchestrator_client_test.go new file mode 100644 index 00000000..b97eb3a4 --- /dev/null +++ b/pkg/client/host_orchestrator_client_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestUploadFilesChunkSizeBytesIsZeroPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("did not panic") + } + }() + opts := &ServiceOptions{ + RootEndpoint: "https://test.com", + DumpOut: io.Discard, + } + srv, _ := NewService(opts) + + srv.HostService("foo").UploadFiles("bar", []string{"baz"}) +} + +func TestUploadFilesSucceeds(t *testing.T) { + host := "foo" + uploadDir := "bar" + tempDir := createTempDir(t) + defer os.RemoveAll(tempDir) + quxFile := createTempFile(t, tempDir, "qux", []byte("lorem")) + waldoFile := createTempFile(t, tempDir, "waldo", []byte("l")) + xyzzyFile := createTempFile(t, tempDir, "xyzzy", []byte("abraca")) + mu := sync.Mutex{} + // expected uploads are keyed by format %filename %chunknumber of %chunktotal with the chunk content as value + uploads := map[string]struct{ Content []byte }{ + // qux + "qux 1 of 3": {Content: []byte("lo")}, + "qux 2 of 3": {Content: []byte("re")}, + "qux 3 of 3": {Content: []byte("m")}, + // waldo + "waldo 1 of 1": {Content: []byte("l")}, + // xyzzy + "xyzzy 1 of 3": {Content: []byte("ab")}, + "xyzzy 2 of 3": {Content: []byte("ra")}, + "xyzzy 3 of 3": {Content: []byte("ca")}, + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + switch ep := r.Method + " " + r.URL.Path; ep { + case "PUT /hosts/" + host + "/userartifacts/" + uploadDir: + chunkNumber := r.PostFormValue("chunk_number") + chunkTotal := r.PostFormValue("chunk_total") + f, fheader, err := r.FormFile("file") + if err != nil { + t.Fatal(err) + } + expectedUploadKey := fmt.Sprintf("%s %s of %s", fheader.Filename, chunkNumber, chunkTotal) + val, ok := uploads[expectedUploadKey] + if !ok { + t.Fatalf("unexpected upload with filename: %q, chunk number: %s, chunk total: %s", + fheader.Filename, chunkNumber, chunkTotal) + } + delete(uploads, expectedUploadKey) + b, err := io.ReadAll(f) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(val.Content, b); diff != "" { + t.Fatalf("chunk content mismatch %q (-want +got):\n%s", fheader.Filename, diff) + } + writeOK(w, struct{}{}) + default: + t.Fatal("unexpected endpoint: " + ep) + } + + })) + defer ts.Close() + opts := &ServiceOptions{ + RootEndpoint: ts.URL, + DumpOut: io.Discard, + ChunkSizeBytes: 2, + } + srv, _ := NewService(opts) + + err := srv.HostService(host).UploadFiles(uploadDir, []string{quxFile, waldoFile, xyzzyFile}) + if err != nil { + t.Fatal(err) + } + + if len(uploads) != 0 { + t.Errorf("missing chunk uploads: %v", uploads) + } +} + +func TestUploadFilesExponentialBackoff(t *testing.T) { + tempDir := createTempDir(t) + defer os.RemoveAll(tempDir) + waldoFile := createTempFile(t, tempDir, "waldo", []byte("l")) + timestamps := make([]time.Time, 0) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timestamps = append(timestamps, time.Now()) + if len(timestamps) < 3 { + writeErr(w, 500) + return + } + writeOK(w, struct{}{}) + })) + defer ts.Close() + opts := &ServiceOptions{ + RootEndpoint: ts.URL, + DumpOut: io.Discard, + ChunkSizeBytes: 2, + ChunkUploadBackOffOpts: BackOffOpts{ + InitialDuration: 100 * time.Millisecond, + Multiplier: 2, + MaxElapsedTime: 1 * time.Minute, + }, + } + srv, _ := NewService(opts) + + err := srv.HostService("foo").UploadFiles("bar", []string{waldoFile}) + + if err != nil { + t.Fatal(err) + } + if timestamps[1].Sub(timestamps[0]) < 100*time.Millisecond { + t.Fatal("first retry shouldn't be in less than 100ms") + } + if timestamps[2].Sub(timestamps[1]) < 200*time.Millisecond { + t.Fatal("first retry shouldn't be in less than 200ms") + } +} + +func TestUploadFilesExponentialBackoffReachedElapsedTime(t *testing.T) { + tempDir := createTempDir(t) + defer os.RemoveAll(tempDir) + waldoFile := createTempFile(t, tempDir, "waldo", []byte("l")) + attempts := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts = attempts + 1 + writeErr(w, 500) + })) + defer ts.Close() + opts := &ServiceOptions{ + RootEndpoint: ts.URL, + DumpOut: io.Discard, + ChunkSizeBytes: 2, + ChunkUploadBackOffOpts: BackOffOpts{ + InitialDuration: 100 * time.Millisecond, + Multiplier: 2, + MaxElapsedTime: 1 * time.Second, + }, + } + srv, _ := NewService(opts) + + err := srv.HostService("foo").UploadFiles("bar", []string{waldoFile}) + + if err == nil { + t.Fatal("expected error") + } + if attempts == 0 { + t.Fatal("server was never reached") + } +} + +func createTempDir(t *testing.T) string { + dir, err := os.MkdirTemp("", "cvdrTest") + if err != nil { + t.Fatal(err) + } + return dir +} + +func createTempFile(t *testing.T, dir, name string, content []byte) string { + file := filepath.Join(dir, name) + if err := os.WriteFile(file, content, 0666); err != nil { + t.Fatal(err) + } + return file +}