diff --git a/internal/experiment/wireguard/urlget.go b/internal/experiment/wireguard/urlget.go index 4f1e3c748..eae2f9a5e 100644 --- a/internal/experiment/wireguard/urlget.go +++ b/internal/experiment/wireguard/urlget.go @@ -1,12 +1,13 @@ package wireguard import ( - "io" + "context" "net/http" "time" "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) const ( @@ -16,7 +17,7 @@ const ( // urlget implements an straightforward urlget experiment using the standard library. // By default we pass the wireguard tunnel DialContext to the `http.Transport` on the `http.Client` creation. -func (m *Measurer) urlget(url string, zeroTime time.Time, logger model.Logger) *URLGetResult { +func (m *Measurer) urlget(ctx context.Context, url string, zeroTime time.Time, logger model.Logger) *URLGetResult { if m.dialContextFn == nil { m.dialContextFn = m.tnet.DialContext } @@ -35,7 +36,7 @@ func (m *Measurer) urlget(url string, zeroTime time.Time, logger model.Logger) * logger.Warnf("urlget error: %v", err.Error()) return newURLResultFromError(url, zeroTime, start, err) } - body, err := io.ReadAll(r.Body) + body, err := netxlite.ReadAllContext(ctx, r.Body) if err != nil { logger.Warnf("urlget error: %v", err.Error()) return newURLResultFromError(url, zeroTime, start, err) diff --git a/internal/experiment/wireguard/urlget_test.go b/internal/experiment/wireguard/urlget_test.go index 4add4e9f3..f044a6411 100644 --- a/internal/experiment/wireguard/urlget_test.go +++ b/internal/experiment/wireguard/urlget_test.go @@ -22,17 +22,17 @@ func (c *failingHttpClient) Get(string) (*http.Response, error) { func Test_urlget(t *testing.T) { t.Run("dummy server gets a URLGetResult, with no error", func(t *testing.T) { expected := "dummy data" - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(expected)) })) - defer svr.Close() + defer srv.Close() m := &Measurer{} m.dialContextFn = func(_ context.Context, network, address string) (net.Conn, error) { return net.Dial(network, address) } - r := m.urlget(svr.URL, time.Now(), model.DiscardLogger) + r := m.urlget(context.Background(), srv.URL, time.Now(), model.DiscardLogger) if r.StatusCode != 200 { t.Fatal("expected statusCode==200") } @@ -40,17 +40,17 @@ func Test_urlget(t *testing.T) { t.Run("dummy server gets a URLGetResult with 500 status code", func(t *testing.T) { expected := "dummy data" - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) w.Write([]byte(expected)) })) - defer svr.Close() + defer srv.Close() m := &Measurer{} m.dialContextFn = func(_ context.Context, network, address string) (net.Conn, error) { return net.Dial(network, address) } - r := m.urlget(svr.URL, time.Now(), model.DiscardLogger) + r := m.urlget(context.Background(), srv.URL, time.Now(), model.DiscardLogger) if r.StatusCode != 500 { t.Fatal("expected statusCode==500") } @@ -60,7 +60,7 @@ func Test_urlget(t *testing.T) { m := &Measurer{} m.httpClient = &failingHttpClient{} - r := m.urlget("http://example.org", time.Now(), model.DiscardLogger) + r := m.urlget(context.Background(), "http://example.org", time.Now(), model.DiscardLogger) expectedError := "unknown_failure: some error" if *r.Failure != expectedError { t.Fatal("expected error") diff --git a/internal/experiment/wireguard/wireguard.go b/internal/experiment/wireguard/wireguard.go index acabe54b3..9171ceb14 100644 --- a/internal/experiment/wireguard/wireguard.go +++ b/internal/experiment/wireguard/wireguard.go @@ -118,7 +118,7 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // 3. use tunnel if err == nil { sess.Logger().Info("Using the wireguard tunnel.") - urlgetResult := m.urlget(defaultURLGetTarget, zeroTime, sess.Logger()) + urlgetResult := m.urlget(ctx, defaultURLGetTarget, zeroTime, sess.Logger()) testkeys.URLGet = append(testkeys.URLGet, urlgetResult) testkeys.NetworkEvents = m.events.log() }