Skip to content

Commit

Permalink
use netxlite.ReadAllContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ainghazal committed Jun 25, 2024
1 parent 84cb927 commit 25e5661
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
7 changes: 4 additions & 3 deletions internal/experiment/wireguard/urlget.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package wireguard

import (
"io"
"context"
"net/http"
"time"

"github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

const (
Expand All @@ -16,7 +17,7 @@ const (

// urlget implements an straightforward urlget experiment using the standard library.
// By default we pass the wireguard tunnel DialContext to the `http.Transport` on the `http.Client` creation.
func (m *Measurer) urlget(url string, zeroTime time.Time, logger model.Logger) *URLGetResult {
func (m *Measurer) urlget(ctx context.Context, url string, zeroTime time.Time, logger model.Logger) *URLGetResult {
if m.dialContextFn == nil {
m.dialContextFn = m.tnet.DialContext
}
Expand All @@ -35,7 +36,7 @@ func (m *Measurer) urlget(url string, zeroTime time.Time, logger model.Logger) *
logger.Warnf("urlget error: %v", err.Error())
return newURLResultFromError(url, zeroTime, start, err)
}
body, err := io.ReadAll(r.Body)
body, err := netxlite.ReadAllContext(ctx, r.Body)
if err != nil {
logger.Warnf("urlget error: %v", err.Error())
return newURLResultFromError(url, zeroTime, start, err)
Expand Down
14 changes: 7 additions & 7 deletions internal/experiment/wireguard/urlget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,35 @@ func (c *failingHttpClient) Get(string) (*http.Response, error) {
func Test_urlget(t *testing.T) {
t.Run("dummy server gets a URLGetResult, with no error", func(t *testing.T) {
expected := "dummy data"
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(expected))
}))
defer svr.Close()
defer srv.Close()

m := &Measurer{}
m.dialContextFn = func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
r := m.urlget(svr.URL, time.Now(), model.DiscardLogger)
r := m.urlget(context.Background(), srv.URL, time.Now(), model.DiscardLogger)
if r.StatusCode != 200 {
t.Fatal("expected statusCode==200")
}
})

t.Run("dummy server gets a URLGetResult with 500 status code", func(t *testing.T) {
expected := "dummy data"
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
w.Write([]byte(expected))
}))
defer svr.Close()
defer srv.Close()

m := &Measurer{}
m.dialContextFn = func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}
r := m.urlget(svr.URL, time.Now(), model.DiscardLogger)
r := m.urlget(context.Background(), srv.URL, time.Now(), model.DiscardLogger)
if r.StatusCode != 500 {
t.Fatal("expected statusCode==500")
}
Expand All @@ -60,7 +60,7 @@ func Test_urlget(t *testing.T) {
m := &Measurer{}
m.httpClient = &failingHttpClient{}

r := m.urlget("http://example.org", time.Now(), model.DiscardLogger)
r := m.urlget(context.Background(), "http://example.org", time.Now(), model.DiscardLogger)
expectedError := "unknown_failure: some error"
if *r.Failure != expectedError {
t.Fatal("expected error")
Expand Down
2 changes: 1 addition & 1 deletion internal/experiment/wireguard/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error {
// 3. use tunnel
if err == nil {
sess.Logger().Info("Using the wireguard tunnel.")
urlgetResult := m.urlget(defaultURLGetTarget, zeroTime, sess.Logger())
urlgetResult := m.urlget(ctx, defaultURLGetTarget, zeroTime, sess.Logger())
testkeys.URLGet = append(testkeys.URLGet, urlgetResult)
testkeys.NetworkEvents = m.events.log()
}
Expand Down

0 comments on commit 25e5661

Please sign in to comment.