Skip to content

Commit

Permalink
Merge branch 'google:master' into sot-migration
Browse files Browse the repository at this point in the history
  • Loading branch information
torsm authored Aug 28, 2023
2 parents 21965de + 4159c23 commit 3be84a5
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 49 deletions.
9 changes: 5 additions & 4 deletions fleetspeak/src/server/components/components.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ func MakeComponents(cfg *cpb.Config) (*server.Components, error) {
l = &chttps.ProxyListener{l}

Check failure on line 94 in fleetspeak/src/server/components/components.go

View workflow job for this annotation

GitHub Actions / build-test-linux

github.com/google/fleetspeak/fleetspeak/src/server/components/https.ProxyListener struct literal uses unkeyed fields
}
comm, err = https.NewCommunicator(https.Params{
Listener: l,
Cert: []byte(hcfg.Certificates),
Key: []byte(hcfg.Key),
Streaming: !hcfg.DisableStreaming,
Listener: l,
Cert: []byte(hcfg.Certificates),
ClientCertHeader: hcfg.ClientCertificateHeader,
Key: []byte(hcfg.Key),
Streaming: !hcfg.DisableStreaming,
})
if err != nil {
return nil, fmt.Errorf("failed to create communicator: %v", err)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ message HttpsConfig {
// connection causes more active connections but can reduce database load and
// server->client communications latency.
bool disable_streaming = 4;

// If set, the server will validate the client certificate from the request header.
// This should be used if TLS is terminated at the load balancer and client certificates
// can be passed upstream to the fleetspeak server as an http header.
string client_certificate_header = 5;
}

message AdminConfig {
Expand Down
52 changes: 52 additions & 0 deletions fleetspeak/src/server/https/client_certificate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package https

import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/url"
)

// GetClientCert returns the client certificate from either the request header or TLS connection state.
func GetClientCert(req *http.Request, hn string) (*x509.Certificate, error) {
if hn != "" {
return getCertFromHeader(hn, req.Header)
} else {
return getCertFromTLS(req)
}
}

func getCertFromHeader(hn string, rh http.Header) (*x509.Certificate, error) {
headerCert := rh.Get(hn)
if headerCert == "" {
return nil, errors.New("no certificate found in header")
}
// Most certificates are URL PEM encoded
if decodedCert, err := url.PathUnescape(headerCert); err != nil {
return nil, err
} else {
headerCert = decodedCert
}
block, rest := pem.Decode([]byte(headerCert))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("failed to decode PEM block containing certificate")
}
if len(rest) != 0 {
return nil, errors.New("received more than 1 client cert")
}
cert, err := x509.ParseCertificate(block.Bytes)
return cert, err
}

func getCertFromTLS(req *http.Request) (*x509.Certificate, error) {
if req.TLS == nil {
return nil, errors.New("TLS information not found")
}
if len(req.TLS.PeerCertificates) != 1 {
return nil, fmt.Errorf("expected 1 client cert, received %v", len(req.TLS.PeerCertificates))
}
cert := req.TLS.PeerCertificates[0]
return cert, nil
}
3 changes: 2 additions & 1 deletion fleetspeak/src/server/https/https.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type Params struct {
Listener net.Listener // Where to listen for connections, required.
Cert, Key []byte // x509 encoded certificate and matching private key, required.
Streaming bool // Whether to enable streaming communications.
ClientCertHeader string // Where to locate the client certificate from the request header, if not provided use TLS request.
StreamingLifespan time.Duration // Maximum time to keep a streaming connection open, defaults to 10 min.
StreamingCloseTime time.Duration // How much of StreamingLifespan to allocate to an orderly stream close, defaults to 30 sec.
StreamingJitter time.Duration // Maximum amount of jitter to add to StreamingLifespan.
Expand All @@ -109,7 +110,7 @@ func NewCommunicator(p Params) (*Communicator, error) {
hs: http.Server{
Handler: mux,
TLSConfig: &tls.Config{
ClientAuth: tls.RequireAnyClientCert,
ClientAuth: tls.RequestClientCert,
Certificates: []tls.Certificate{c},
CipherSuites: []uint16{
// We may as well allow only the strongest (as far as we can guess)
Expand Down
130 changes: 120 additions & 10 deletions fleetspeak/src/server/https/https_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var (
serverCert []byte
)

func makeServer(t *testing.T, caseName string) (*server.Server, *sqlite.Datastore, string) {
func makeServer(t *testing.T, caseName, clientCertHeader string) (*server.Server, *sqlite.Datastore, string) {
cert, key, err := comtesting.ServerCert()
if err != nil {
t.Fatal(err)
Expand All @@ -66,7 +66,7 @@ func makeServer(t *testing.T, caseName string) (*server.Server, *sqlite.Datastor
if err != nil {
t.Fatal(err)
}
com, err := NewCommunicator(Params{Listener: tl, Cert: cert, Key: key, Streaming: true})
com, err := NewCommunicator(Params{Listener: tl, Cert: cert, Key: key, Streaming: true, ClientCertHeader: clientCertHeader})
if err != nil {
t.Fatal(err)
}
Expand All @@ -77,7 +77,7 @@ func makeServer(t *testing.T, caseName string) (*server.Server, *sqlite.Datastor
return ts.S, ts.DS, tl.Addr().String()
}

func makeClient(t *testing.T) (common.ClientID, *http.Client) {
func makeClient(t *testing.T) (common.ClientID, *http.Client, []byte) {
// Populate a CertPool with the server's certificate.
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(serverCert) {
Expand Down Expand Up @@ -129,14 +129,14 @@ func makeClient(t *testing.T) (common.ClientID, *http.Client) {
ExpectContinueTimeout: 1 * time.Second,
},
}
return id, &cl
return id, &cl, bc
}

func TestNormalPoll(t *testing.T) {
ctx := context.Background()

s, ds, addr := makeServer(t, "Normal")
id, cl := makeClient(t)
s, ds, addr := makeServer(t, "Normal", "")
id, cl, _ := makeClient(t)
defer s.Stop()

u := url.URL{Scheme: "https", Host: addr, Path: "/message"}
Expand Down Expand Up @@ -171,8 +171,8 @@ func TestNormalPoll(t *testing.T) {
func TestFile(t *testing.T) {
ctx := context.Background()

s, ds, addr := makeServer(t, "File")
_, cl := makeClient(t)
s, ds, addr := makeServer(t, "File", "")
_, cl, _ := makeClient(t)
defer s.Stop()

data := []byte("The quick sly fox jumped over the lazy dogs.")
Expand Down Expand Up @@ -241,8 +241,8 @@ func readContact(body *bufio.Reader) (*fspb.ContactData, error) {
func TestStreaming(t *testing.T) {
ctx := context.Background()

s, _, addr := makeServer(t, "Streaming")
_, cl := makeClient(t)
s, _, addr := makeServer(t, "Streaming", "")
_, cl, _ := makeClient(t)
defer s.Stop()

br, bw := io.Pipe()
Expand Down Expand Up @@ -299,3 +299,113 @@ func TestStreaming(t *testing.T) {
bw.Close()
resp.Body.Close()
}

func TestHeaderNormalPoll(t *testing.T) {
ctx := context.Background()

s, ds, addr := makeServer(t, "Normal", "ssl-client-cert")
id, cl, bc := makeClient(t)
defer s.Stop()

u := url.URL{Scheme: "https", Host: addr, Path: "/message"}

req, err := http.NewRequest("POST", u.String(), nil)
req.Close = true
cc := url.PathEscape(string(bc))
req.Header.Set("ssl-client-cert", cc)
if err != nil {
t.Fatal(err)
}

// An empty body is a valid, though atypical initial request.
req = req.WithContext(ctx)
resp, err := cl.Do(req)
if err != nil {
t.Fatal(err)
}

b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Error(err)
}
resp.Body.Close()

var cd fspb.ContactData
if err := proto.Unmarshal(b, &cd); err != nil {
t.Errorf("Unable to parse returned data as ContactData: %v", err)
}
if cd.SequencingNonce == 0 {
t.Error("Expected SequencingNonce in returned ContactData")
}

// The client should now exist in the datastore.
_, err = ds.GetClientData(ctx, id)
if err != nil {
t.Errorf("Error getting client data after poll: %v", err)
}
}

func TestHeaderStreaming(t *testing.T) {
ctx := context.Background()

s, _, addr := makeServer(t, "Streaming", "ssl-client-cert")
_, cl, bc := makeClient(t)
defer s.Stop()

br, bw := io.Pipe()
go func() {
// First exchange - these writes must happen during the http.Client.Do call
// below, because the server writes headers at the end of the first message
// exchange.

// Start with the magic number:
binary.Write(bw, binary.LittleEndian, magic)

if _, err := bw.Write(makeWrapped()); err != nil {
t.Error(err)
}
}()

u := url.URL{Scheme: "https", Host: addr, Path: "/streaming-message"}
req, err := http.NewRequest("POST", u.String(), br)
req.ContentLength = -1
req.Close = true
req.Header.Set("Expect", "100-continue")

cc := url.PathEscape(string(bc))
req.Header.Set("ssl-client-cert", cc)
if err != nil {
t.Fatal(err)
}
req = req.WithContext(ctx)
resp, err := cl.Do(req)
if err != nil {
t.Fatalf("Streaming post failed (%v): %v", resp, err)
}
// Read ContactData for first exchange.
body := bufio.NewReader(resp.Body)
cd, err := readContact(body)
if err != nil {
t.Error(err)
}
if cd.AckIndex != 0 {
t.Errorf("AckIndex of initial exchange should be unset, got %d", cd.AckIndex)
}

for i := uint64(1); i < 10; i++ {
// Write another WrappedContactData.
if _, err := bw.Write(makeWrapped()); err != nil {
t.Error(err)
}
cd, err := readContact(body)
if err != nil {
t.Error(err)
}
if cd.AckIndex != i {
t.Errorf("Received ack for contact %d, but expected %d", cd.AckIndex, i)
}
}

bw.Close()
resp.Body.Close()
}
11 changes: 3 additions & 8 deletions fleetspeak/src/server/https/message_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,13 @@ func (s messageServer) ServeHTTP(res http.ResponseWriter, req *http.Request) {
return
}

if req.TLS == nil {
pi.Status = http.StatusBadRequest
http.Error(res, "TLS information not found", pi.Status)
return
}
if len(req.TLS.PeerCertificates) != 1 {
cert, err := GetClientCert(req, s.p.ClientCertHeader)
if err != nil {
pi.Status = http.StatusBadRequest
http.Error(res, fmt.Sprintf("expected 1 client cert, received %v", len(req.TLS.PeerCertificates)), pi.Status)
http.Error(res, err.Error(), pi.Status)
return
}

cert := req.TLS.PeerCertificates[0]
if cert.PublicKey == nil {
pi.Status = http.StatusBadRequest
http.Error(res, "public key not present in client cert", pi.Status)
Expand Down
Loading

0 comments on commit 3be84a5

Please sign in to comment.