From ddcdda7b275dbfada8c5baf1a7a21caa219eddd7 Mon Sep 17 00:00:00 2001 From: schmidtw Date: Tue, 3 Oct 2023 07:34:19 -0700 Subject: [PATCH 1/4] Example comman is able to connect with mtls themis. Adding options to make the usage of credentials easier. --- config.go | 51 ++++++++++++++- go.mod | 4 +- go.sum | 13 ++-- internal/credentials/cmd/example/main.go | 62 ++++++++++++++---- internal/credentials/credentials.go | 82 +++++++++++++++++++----- internal/credentials/credentials_test.go | 55 ++++++++++++++++ internal/credentials/options.go | 28 ++++++++ 7 files changed, 262 insertions(+), 33 deletions(-) diff --git a/config.go b/config.go index 735937a..9b49917 100644 --- a/config.go +++ b/config.go @@ -6,15 +6,62 @@ package main import ( "fmt" "os" + "time" "github.com/goschtalt/goschtalt" + "github.com/xmidt-org/arrange/arrangehttp" "github.com/xmidt-org/sallust" + "github.com/xmidt-org/wrp-go/v3" "gopkg.in/dealancer/validate.v2" ) type Config struct { - SpecialValue string - Logger sallust.Config + Identity Identity + OperationalState OperationalState + XmidtCredentials XmidtCredentials + Logger sallust.Config +} + +type Identity struct { + DeviceID wrp.DeviceID + SerialNumber string + HardwareModel string + HardwareManufacturer string + FirmwareVersion string + PartnerID string +} + +type OperationalState struct { + LastRebootReason string + BootTime time.Time +} + +type XmidtCredentials struct { + URL string + HTTPClient arrangehttp.ClientConfig + RefetchPercent float64 +} + +type XmidtService struct { + URL string + JwtTxtRedirector JwtTxtRedirector + Backoff Backoff +} + +type JwtTxtRedirector struct { + Required bool + AllowedAlgorithms []string + Timeout time.Duration + PEMs []string + PEMFiles []string +} + +// Backoff defines the parameters that limit the retry backoff algorithm. +// The retries are a geometric progression. +// 1, 3, 7, 15, 31 ... n = (2n+1) +type Backoff struct { + MinDelay time.Duration + MaxDelay time.Duration } // Collect and process the configuration files and env vars and diff --git a/go.mod b/go.mod index 0a6fff1..c2fc6c0 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/goschtalt/yaml-decoder v0.0.1 github.com/goschtalt/yaml-encoder v0.0.3 github.com/stretchr/testify v1.8.4 + github.com/ugorji/go/codec v1.2.11 + github.com/xmidt-org/arrange v0.5.0 github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed github.com/xmidt-org/sallust v0.2.2 github.com/xmidt-org/wrp-go/v3 v3.2.0 @@ -26,7 +28,7 @@ require ( github.com/miekg/dns v1.1.56 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect + github.com/xmidt-org/httpaux v0.4.0 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.12.0 // indirect diff --git a/go.sum b/go.sum index 5948c85..6ef26c7 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,8 @@ github.com/goschtalt/yaml-encoder v0.0.3 h1:vfQ3vXZNvoEFPa3NzOWNtweYVa+2qMh8eqhX github.com/goschtalt/yaml-encoder v0.0.3/go.mod h1:E9ANM2mgRmoqP+JTFFv03fVWcnn+QrIDfVu5shDvX3A= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= @@ -47,6 +47,7 @@ github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XF github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -56,14 +57,18 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/xmidt-org/arrange v0.5.0 h1:ajkVHkr7dXnfCYm/6eafWoOab+6A3b2jEHQO0IdIIb0= +github.com/xmidt-org/arrange v0.5.0/go.mod h1:PoZB9lR49ma0osydQbaWpNeA3XPoLkjP5RYUoOw8wZU= github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed h1:KpcgFuumKrt/824H3gtmNI/IvgjsBo6rnlSnwXlFu60= github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed/go.mod h1:X9Og+8y1Llz7N8F20UmjZUNgrxHubMVfBcroJ5SPtIY= +github.com/xmidt-org/httpaux v0.4.0 h1:cAL/MzIBpSsv4xZZeq/Eu1J5M3vfNe49xr41mP3COKU= +github.com/xmidt-org/httpaux v0.4.0/go.mod h1:UypqZwuZV1nn8D6+K1JDb+im9IZrLNg/2oO/Bgiybxc= github.com/xmidt-org/sallust v0.2.2 h1:MrINLEr7cMj6ENx/O76fvpfd5LNGYnk7OipZAGXPYA0= github.com/xmidt-org/sallust v0.2.2/go.mod h1:ytBoypcPw10OmjM6b92Jx3eoqWX4J5zVXOQozGwz4qs= github.com/xmidt-org/wrp-go/v3 v3.2.0 h1:XX5c0ZJYaTEvlHFk0lzxadoOMbxg5YtUkPWNXHoxTDE= github.com/xmidt-org/wrp-go/v3 v3.2.0/go.mod h1:46ily/xzmRUhs8gSbTKNeOA6ztwcHauZFnfr4hRpoHA= -go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= -go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= go.uber.org/dig v1.17.0/go.mod h1:rTxpf7l5I0eBTlE6/9RL+lDybC7WFwY2QH55ZSjy1mU= go.uber.org/fx v1.20.0 h1:ZMC/pnRvhsthOZh9MZjMq5U8Or3mA9zBSPaLnzs3ihQ= diff --git a/internal/credentials/cmd/example/main.go b/internal/credentials/cmd/example/main.go index 9206e9f..45e931b 100644 --- a/internal/credentials/cmd/example/main.go +++ b/internal/credentials/cmd/example/main.go @@ -13,6 +13,7 @@ import ( "time" "github.com/alecthomas/kong" + "github.com/golang-jwt/jwt/v5" "github.com/xmidt-org/wrp-go/v3" cred "github.com/xmidt-org/xmidt-agent/internal/credentials" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" @@ -38,9 +39,9 @@ func main() { client := http.DefaultClient - if cli.Private != "" || cli.Public != "" || cli.CA != "" { - if cli.Private == "" || cli.Public == "" || cli.CA == "" { - panic("--private, --public and --ca must be specified together") + if cli.Private != "" || cli.Public != "" { + if cli.Private == "" || cli.Public == "" { + panic("--private and --public must be specified together") } cert, err := tls.LoadX509KeyPair(cli.Public, cli.Private) @@ -48,16 +49,18 @@ func main() { panic(err) } - caCert, err := os.ReadFile(cli.CA) - if err != nil { - panic(err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, + } + + if cli.CA != "" { + caCert, err := os.ReadFile(cli.CA) + if err != nil { + panic(err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool } tr := &http.Transport{TLSClientConfig: tlsConfig} @@ -117,4 +120,41 @@ func main() { defer cancel() credentials.WaitUntilFetched(ctx) + token, expires, err := credentials.Credentials() + if err != nil { + panic(err) + } + + fmt.Printf("JWT: %s\n", token) + fmt.Printf("Expires: %s\n", expires.Format(time.RFC3339)) + + claims := jwt.RegisteredClaims{} + parser := jwt.NewParser() + _, parts, err := parser.ParseUnverified(token, &claims) + if err != nil { + panic(err) + } + + fmt.Println("Claims:") + fmt.Printf(" ID: %s\n", claims.ID) + fmt.Printf(" ExpirationTime: %s\n", claims.ExpiresAt) + fmt.Printf(" IssuedAt: %s\n", claims.IssuedAt) + fmt.Printf(" NotBefore: %s\n", claims.NotBefore) + fmt.Printf(" Issuer: %s\n", claims.Issuer) + fmt.Printf(" Subject: %s\n", claims.Subject) + fmt.Printf(" Audience: %s\n", claims.Audience) + + header, err := parser.DecodeSegment(parts[0]) + if err != nil { + panic(err) + } + + body, err := parser.DecodeSegment(parts[1]) + if err != nil { + panic(err) + } + + fmt.Println("Parts:") + fmt.Printf(" Header: %s\n", header) + fmt.Printf(" Body: %s\n", body) } diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go index 03a58f9..9d8f74d 100644 --- a/internal/credentials/credentials.go +++ b/internal/credentials/credentials.go @@ -13,10 +13,13 @@ import ( "sync" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "github.com/ugorji/go/codec" "github.com/xmidt-org/eventor" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" + "github.com/xmidt-org/xmidt-agent/internal/fs" ) var ( @@ -57,6 +60,10 @@ type Credentials struct { url string refetchPercent float64 assumedLifetime time.Duration + ignoreBody bool + required bool + fs fs.FS + filename string client *http.Client macAddress wrp.DeviceID serialNumber string @@ -70,7 +77,7 @@ type Credentials struct { partnerID func() string // dynamic // What we are using to decorate the request. - token *xmidtToken + token *xmidtInfo } // Option is the interface implemented by types that can be used to @@ -186,9 +193,27 @@ func (c *Credentials) MarkInvalid(ctx context.Context) { } +func (c *Credentials) Credentials() (string, time.Time, error) { + c.m.RLock() + defer c.m.RUnlock() + if c.token == nil { + return "", time.Time{}, ErrNoToken + } + return c.token.Token, c.token.ExpiresAt, nil +} + // Decorate decorates the request with the credentials. If the credentials // are not valid, an error is returned. func (c *Credentials) Decorate(req *http.Request) error { + err := c.decorate(req) + if c.required && err != nil { + return err + } + + return nil +} + +func (c *Credentials) decorate(req *http.Request) error { var e event.Decorate if req == nil { @@ -199,15 +224,9 @@ func (c *Credentials) Decorate(req *http.Request) error { var token string var expiresAt time.Time - c.m.RLock() - if c.token != nil { - token = c.token.Token - expiresAt = c.token.ExpiresAt - } - c.m.RUnlock() + token, expiresAt, e.Err = c.Credentials() - if token == "" { - e.Err = ErrNoToken + if e.Err != nil { return c.dispatch(e) } @@ -223,7 +242,7 @@ func (c *Credentials) Decorate(req *http.Request) error { // fetch fetches the credentials from the server. This should only be called // by the run() method. -func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, error) { +func (c *Credentials) fetch(ctx context.Context) (*xmidtInfo, time.Duration, error) { var fe event.Fetch req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) @@ -275,7 +294,7 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, er return nil, retryIn, c.dispatch(fe) } - var token xmidtToken + var token xmidtInfo body, err := io.ReadAll(resp.Body) if err != nil { fe.Err = errors.Join(err, ErrFetchFailed) @@ -283,6 +302,14 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, er } token.Token = string(body) + c.determineExpiration(resp, &token) + + fe.Expiration = token.ExpiresAt + + return &token, 0, c.dispatch(fe) +} + +func (c *Credentials) determineExpiration(resp *http.Response, token *xmidtInfo) { // One hundred years is forever. token.ExpiresAt = c.nowFunc().Add(time.Hour * 24 * 365 * 100) if c.assumedLifetime > 0 { @@ -295,9 +322,17 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, er token.ExpiresAt = expiration } - fe.Expiration = token.ExpiresAt + if c.ignoreBody { + return + } - return &token, 0, c.dispatch(fe) + // If we are examining the JWT, parse it and get the expiration. + claims := jwt.RegisteredClaims{} + parser := jwt.NewParser() + _, _, err := parser.ParseUnverified(token.Token, &claims) + if err == nil && claims.ExpiresAt != nil { + token.ExpiresAt = claims.ExpiresAt.Time + } } // run is the main loop for the credentials service. @@ -359,6 +394,22 @@ func (c *Credentials) run(ctx context.Context) { } } +func (c *Credentials) store(token xmidtInfo) error { + if c.fs == nil { + return nil + } + + buf := make([]byte, 0, len(token.Token)) + handle := new(codec.MsgpackHandle) + enc := codec.NewEncoderBytes(&buf, handle) + err := enc.Encode(token) + if err != nil { + return err + } + + return c.fs.WriteFile(c.filename, []byte(c.token.Token), 0600) +} + // dispatch dispatches the event to the listeners and returns the error that // should be returned by the caller. func (c *Credentials) dispatch(evnt any) error { @@ -378,9 +429,10 @@ func (c *Credentials) dispatch(evnt any) error { panic("unknown event type") } -// xmidtToken is the token returned from the server as well as the expiration +// xmidtInfo is the token returned from the server as well as the expiration // time. -type xmidtToken struct { +type xmidtInfo struct { Token string ExpiresAt time.Time + Headers http.Header } diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go index d5df461..65b92a1 100644 --- a/internal/credentials/credentials_test.go +++ b/internal/credentials/credentials_test.go @@ -548,3 +548,58 @@ func TestDecorate(t *testing.T) { assert.Equal(2, count) } + +func TestEndToEndWithJwtPayload(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + when := time.Date(2023, 10, 30, 7, 4, 26, 0, time.UTC) + + token := `eyJhbGciOiJSUzI1NiIsImtpZCI6InRoZW1pcy0yMDE3MDEiLCJ0eXAiOiJKV1QifQ.` + + `eyJhdWQiOiJYTWlEVCIsImNhcGFiaWxpdGllcyI6WyJ4MTppc3N1ZXI6dGVzdDouKjphbGwi` + + `XSwiY3VzdG9tIjoicmJsIiwiZXhwIjoxNjk4Njc0NjY2LCJpYXQiOjE2OTYwODI2NjYsImlz` + + `cyI6InRoZW1pcyIsImp0aSI6IldUZDh3SlV0Rzc3SkNZd3lWelRxRnciLCJtYWMiOiIxMTIy` + + `MzM0NDU1NjYiLCJuYmYiOjE2OTYwODI1MTYsInBhcnRuZXItaWQiOiJjb21jYXN0Iiwic2Vy` + + `aWFsIjoiMTIzNDU2Nzg5MCIsInN1YiI6ImNsaWVudDpzdXBwbGllZCIsInRydXN0IjoxMDAw` + + `LCJ1dWlkIjoiMTczYTZlMjQtODgxOC00Nzk2LTgzNzYtNzdiOTA0NmJhZmVjIn0.invalid` + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + w.Header().Add("Expires", when.Format(http.TimeFormat)) + _, _ = w.Write([]byte(token)) + }, + ), + ) + defer server.Close() + + c, err := New( + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + assert.Equal(when.Format(http.TimeFormat), e.Expiration.Format(http.TimeFormat)) + assert.NoError(e.Err) + })), + ) + + require.NoError(err) + require.NotNil(c) + + c.Start() + defer c.Stop() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + c.WaitUntilValid(deadline) +} diff --git a/internal/credentials/options.go b/internal/credentials/options.go index 57874d9..0699f82 100644 --- a/internal/credentials/options.go +++ b/internal/credentials/options.go @@ -9,6 +9,7 @@ import ( "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" + "github.com/xmidt-org/xmidt-agent/internal/fs" ) type optionFunc func(*Credentials) error @@ -72,6 +73,33 @@ func AssumedLifetime(lifetime time.Duration) Option { }) } +// IgnoreBody is a flag that indicates whether the body of the response should +// be ignored instead of examined for an expiration time. The default is to +// examine the body. +func IgnoreBody() Option { + return nilOptionFunc( + func(c *Credentials) { + c.ignoreBody = true + }) +} + +// Required is a flag that indicates whether the credentials are required to +// successfully decorate a request. The default is optional. +func Required() Option { + return nilOptionFunc( + func(c *Credentials) { + c.required = true + }) +} + +func LocalStorage(fs fs.FS, filename string) Option { + return nilOptionFunc( + func(c *Credentials) { + c.fs = fs + c.filename = filename + }) +} + // MacAddress is the MAC address of the device. func MacAddress(macAddress wrp.DeviceID) Option { return nilOptionFunc( From 69f00274981e3578982fa86f62ed480c4bedef01 Mon Sep 17 00:00:00 2001 From: schmidtw Date: Tue, 3 Oct 2023 19:04:40 -0700 Subject: [PATCH 2/4] Improve the fs options and tests. --- internal/fs/fs_test.go | 99 +++++++++++++++++++++++++++++++++++------- internal/fs/mem/mem.go | 28 ++---------- internal/fs/options.go | 61 ++++++++++++++++++++------ 3 files changed, 135 insertions(+), 53 deletions(-) diff --git a/internal/fs/fs_test.go b/internal/fs/fs_test.go index 26dd464..7d3dcad 100644 --- a/internal/fs/fs_test.go +++ b/internal/fs/fs_test.go @@ -18,7 +18,7 @@ var ( errUnknown = errors.New("unknown error") ) -func TestMakeDir(t *testing.T) { +func TestWithDirOrSimiar(t *testing.T) { tests := []struct { description string opt xafs.Option @@ -29,40 +29,104 @@ func TestMakeDir(t *testing.T) { }{ { description: "simple path", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), expect: mem.New(mem.WithDir("foo", 0755)), }, { description: "simple existing path", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), start: mem.New(mem.WithDir("foo", 0755)), expect: mem.New(mem.WithDir("foo", 0755)), }, { description: "not a directory", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), start: mem.New(mem.WithFile("foo", "data", 0755)), expectErr: xafs.ErrNotDirectory, - }, { - description: "not able to read", - opt: xafs.MakeDir("foo", 0755), - start: mem.New(mem.WithDir("foo", 0111)), - expectErr: fs.ErrPermission, }, { description: "error opening the file", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), start: mem.New(mem.WithError("foo", errUnknown)), expectErr: errUnknown, }, { - description: "two directory path", + description: "three directory path", opts: []xafs.Option{ - xafs.MakeDir("foo", 0700), - xafs.MakeDir("foo/bar", 0750), - xafs.MakeDir("foo/bar/car", 0755), + xafs.WithDir("foo", 0700), + xafs.WithDir("foo/bar", 0750), + xafs.WithDir("foo/bar/car", 0755), }, expect: mem.New( mem.WithDir("foo", 0700), mem.WithDir("foo/bar", 0750), mem.WithDir("foo/bar/car", 0755), ), + }, { + description: "abs directory path", + start: mem.New(mem.WithDir("/", 0755)), + opts: []xafs.Option{ + xafs.WithDir("/foo", 0700), + xafs.WithDir("/foo/bar", 0750), + xafs.WithDir("/foo/bar/car", 0755), + }, + expect: mem.New( + mem.WithDir("/", 0755), + mem.WithDir("/foo", 0700), + mem.WithDir("/foo/bar", 0750), + mem.WithDir("/foo/bar/car", 0755), + ), + }, { + description: "WithDirs three directory path", + start: mem.New(), + opts: []xafs.Option{ + xafs.WithDirs("foo/bar/car", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0755), + mem.WithDir("foo/bar", 0755), + mem.WithDir("foo/bar/car", 0755), + ), + }, { + description: "WithDirs three directory path one exists", + start: mem.New(mem.WithDir("foo", 0711)), + opts: []xafs.Option{ + xafs.WithDirs("foo/bar/car", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0711), + mem.WithDir("foo/bar", 0755), + mem.WithDir("foo/bar/car", 0755), + ), + }, { + description: "abs three directory path", + start: mem.New(mem.WithDir("/", 0700)), + opts: []xafs.Option{ + xafs.WithDirs("/boo/egg/cat", 0755), + }, + expect: mem.New( + mem.WithDir("/", 0700), + mem.WithDir("/boo", 0755), + mem.WithDir("/boo/egg", 0755), + mem.WithDir("/boo/egg/cat", 0755), + ), + }, { + description: "WithPath two directory path one exists, and a filename", + start: mem.New(mem.WithDir("foo", 0711)), + opts: []xafs.Option{ + xafs.WithPath("foo/bar/car.json", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0711), + mem.WithDir("foo/bar", 0755), + ), + }, { + description: "Ensure Operate can handle nil options", + start: mem.New(mem.WithDir("foo", 0711)), + opts: []xafs.Option{ + nil, + xafs.WithPath("foo/bar/car.json", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0711), + mem.WithDir("foo/bar", 0755), + ), }, } @@ -71,7 +135,12 @@ func TestMakeDir(t *testing.T) { require := require.New(t) assert := assert.New(t) - opts := append(tc.opts, tc.opt) + opts := make([]xafs.Option, 0, len(tc.opts)+1) + if tc.opt != nil { + opts = append(tc.opts, tc.opt) + } + opts = append(opts, tc.opts...) + fs := tc.start if fs == nil { fs = mem.New() diff --git a/internal/fs/mem/mem.go b/internal/fs/mem/mem.go index 6e6e118..508972c 100644 --- a/internal/fs/mem/mem.go +++ b/internal/fs/mem/mem.go @@ -105,8 +105,10 @@ func (fs *FS) Open(name string) (iofs.File, error) { } func (fs *FS) Mkdir(path string, perm iofs.FileMode) error { - if err := fs.hasPerms(path, iofs.FileMode(0111)); err != nil { - return err + if path != separator { + if err := fs.hasPerms(path, iofs.FileMode(0111)); err != nil { + return err + } } if fs.Dirs == nil { @@ -201,25 +203,3 @@ func (fs *FS) hasPerms(name string, perm iofs.FileMode) error { return nil } - -// Remove this in favor of pp ... except I can't download pp right now. -/* -func (fs *FS) String() string { - buf := strings.Builder{} - - buf.WriteString("Files:\n") - for k, v := range fs.Files { - fmt.Fprintf(&buf, " '%s': '%s'\n", k, string(v.Bytes)) - } - buf.WriteString("Dirs:\n") - for k, v := range fs.Dirs { - fmt.Fprintf(&buf, " '%s': %v\n", k, v) - } - buf.WriteString("Errs:\n") - for k, v := range fs.Errs { - fmt.Fprintf(&buf, " '%s': %v\n", k, v) - } - - return buf.String() -} -*/ diff --git a/internal/fs/options.go b/internal/fs/options.go index f346086..b07b431 100644 --- a/internal/fs/options.go +++ b/internal/fs/options.go @@ -18,9 +18,17 @@ var ( ErrInvalidSHA = errors.New("invalid SHA for file") ) -// MakeDir is an option that ensures the specified directory exists with the -// specified permissions. -func MakeDir(dir string, perm fs.FileMode) Option { +// Options provides a way to group multiple options together. +func Options(opts ...Option) Option { + return OptionFunc( + func(f FS) error { + return Operate(f, opts...) + }) +} + +// WithDir is an option that ensures the specified directory exists. If it +// does not, create it with the specified permissions. +func WithDir(dir string, perm fs.FileMode) Option { return OptionFunc( func(f FS) error { file, err := f.Open(dir) @@ -33,21 +41,46 @@ func MakeDir(dir string, perm fs.FileMode) Option { defer file.Close() stat, err := file.Stat() - if err != nil { - return err + if err == nil { + if !stat.IsDir() { + return ErrNotDirectory + } } - if !stat.IsDir() { - return ErrNotDirectory - } + return err + }) +} - mode := stat.Mode() - if (mode & fs.ModePerm & perm) != (fs.ModePerm & perm) { - return fs.ErrPermission - } +// WithDirs is an option that ensures the specified directory path exists with +// the specified permissions. The path is split on the path separator and +// each directory is created in order if needed. +// +// Notes: +// - The path should not contain the filename or that will be created as a directory. +// - The same permissions are applied to all directories that are created. +func WithDirs(path string, perm fs.FileMode) Option { + dirs := strings.Split(path, string(filepath.Separator)) + if filepath.IsAbs(path) { + dirs[0] = string(filepath.Separator) + } + + var full string + opts := make([]Option, 0, len(dirs)) + for _, dir := range dirs { + full = filepath.Join(full, dir) + opts = append(opts, WithDir(full, perm)) + } + return Options(opts...) +} - return nil - }) +// WithPath is an option that ensures the set of directories for the specified +// file exists. The directory is determined by calling filepath.Dir on the name. +// +// Notes: +// - The name should contain the filename and any path to ensure is present. +// - The same permissions are applied to all directories that are created. +func WithPath(name string, perm fs.FileMode) Option { + return WithDirs(filepath.Dir(name), perm) } // WriteFileWithSHA256 calculates and writes both the file and a checksum file. From 1dd4e3b02d638c514a86ced32f596bd572f9bb23 Mon Sep 17 00:00:00 2001 From: schmidtw Date: Tue, 3 Oct 2023 19:08:06 -0700 Subject: [PATCH 3/4] Add a comment. --- internal/fs/fs.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/fs/fs.go b/internal/fs/fs.go index 6c9971b..f14bab5 100644 --- a/internal/fs/fs.go +++ b/internal/fs/fs.go @@ -24,6 +24,7 @@ type FS interface { // Option is an interface for options that can be applied in order via the Operate function. type Option interface { + // Apply applies the option to the filesystem. Apply(FS) error } From 98c5d8e1c057baf5e2c52ffdeaf910ad3ef52c2e Mon Sep 17 00:00:00 2001 From: schmidtw Date: Tue, 3 Oct 2023 23:06:36 -0700 Subject: [PATCH 4/4] Upate the credentials api to store them to the filesystem and retreive them if requested. --- internal/credentials/credentials.go | 93 ++++++++++++++++++------ internal/credentials/credentials_test.go | 80 ++++++++++++++------ internal/credentials/event/events.go | 3 + internal/credentials/options.go | 7 +- 4 files changed, 139 insertions(+), 44 deletions(-) diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go index 9d8f74d..97eab88 100644 --- a/internal/credentials/credentials.go +++ b/internal/credentials/credentials.go @@ -8,12 +8,12 @@ import ( "errors" "fmt" "io" + iofs "io/fs" "net/http" "strconv" "sync" "time" - "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/ugorji/go/codec" "github.com/xmidt-org/eventor" @@ -64,6 +64,7 @@ type Credentials struct { required bool fs fs.FS filename string + perm iofs.FileMode client *http.Client macAddress wrp.DeviceID serialNumber string @@ -196,9 +197,10 @@ func (c *Credentials) MarkInvalid(ctx context.Context) { func (c *Credentials) Credentials() (string, time.Time, error) { c.m.RLock() defer c.m.RUnlock() - if c.token == nil { + if c.token == nil || c.token.Token == "" { return "", time.Time{}, ErrNoToken } + return c.token.Token, c.token.ExpiresAt, nil } @@ -237,13 +239,16 @@ func (c *Credentials) decorate(req *http.Request) error { } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return c.dispatch(e) } // fetch fetches the credentials from the server. This should only be called // by the run() method. func (c *Credentials) fetch(ctx context.Context) (*xmidtInfo, time.Duration, error) { - var fe event.Fetch + fe := event.Fetch{ + Origin: "network", + } req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) if err != nil { @@ -321,38 +326,43 @@ func (c *Credentials) determineExpiration(resp *http.Response, token *xmidtInfo) // Even better, we were told when it expires. token.ExpiresAt = expiration } - - if c.ignoreBody { - return - } - - // If we are examining the JWT, parse it and get the expiration. - claims := jwt.RegisteredClaims{} - parser := jwt.NewParser() - _, _, err := parser.ParseUnverified(token.Token, &claims) - if err == nil && claims.ExpiresAt != nil { - token.ExpiresAt = claims.ExpiresAt.Time - } } // run is the main loop for the credentials service. func (c *Credentials) run(ctx context.Context) { var ( - timer *time.Timer - fetched bool - valid bool + timer *time.Timer + skipFetch bool + fromDisc bool + fetched bool + valid bool + retryIn time.Duration ) c.wg.Add(1) defer c.wg.Done() + token, err := c.load() + if err == nil && token != nil { + fromDisc = true + skipFetch = true + } + for { - token, retryIn, err := c.fetch(ctx) + if !skipFetch { + token, retryIn, err = c.fetch(ctx) + if err == nil { + fromDisc = false + } + } if !fetched { close(c.fetched) fetched = true } + // Only skip the fetch once. + skipFetch = false + // Assume we failed, so retry in 1 second or when the server suggested. next := max(time.Second, retryIn) @@ -368,6 +378,10 @@ func (c *Credentials) run(ctx context.Context) { valid = true } + if !fromDisc { + _ = c.store(token) + } + until := expires.Sub(c.nowFunc()) if 0 < until { // Add a timer to fetch the token again @@ -394,7 +408,7 @@ func (c *Credentials) run(ctx context.Context) { } } -func (c *Credentials) store(token xmidtInfo) error { +func (c *Credentials) store(token *xmidtInfo) error { if c.fs == nil { return nil } @@ -407,7 +421,43 @@ func (c *Credentials) store(token xmidtInfo) error { return err } - return c.fs.WriteFile(c.filename, []byte(c.token.Token), 0600) + return fs.Operate(c.fs, + fs.WithPath(c.filename, c.perm), + fs.WriteFileWithSHA256(c.filename, buf, c.perm)) +} + +func (c *Credentials) load() (*xmidtInfo, error) { + fe := event.Fetch{ + Origin: "fs", + } + + if c.fs == nil { + return nil, nil + } + + var buf []byte + + fe.At = time.Now() + err := fs.Operate(c.fs, + fs.WithPath(c.filename, c.perm), + fs.ReadFileWithSHA256(c.filename, &buf)) + fe.Duration = time.Since(fe.At) + if err != nil { + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, c.dispatch(fe) + } + + handle := new(codec.MsgpackHandle) + dec := codec.NewDecoderBytes(buf, handle) + + var token xmidtInfo + err = dec.Decode(&token) + if err != nil { + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, c.dispatch(fe) + } + fe.Expiration = token.ExpiresAt + return &token, c.dispatch(fe) } // dispatch dispatches the event to the listeners and returns the error that @@ -434,5 +484,4 @@ func (c *Credentials) dispatch(evnt any) error { type xmidtInfo struct { Token string ExpiresAt time.Time - Headers http.Header } diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go index 65b92a1..624f5a3 100644 --- a/internal/credentials/credentials_test.go +++ b/internal/credentials/credentials_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" + "github.com/xmidt-org/xmidt-agent/internal/fs/mem" ) func TestNew(t *testing.T) { @@ -256,7 +257,7 @@ func TestEndToEnd429(t *testing.T) { defer c.Stop() ctx := context.Background() - deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) defer cancel() c.WaitUntilFetched(deadline) assert.Equal(1, called) @@ -385,6 +386,8 @@ func TestEndToEnd(t *testing.T) { ) defer server.Close() + fs := mem.New(mem.WithDir(".", 0755)) + c, err := New( URL(server.URL), MacAddress(wrp.DeviceID("mac:112233445566")), @@ -396,6 +399,7 @@ func TestEndToEnd(t *testing.T) { XmidtProtocol("protocol"), BootRetryWait(1), AssumedLifetime(24*time.Hour), + LocalStorage(fs, "credentials.msgpack", 0600), AddFetchListener(event.FetchListenerFunc( func(e event.Fetch) { fmt.Println("Fetch:") @@ -452,6 +456,12 @@ func TestEndToEnd(t *testing.T) { // Multiple calls to Stop is ok. c.Stop() + + _, found := fs.Files["credentials.msgpack"] + assert.True(found) + + _, found = fs.Files["credentials.msgpack.sha256"] + assert.True(found) } func TestContextExpires(t *testing.T) { @@ -510,6 +520,7 @@ func TestDecorate(t *testing.T) { LastRebootReason("reason"), XmidtProtocol("protocol"), BootRetryWait(1), + Required(), AddFetchListener(event.FetchListenerFunc( func(e event.Fetch) { assert.NoError(e.Err) @@ -549,33 +560,27 @@ func TestDecorate(t *testing.T) { assert.Equal(2, count) } -func TestEndToEndWithJwtPayload(t *testing.T) { +func TestToAndFromFile(t *testing.T) { assert := assert.New(t) require := require.New(t) - when := time.Date(2023, 10, 30, 7, 4, 26, 0, time.UTC) - - token := `eyJhbGciOiJSUzI1NiIsImtpZCI6InRoZW1pcy0yMDE3MDEiLCJ0eXAiOiJKV1QifQ.` + - `eyJhdWQiOiJYTWlEVCIsImNhcGFiaWxpdGllcyI6WyJ4MTppc3N1ZXI6dGVzdDouKjphbGwi` + - `XSwiY3VzdG9tIjoicmJsIiwiZXhwIjoxNjk4Njc0NjY2LCJpYXQiOjE2OTYwODI2NjYsImlz` + - `cyI6InRoZW1pcyIsImp0aSI6IldUZDh3SlV0Rzc3SkNZd3lWelRxRnciLCJtYWMiOiIxMTIy` + - `MzM0NDU1NjYiLCJuYmYiOjE2OTYwODI1MTYsInBhcnRuZXItaWQiOiJjb21jYXN0Iiwic2Vy` + - `aWFsIjoiMTIzNDU2Nzg5MCIsInN1YiI6ImNsaWVudDpzdXBwbGllZCIsInRydXN0IjoxMDAw` + - `LCJ1dWlkIjoiMTczYTZlMjQtODgxOC00Nzk2LTgzNzYtNzdiOTA0NmJhZmVjIn0.invalid` - + var count int server := httptest.NewServer( http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { r.Body.Close() - w.Header().Add("Expires", when.Format(http.TimeFormat)) - _, _ = w.Write([]byte(token)) + w.Header().Add("Expires", time.Now().Add(1*time.Hour).Format(http.TimeFormat)) + _, _ = w.Write([]byte(`token`)) + count++ }, ), ) defer server.Close() - c, err := New( + fs := mem.New(mem.WithDir(".", 0755)) + + opts := []Option{ URL(server.URL), MacAddress(wrp.DeviceID("mac:112233445566")), SerialNumber("1234567890"), @@ -585,21 +590,54 @@ func TestEndToEndWithJwtPayload(t *testing.T) { LastRebootReason("reason"), XmidtProtocol("protocol"), BootRetryWait(1), + LocalStorage(fs, "credentials.msgpack", 0600), + } + + var listenerCount int + copts := append(opts, AddFetchListener(event.FetchListenerFunc( func(e event.Fetch) { - assert.Equal(when.Format(http.TimeFormat), e.Expiration.Format(http.TimeFormat)) - assert.NoError(e.Err) - })), - ) + if listenerCount == 0 { + assert.Equal("fs", e.Origin) + assert.Error(e.Err) + } else { + assert.Equal("network", e.Origin) + assert.NoError(e.Err) + } + listenerCount++ + }))) + + c, err := New(copts...) require.NoError(err) require.NotNil(c) c.Start() - defer c.Stop() ctx := context.Background() deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) defer cancel() - c.WaitUntilValid(deadline) + c.WaitUntilFetched(deadline) + + c.Stop() + + dopts := append(opts, + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + assert.Equal("fs", e.Origin) + assert.NoError(e.Err) + }))) + + d, err := New(dopts...) + require.NoError(err) + + d.Start() + ctx = context.Background() + deadline, cancel = context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + d.WaitUntilFetched(deadline) + + d.Stop() + + assert.Equal(1, count) } diff --git a/internal/credentials/event/events.go b/internal/credentials/event/events.go index 91607f1..1be54ac 100644 --- a/internal/credentials/event/events.go +++ b/internal/credentials/event/events.go @@ -15,6 +15,9 @@ type CancelListenerFunc func() // Fetch is the event that is sent when the credentials are fetched. type Fetch struct { + // The origin of the data - "fs" or "network" are the only valid values. + Origin string + // At holds the time when the fetch request was made. At time.Time diff --git a/internal/credentials/options.go b/internal/credentials/options.go index 0699f82..57fe799 100644 --- a/internal/credentials/options.go +++ b/internal/credentials/options.go @@ -4,6 +4,7 @@ package credentials import ( + iofs "io/fs" "net/http" "time" @@ -92,11 +93,15 @@ func Required() Option { }) } -func LocalStorage(fs fs.FS, filename string) Option { +// LocalStorage is the local storage used to cache the credentials. +// +// The filename (and path) is relative to the provided filesystem. +func LocalStorage(fs fs.FS, filename string, perm iofs.FileMode) Option { return nilOptionFunc( func(c *Credentials) { c.fs = fs c.filename = filename + c.perm = perm }) }