diff --git a/schemaregistry/internal/client_config.go b/schemaregistry/internal/client_config.go index 6578947e9..7016d803e 100644 --- a/schemaregistry/internal/client_config.go +++ b/schemaregistry/internal/client_config.go @@ -22,7 +22,7 @@ import ( // ClientConfig is used to pass multiple configuration options to the Schema Registry client. type ClientConfig struct { - // SchemaRegistryURL determines the URL of Schema Registry. + // SchemaRegistryURL is a comma-space separated list of URLs for the Schema Registry. SchemaRegistryURL string // BasicAuthUserInfo specifies the user info in the form of {username}:{password}. diff --git a/schemaregistry/internal/rest_service.go b/schemaregistry/internal/rest_service.go index 1daca2170..87f98393c 100644 --- a/schemaregistry/internal/rest_service.go +++ b/schemaregistry/internal/rest_service.go @@ -112,7 +112,7 @@ func NewRequest(method string, endpoint string, body interface{}, arguments ...i // RestService represents a REST client type RestService struct { - url *url.URL + urls []*url.URL headers http.Header maxRetries int retriesWaitMs int @@ -124,21 +124,22 @@ type RestService struct { // NewRestService returns a new REST client for the Confluent Schema Registry func NewRestService(conf *ClientConfig) (*RestService, error) { urlConf := conf.SchemaRegistryURL - u, err := url.Parse(urlConf) - - if err != nil { - return nil, err + urlStrs := strings.Split(urlConf, ",") + urls := make([]*url.URL, len(urlStrs)) + for i, urlStr := range urlStrs { + u, err := url.Parse(strings.TrimSpace(urlStr)) + if err != nil { + return nil, err + } + urls[i] = u } - headers, err := NewAuthHeader(u, conf) + headers, err := NewAuthHeader(urls[0], conf) if err != nil { return nil, err } headers.Add("Content-Type", "application/vnd.schemaregistry.v1+json") - if err != nil { - return nil, err - } if conf.HTTPClient == nil { transport, err := configureTransport(conf) @@ -155,7 +156,7 @@ func NewRestService(conf *ClientConfig) (*RestService, error) { } return &RestService{ - url: u, + urls: urls, headers: headers, maxRetries: conf.MaxRetries, retriesWaitMs: conf.RetriesWaitMs, @@ -337,19 +338,51 @@ func NewAuthHeader(service *url.URL, conf *ClientConfig) (http.Header, error) { return header, err } -// HandleRequest sends a HTTP(S) request to the Schema Registry, placing results into the response object +// HandleRequest sends a request to the Schema Registry, iterating over the list of URLs func (rs *RestService) HandleRequest(request *API, response interface{}) error { - urlPath := path.Join(rs.url.Path, fmt.Sprintf(request.endpoint, request.arguments...)) - endpoint, err := rs.url.Parse(urlPath) - if err != nil { + var resp *http.Response + var err error + for i, u := range rs.urls { + resp, err = rs.HandleHTTPRequest(u, request) + if err != nil { + if i == len(rs.urls)-1 { + return err + } + continue + } + if isSuccess(resp.StatusCode) || !isRetriable(resp.StatusCode) || i >= rs.maxRetries { + break + } + } + defer resp.Body.Close() + if isSuccess(resp.StatusCode) { + if err = json.NewDecoder(resp.Body).Decode(response); err != nil { + return err + } + return nil + } + + var failure rest.Error + if err = json.NewDecoder(resp.Body).Decode(&failure); err != nil { return err } + return &failure +} + +// HandleHTTPRequest sends a HTTP(S) request to the Schema Registry, placing results into the response object +func (rs *RestService) HandleHTTPRequest(url *url.URL, request *API) (*http.Response, error) { + urlPath := path.Join(url.Path, fmt.Sprintf(request.endpoint, request.arguments...)) + endpoint, err := url.Parse(urlPath) + if err != nil { + return nil, err + } + var readCloser io.ReadCloser if request.body != nil { outbuf, err := json.Marshal(request.body) if err != nil { - return err + return nil, err } readCloser = ioutil.NopCloser(bytes.NewBuffer(outbuf)) } @@ -365,30 +398,16 @@ func (rs *RestService) HandleRequest(request *API, response interface{}) error { for i := 0; i < rs.maxRetries+1; i++ { resp, err = rs.Do(req) if err != nil { - return err + return nil, err } if isSuccess(resp.StatusCode) || !isRetriable(resp.StatusCode) || i >= rs.maxRetries { - break + return resp, nil } time.Sleep(rs.fullJitter(i)) } - - defer resp.Body.Close() - if resp.StatusCode == 200 { - if err = json.NewDecoder(resp.Body).Decode(response); err != nil { - return err - } - return nil - } - - var failure rest.Error - if err := json.NewDecoder(resp.Body).Decode(&failure); err != nil { - return err - } - - return &failure + return nil, fmt.Errorf("failed to send request after %d retries", rs.maxRetries) } func (rs *RestService) fullJitter(retriesAttempted int) time.Duration { diff --git a/schemaregistry/serde/avrov2/avro_test.go b/schemaregistry/serde/avrov2/avro_test.go index 181742066..ab5b1363f 100644 --- a/schemaregistry/serde/avrov2/avro_test.go +++ b/schemaregistry/serde/avrov2/avro_test.go @@ -937,6 +937,8 @@ func TestAvroSerdeWithCELFieldTransformDisable(t *testing.T) { OnFailure: nil, Disabled: &[]bool{true}[0], }) + ser.RuleRegistry = ®istry + id, err := client.Register("topic1-value", info, false) serde.MaybeFail("Schema registration", err) if id <= 0 { @@ -960,7 +962,7 @@ func TestAvroSerdeWithCELFieldTransformDisable(t *testing.T) { deser.MessageFactory = testMessageFactory newobj, err := deser.Deserialize("topic1", bytes) - serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*DemoSchema).StringField, "hi")) } func TestAvroSerdeWithCELFieldTransform(t *testing.T) {