Skip to content

Commit

Permalink
Add support for getting signed x509 cert (#61)
Browse files Browse the repository at this point in the history
* Add support for getting signed x509 cert

* address comments

Co-authored-by: hkadakia <[email protected]>
  • Loading branch information
hkadakia and hkadakia authored Feb 17, 2021
1 parent c0d8212 commit 4724598
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 18 deletions.
49 changes: 36 additions & 13 deletions api/x509cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,43 @@ func (s *SigningService) GetX509CACertificate(ctx context.Context, keyMeta *prot
return nil, status.Errorf(codes.InvalidArgument, "Bad request: %v", err)
}

// Create a context with server side timeout.
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
defer cancel() // Cancel ctx as soon as GetX509CACertificate returns.

if !s.KeyUsages[config.X509CertEndpoint][keyMeta.Identifier] {
statusCode = http.StatusBadRequest
err = fmt.Errorf("cannot use key %q for %q", keyMeta.Identifier, config.X509CertEndpoint)
return nil, status.Errorf(codes.InvalidArgument, "Bad request: %v", err)
}

cert, err := s.GetX509CACert(ctx, keyMeta.Identifier)
if err != nil {
statusCode = http.StatusInternalServerError
return nil, status.Error(codes.Internal, "Internal server error")
type resp struct {
cert []byte
err error
}
respCh := make(chan resp)
go func() {
cert, err := s.GetX509CACert(ctx, keyMeta.Identifier)
respCh <- resp{cert, err}
}()

select {
case <-ctx.Done():
statusCode = http.StatusBadRequest
err = fmt.Errorf("client canceled request for %q", config.SSHHostCertEndpoint)
return nil, status.Errorf(codes.Canceled, "%v", err)
case <-reqCtx.Done():
// Handle the server timeout requests.
statusCode = http.StatusServiceUnavailable
err = fmt.Errorf("request timed out for %q", config.SSHHostCertEndpoint)
return nil, status.Errorf(codes.DeadlineExceeded, "%v", err)
case response := <-respCh:
if response.err != nil {
statusCode = http.StatusInternalServerError
return nil, status.Error(codes.Internal, "Internal server error")
}
return &proto.X509Certificate{Cert: string(response.cert)}, nil
}
return &proto.X509Certificate{Cert: string(cert)}, nil
}

// PostX509Certificate signs the given CSR using the specified key and returns a PEM encoded X509 certificate.
Expand All @@ -89,7 +114,7 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto
return nil, status.Errorf(codes.InvalidArgument, "Bad request: %v", err)
}

// create a context with server side timeout
// Create a context with server side timeout.
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
defer cancel() // Cancel ctx as soon as PostX509Certificate returns

Expand All @@ -113,24 +138,22 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto
}

type resp struct {
data []byte
cert []byte
err error
}
respCh := make(chan resp)
go func() {
data, err := s.SignX509Cert(reqCtx, req, request.KeyMeta.Identifier)
respCh <- resp{data, err}
cert, err := s.SignX509Cert(reqCtx, req, request.KeyMeta.Identifier)
respCh <- resp{cert, err}
}()

select {
case <-ctx.Done():
// client canceled the request. Cancel any pending server request and return
cancel()
statusCode = http.StatusBadRequest
err = fmt.Errorf("client canceled request for %q", config.X509CertEndpoint)
return nil, status.Errorf(codes.Canceled, "%v", err)
case <-reqCtx.Done():
// server request timed out.
// Handle the server timeout requests.
statusCode = http.StatusServiceUnavailable
err = fmt.Errorf("request timed out for %q", config.X509CertEndpoint)
return nil, status.Errorf(codes.DeadlineExceeded, "%v", err)
Expand All @@ -139,6 +162,6 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto
statusCode = http.StatusInternalServerError
return nil, status.Error(codes.Internal, "Internal server error")
}
return &proto.X509Certificate{Cert: string(response.data)}, nil
return &proto.X509Certificate{Cert: string(response.cert)}, nil
}
}
33 changes: 28 additions & 5 deletions api/x509cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,18 @@ func TestGetX509CertificateAvailableSigningKeys(t *testing.T) {

func TestGetX509CACertificate(t *testing.T) {
t.Parallel()
ctx := context.Background()
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
timeoutCtx, timeCancel := context.WithTimeout(ctx, timeout)
defer timeCancel()
testcases := map[string]struct {
ctx context.Context
KeyUsages map[string]map[string]bool
KeyMeta *proto.KeyMeta
// if expectedCert set to nil, we are expecting an error while testing
expectedCert *proto.X509Certificate
timeout time.Duration
}{
"emptyKeyUsages": {
KeyMeta: &proto.KeyMeta{Identifier: "randomid"},
Expand Down Expand Up @@ -116,24 +123,40 @@ func TestGetX509CACertificate(t *testing.T) {
KeyMeta: &proto.KeyMeta{Identifier: "x509id2"},
expectedCert: nil,
},
"requestTimeout": {
ctx: timeoutCtx,
KeyUsages: x509keyUsage,
KeyMeta: &proto.KeyMeta{Identifier: "x509id"},
expectedCert: nil,
timeout: timeout,
},
"requestCancelled": {
ctx: cancelCtx,
KeyUsages: x509keyUsage,
KeyMeta: &proto.KeyMeta{Identifier: "x509id"},
expectedCert: nil,
timeout: timeout,
},
}
for label, tt := range testcases {
tt := tt
label := label
if tt.ctx == nil {
tt.ctx = ctx
}
t.Run(label, func(t *testing.T) {
t.Parallel()
var ctx context.Context
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, timeout: tt.timeout}
ssBad := initMockSigningService(msspBad)
_, err := ssBad.GetX509CACertificate(ctx, tt.KeyMeta)
_, err := ssBad.GetX509CACertificate(tt.ctx, tt.KeyMeta)
if err == nil {
t.Fatalf("in test %v: bad signing service should return error but got nil", label)
}
// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, timeout: tt.timeout}
ssGood := initMockSigningService(msspGood)
cert, err := ssGood.GetX509CACertificate(ctx, tt.KeyMeta)
cert, err := ssGood.GetX509CACertificate(tt.ctx, tt.KeyMeta)
if err != nil && tt.expectedCert != nil {
t.Fatalf("in test %v: not expecting error but got error %v", label, err)
}
Expand Down

0 comments on commit 4724598

Please sign in to comment.