diff --git a/millennium.go b/millennium.go index e0b4b78..1f38d06 100644 --- a/millennium.go +++ b/millennium.go @@ -12,6 +12,7 @@ import ( "io" "net/http" "net/url" + "slices" "strings" "time" @@ -25,6 +26,7 @@ type AuthType string // Authentication types available for Millennium const ( NTLM AuthType = "NTLM" + Basic AuthType = "BASIC" Session AuthType = "SESSION" ) @@ -157,9 +159,7 @@ func (m *Millennium) Login(username string, password string, authType AuthType) m.Client.HTTPClient.Transport = ntlmssp.Negotiator{ RoundTripper: &http.Transport{}, } - } - - if authType == Session { + } else if authType == Session { var responseLogin ResponseLogin m.headers.Set("WTS-Authorization", fmt.Sprintf("%s/%s", strings.ToUpper(m.credentials.Username), strings.ToUpper(m.credentials.Password))) if err := m.Post("login", []byte{}, &responseLogin); err != nil { @@ -216,6 +216,7 @@ func (m *Millennium) Request(r RequestMethod) (err error) { requestBody := bodyReader req, err := retryablehttp.NewRequestWithContext(m.Context, requestMethod, requestURL, requestBody) + if err != nil { return fmt.Errorf("unable to start new request to Millennium: %w", err) } @@ -224,8 +225,8 @@ func (m *Millennium) Request(r RequestMethod) (err error) { req.Header = m.headers } - // If authType is NTLM, set basic auth on request - if m.credentials.AuthType == NTLM { + // If authType is NTLM or Basic, set basic auth on request + if m.credentials.AuthType == NTLM || m.credentials.AuthType == Basic { req.SetBasicAuth(m.credentials.Username, m.credentials.Password) } @@ -243,6 +244,12 @@ func (m *Millennium) sendRequest(request *retryablehttp.Request, response interf return fmt.Errorf("unable to send request: %w", err) } + if !slices.Contains([]int{http.StatusOK, http.StatusNoContent, http.StatusCreated, http.StatusTemporaryRedirect, http.StatusPermanentRedirect}, res.StatusCode) { + defer res.Body.Close() + + return fmt.Errorf("unable to send request: %s", res.Status) + } + return m.getResponse(res, &response) } diff --git a/millennium_test.go b/millennium_test.go index 35b974b..678bda0 100644 --- a/millennium_test.go +++ b/millennium_test.go @@ -151,6 +151,21 @@ func (s *mockHTTPServer) Start() *httptest.Server { Body: s.jsonError("Query error", http.StatusInternalServerError), }) }) + mux.HandleFunc("/api/test.basicauth", func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username != "correct_user" || password != "correct_password" { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "", http.StatusUnauthorized) + return + } + + s.writeOutput(&writeOutputParams{ + Writer: w, + Request: r, + StatusCode: 200, + Body: []byte(`{"odata.count": 1,"value":[{"number":1,"string":"test","bool":true}]}`), + }) + }) s.testServer = httptest.NewServer(mux) return s.testServer @@ -298,6 +313,35 @@ func TestNTLM(t *testing.T) { } } +func TestBasicAuth(t *testing.T) { + client := NewTestClient(t) + err := client.Login("test", "test", Basic) + if err != nil { + t.Error(err) + } + + var _r interface{} + + _, err = client.Get("test.basicauth", url.Values{}, &_r) + if err == nil { + t.Error("Expected error") + } + + err = client.Login("correct_user", "correct_password", Basic) + if err != nil { + t.Error(err) + } + + x, err := client.Get("test.basicauth", url.Values{}, &_r) + if err != nil { + t.Fatal(err) + } + + if x == 0 { + t.Error("Zero records returned") + } +} + func TestRequest(t *testing.T) { client := NewTestClient(t)