diff --git a/config.go b/config.go index 735937a..9b49917 100644 --- a/config.go +++ b/config.go @@ -6,15 +6,62 @@ package main import ( "fmt" "os" + "time" "github.com/goschtalt/goschtalt" + "github.com/xmidt-org/arrange/arrangehttp" "github.com/xmidt-org/sallust" + "github.com/xmidt-org/wrp-go/v3" "gopkg.in/dealancer/validate.v2" ) type Config struct { - SpecialValue string - Logger sallust.Config + Identity Identity + OperationalState OperationalState + XmidtCredentials XmidtCredentials + Logger sallust.Config +} + +type Identity struct { + DeviceID wrp.DeviceID + SerialNumber string + HardwareModel string + HardwareManufacturer string + FirmwareVersion string + PartnerID string +} + +type OperationalState struct { + LastRebootReason string + BootTime time.Time +} + +type XmidtCredentials struct { + URL string + HTTPClient arrangehttp.ClientConfig + RefetchPercent float64 +} + +type XmidtService struct { + URL string + JwtTxtRedirector JwtTxtRedirector + Backoff Backoff +} + +type JwtTxtRedirector struct { + Required bool + AllowedAlgorithms []string + Timeout time.Duration + PEMs []string + PEMFiles []string +} + +// Backoff defines the parameters that limit the retry backoff algorithm. +// The retries are a geometric progression. +// 1, 3, 7, 15, 31 ... n = (2n+1) +type Backoff struct { + MinDelay time.Duration + MaxDelay time.Duration } // Collect and process the configuration files and env vars and diff --git a/go.mod b/go.mod index 0a6fff1..c2fc6c0 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/goschtalt/yaml-decoder v0.0.1 github.com/goschtalt/yaml-encoder v0.0.3 github.com/stretchr/testify v1.8.4 + github.com/ugorji/go/codec v1.2.11 + github.com/xmidt-org/arrange v0.5.0 github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed github.com/xmidt-org/sallust v0.2.2 github.com/xmidt-org/wrp-go/v3 v3.2.0 @@ -26,7 +28,7 @@ require ( github.com/miekg/dns v1.1.56 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect + github.com/xmidt-org/httpaux v0.4.0 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.12.0 // indirect diff --git a/go.sum b/go.sum index 5948c85..6ef26c7 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,8 @@ github.com/goschtalt/yaml-encoder v0.0.3 h1:vfQ3vXZNvoEFPa3NzOWNtweYVa+2qMh8eqhX github.com/goschtalt/yaml-encoder v0.0.3/go.mod h1:E9ANM2mgRmoqP+JTFFv03fVWcnn+QrIDfVu5shDvX3A= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= @@ -47,6 +47,7 @@ github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XF github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -56,14 +57,18 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/xmidt-org/arrange v0.5.0 h1:ajkVHkr7dXnfCYm/6eafWoOab+6A3b2jEHQO0IdIIb0= +github.com/xmidt-org/arrange v0.5.0/go.mod h1:PoZB9lR49ma0osydQbaWpNeA3XPoLkjP5RYUoOw8wZU= github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed h1:KpcgFuumKrt/824H3gtmNI/IvgjsBo6rnlSnwXlFu60= github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed/go.mod h1:X9Og+8y1Llz7N8F20UmjZUNgrxHubMVfBcroJ5SPtIY= +github.com/xmidt-org/httpaux v0.4.0 h1:cAL/MzIBpSsv4xZZeq/Eu1J5M3vfNe49xr41mP3COKU= +github.com/xmidt-org/httpaux v0.4.0/go.mod h1:UypqZwuZV1nn8D6+K1JDb+im9IZrLNg/2oO/Bgiybxc= github.com/xmidt-org/sallust v0.2.2 h1:MrINLEr7cMj6ENx/O76fvpfd5LNGYnk7OipZAGXPYA0= github.com/xmidt-org/sallust v0.2.2/go.mod h1:ytBoypcPw10OmjM6b92Jx3eoqWX4J5zVXOQozGwz4qs= github.com/xmidt-org/wrp-go/v3 v3.2.0 h1:XX5c0ZJYaTEvlHFk0lzxadoOMbxg5YtUkPWNXHoxTDE= github.com/xmidt-org/wrp-go/v3 v3.2.0/go.mod h1:46ily/xzmRUhs8gSbTKNeOA6ztwcHauZFnfr4hRpoHA= -go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= -go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= go.uber.org/dig v1.17.0/go.mod h1:rTxpf7l5I0eBTlE6/9RL+lDybC7WFwY2QH55ZSjy1mU= go.uber.org/fx v1.20.0 h1:ZMC/pnRvhsthOZh9MZjMq5U8Or3mA9zBSPaLnzs3ihQ= diff --git a/internal/credentials/cmd/example/main.go b/internal/credentials/cmd/example/main.go index 9206e9f..45e931b 100644 --- a/internal/credentials/cmd/example/main.go +++ b/internal/credentials/cmd/example/main.go @@ -13,6 +13,7 @@ import ( "time" "github.com/alecthomas/kong" + "github.com/golang-jwt/jwt/v5" "github.com/xmidt-org/wrp-go/v3" cred "github.com/xmidt-org/xmidt-agent/internal/credentials" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" @@ -38,9 +39,9 @@ func main() { client := http.DefaultClient - if cli.Private != "" || cli.Public != "" || cli.CA != "" { - if cli.Private == "" || cli.Public == "" || cli.CA == "" { - panic("--private, --public and --ca must be specified together") + if cli.Private != "" || cli.Public != "" { + if cli.Private == "" || cli.Public == "" { + panic("--private and --public must be specified together") } cert, err := tls.LoadX509KeyPair(cli.Public, cli.Private) @@ -48,16 +49,18 @@ func main() { panic(err) } - caCert, err := os.ReadFile(cli.CA) - if err != nil { - panic(err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, + } + + if cli.CA != "" { + caCert, err := os.ReadFile(cli.CA) + if err != nil { + panic(err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool } tr := &http.Transport{TLSClientConfig: tlsConfig} @@ -117,4 +120,41 @@ func main() { defer cancel() credentials.WaitUntilFetched(ctx) + token, expires, err := credentials.Credentials() + if err != nil { + panic(err) + } + + fmt.Printf("JWT: %s\n", token) + fmt.Printf("Expires: %s\n", expires.Format(time.RFC3339)) + + claims := jwt.RegisteredClaims{} + parser := jwt.NewParser() + _, parts, err := parser.ParseUnverified(token, &claims) + if err != nil { + panic(err) + } + + fmt.Println("Claims:") + fmt.Printf(" ID: %s\n", claims.ID) + fmt.Printf(" ExpirationTime: %s\n", claims.ExpiresAt) + fmt.Printf(" IssuedAt: %s\n", claims.IssuedAt) + fmt.Printf(" NotBefore: %s\n", claims.NotBefore) + fmt.Printf(" Issuer: %s\n", claims.Issuer) + fmt.Printf(" Subject: %s\n", claims.Subject) + fmt.Printf(" Audience: %s\n", claims.Audience) + + header, err := parser.DecodeSegment(parts[0]) + if err != nil { + panic(err) + } + + body, err := parser.DecodeSegment(parts[1]) + if err != nil { + panic(err) + } + + fmt.Println("Parts:") + fmt.Printf(" Header: %s\n", header) + fmt.Printf(" Body: %s\n", body) } diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go index 03a58f9..97eab88 100644 --- a/internal/credentials/credentials.go +++ b/internal/credentials/credentials.go @@ -8,15 +8,18 @@ import ( "errors" "fmt" "io" + iofs "io/fs" "net/http" "strconv" "sync" "time" "github.com/google/uuid" + "github.com/ugorji/go/codec" "github.com/xmidt-org/eventor" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" + "github.com/xmidt-org/xmidt-agent/internal/fs" ) var ( @@ -57,6 +60,11 @@ type Credentials struct { url string refetchPercent float64 assumedLifetime time.Duration + ignoreBody bool + required bool + fs fs.FS + filename string + perm iofs.FileMode client *http.Client macAddress wrp.DeviceID serialNumber string @@ -70,7 +78,7 @@ type Credentials struct { partnerID func() string // dynamic // What we are using to decorate the request. - token *xmidtToken + token *xmidtInfo } // Option is the interface implemented by types that can be used to @@ -186,9 +194,28 @@ func (c *Credentials) MarkInvalid(ctx context.Context) { } +func (c *Credentials) Credentials() (string, time.Time, error) { + c.m.RLock() + defer c.m.RUnlock() + if c.token == nil || c.token.Token == "" { + return "", time.Time{}, ErrNoToken + } + + return c.token.Token, c.token.ExpiresAt, nil +} + // Decorate decorates the request with the credentials. If the credentials // are not valid, an error is returned. func (c *Credentials) Decorate(req *http.Request) error { + err := c.decorate(req) + if c.required && err != nil { + return err + } + + return nil +} + +func (c *Credentials) decorate(req *http.Request) error { var e event.Decorate if req == nil { @@ -199,15 +226,9 @@ func (c *Credentials) Decorate(req *http.Request) error { var token string var expiresAt time.Time - c.m.RLock() - if c.token != nil { - token = c.token.Token - expiresAt = c.token.ExpiresAt - } - c.m.RUnlock() + token, expiresAt, e.Err = c.Credentials() - if token == "" { - e.Err = ErrNoToken + if e.Err != nil { return c.dispatch(e) } @@ -218,13 +239,16 @@ func (c *Credentials) Decorate(req *http.Request) error { } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return c.dispatch(e) } // fetch fetches the credentials from the server. This should only be called // by the run() method. -func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, error) { - var fe event.Fetch +func (c *Credentials) fetch(ctx context.Context) (*xmidtInfo, time.Duration, error) { + fe := event.Fetch{ + Origin: "network", + } req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) if err != nil { @@ -275,7 +299,7 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, er return nil, retryIn, c.dispatch(fe) } - var token xmidtToken + var token xmidtInfo body, err := io.ReadAll(resp.Body) if err != nil { fe.Err = errors.Join(err, ErrFetchFailed) @@ -283,6 +307,14 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, er } token.Token = string(body) + c.determineExpiration(resp, &token) + + fe.Expiration = token.ExpiresAt + + return &token, 0, c.dispatch(fe) +} + +func (c *Credentials) determineExpiration(resp *http.Response, token *xmidtInfo) { // One hundred years is forever. token.ExpiresAt = c.nowFunc().Add(time.Hour * 24 * 365 * 100) if c.assumedLifetime > 0 { @@ -294,30 +326,43 @@ func (c *Credentials) fetch(ctx context.Context) (*xmidtToken, time.Duration, er // Even better, we were told when it expires. token.ExpiresAt = expiration } - - fe.Expiration = token.ExpiresAt - - return &token, 0, c.dispatch(fe) } // run is the main loop for the credentials service. func (c *Credentials) run(ctx context.Context) { var ( - timer *time.Timer - fetched bool - valid bool + timer *time.Timer + skipFetch bool + fromDisc bool + fetched bool + valid bool + retryIn time.Duration ) c.wg.Add(1) defer c.wg.Done() + token, err := c.load() + if err == nil && token != nil { + fromDisc = true + skipFetch = true + } + for { - token, retryIn, err := c.fetch(ctx) + if !skipFetch { + token, retryIn, err = c.fetch(ctx) + if err == nil { + fromDisc = false + } + } if !fetched { close(c.fetched) fetched = true } + // Only skip the fetch once. + skipFetch = false + // Assume we failed, so retry in 1 second or when the server suggested. next := max(time.Second, retryIn) @@ -333,6 +378,10 @@ func (c *Credentials) run(ctx context.Context) { valid = true } + if !fromDisc { + _ = c.store(token) + } + until := expires.Sub(c.nowFunc()) if 0 < until { // Add a timer to fetch the token again @@ -359,6 +408,58 @@ func (c *Credentials) run(ctx context.Context) { } } +func (c *Credentials) store(token *xmidtInfo) error { + if c.fs == nil { + return nil + } + + buf := make([]byte, 0, len(token.Token)) + handle := new(codec.MsgpackHandle) + enc := codec.NewEncoderBytes(&buf, handle) + err := enc.Encode(token) + if err != nil { + return err + } + + return fs.Operate(c.fs, + fs.WithPath(c.filename, c.perm), + fs.WriteFileWithSHA256(c.filename, buf, c.perm)) +} + +func (c *Credentials) load() (*xmidtInfo, error) { + fe := event.Fetch{ + Origin: "fs", + } + + if c.fs == nil { + return nil, nil + } + + var buf []byte + + fe.At = time.Now() + err := fs.Operate(c.fs, + fs.WithPath(c.filename, c.perm), + fs.ReadFileWithSHA256(c.filename, &buf)) + fe.Duration = time.Since(fe.At) + if err != nil { + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, c.dispatch(fe) + } + + handle := new(codec.MsgpackHandle) + dec := codec.NewDecoderBytes(buf, handle) + + var token xmidtInfo + err = dec.Decode(&token) + if err != nil { + fe.Err = errors.Join(err, ErrFetchFailed) + return nil, c.dispatch(fe) + } + fe.Expiration = token.ExpiresAt + return &token, c.dispatch(fe) +} + // dispatch dispatches the event to the listeners and returns the error that // should be returned by the caller. func (c *Credentials) dispatch(evnt any) error { @@ -378,9 +479,9 @@ func (c *Credentials) dispatch(evnt any) error { panic("unknown event type") } -// xmidtToken is the token returned from the server as well as the expiration +// xmidtInfo is the token returned from the server as well as the expiration // time. -type xmidtToken struct { +type xmidtInfo struct { Token string ExpiresAt time.Time } diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go index d5df461..624f5a3 100644 --- a/internal/credentials/credentials_test.go +++ b/internal/credentials/credentials_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" + "github.com/xmidt-org/xmidt-agent/internal/fs/mem" ) func TestNew(t *testing.T) { @@ -256,7 +257,7 @@ func TestEndToEnd429(t *testing.T) { defer c.Stop() ctx := context.Background() - deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond)) defer cancel() c.WaitUntilFetched(deadline) assert.Equal(1, called) @@ -385,6 +386,8 @@ func TestEndToEnd(t *testing.T) { ) defer server.Close() + fs := mem.New(mem.WithDir(".", 0755)) + c, err := New( URL(server.URL), MacAddress(wrp.DeviceID("mac:112233445566")), @@ -396,6 +399,7 @@ func TestEndToEnd(t *testing.T) { XmidtProtocol("protocol"), BootRetryWait(1), AssumedLifetime(24*time.Hour), + LocalStorage(fs, "credentials.msgpack", 0600), AddFetchListener(event.FetchListenerFunc( func(e event.Fetch) { fmt.Println("Fetch:") @@ -452,6 +456,12 @@ func TestEndToEnd(t *testing.T) { // Multiple calls to Stop is ok. c.Stop() + + _, found := fs.Files["credentials.msgpack"] + assert.True(found) + + _, found = fs.Files["credentials.msgpack.sha256"] + assert.True(found) } func TestContextExpires(t *testing.T) { @@ -510,6 +520,7 @@ func TestDecorate(t *testing.T) { LastRebootReason("reason"), XmidtProtocol("protocol"), BootRetryWait(1), + Required(), AddFetchListener(event.FetchListenerFunc( func(e event.Fetch) { assert.NoError(e.Err) @@ -548,3 +559,85 @@ func TestDecorate(t *testing.T) { assert.Equal(2, count) } + +func TestToAndFromFile(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + var count int + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + + w.Header().Add("Expires", time.Now().Add(1*time.Hour).Format(http.TimeFormat)) + _, _ = w.Write([]byte(`token`)) + count++ + }, + ), + ) + defer server.Close() + + fs := mem.New(mem.WithDir(".", 0755)) + + opts := []Option{ + URL(server.URL), + MacAddress(wrp.DeviceID("mac:112233445566")), + SerialNumber("1234567890"), + HardwareModel("model"), + HardwareManufacturer("manufacturer"), + FirmwareVersion("version"), + LastRebootReason("reason"), + XmidtProtocol("protocol"), + BootRetryWait(1), + LocalStorage(fs, "credentials.msgpack", 0600), + } + + var listenerCount int + copts := append(opts, + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + if listenerCount == 0 { + assert.Equal("fs", e.Origin) + assert.Error(e.Err) + } else { + assert.Equal("network", e.Origin) + assert.NoError(e.Err) + } + listenerCount++ + }))) + + c, err := New(copts...) + + require.NoError(err) + require.NotNil(c) + + c.Start() + + ctx := context.Background() + deadline, cancel := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + c.WaitUntilFetched(deadline) + + c.Stop() + + dopts := append(opts, + AddFetchListener(event.FetchListenerFunc( + func(e event.Fetch) { + assert.Equal("fs", e.Origin) + assert.NoError(e.Err) + }))) + + d, err := New(dopts...) + require.NoError(err) + + d.Start() + ctx = context.Background() + deadline, cancel = context.WithDeadline(ctx, time.Now().Add(1*time.Second)) + defer cancel() + d.WaitUntilFetched(deadline) + + d.Stop() + + assert.Equal(1, count) +} diff --git a/internal/credentials/event/events.go b/internal/credentials/event/events.go index 91607f1..1be54ac 100644 --- a/internal/credentials/event/events.go +++ b/internal/credentials/event/events.go @@ -15,6 +15,9 @@ type CancelListenerFunc func() // Fetch is the event that is sent when the credentials are fetched. type Fetch struct { + // The origin of the data - "fs" or "network" are the only valid values. + Origin string + // At holds the time when the fetch request was made. At time.Time diff --git a/internal/credentials/options.go b/internal/credentials/options.go index 57874d9..57fe799 100644 --- a/internal/credentials/options.go +++ b/internal/credentials/options.go @@ -4,11 +4,13 @@ package credentials import ( + iofs "io/fs" "net/http" "time" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/credentials/event" + "github.com/xmidt-org/xmidt-agent/internal/fs" ) type optionFunc func(*Credentials) error @@ -72,6 +74,37 @@ func AssumedLifetime(lifetime time.Duration) Option { }) } +// IgnoreBody is a flag that indicates whether the body of the response should +// be ignored instead of examined for an expiration time. The default is to +// examine the body. +func IgnoreBody() Option { + return nilOptionFunc( + func(c *Credentials) { + c.ignoreBody = true + }) +} + +// Required is a flag that indicates whether the credentials are required to +// successfully decorate a request. The default is optional. +func Required() Option { + return nilOptionFunc( + func(c *Credentials) { + c.required = true + }) +} + +// LocalStorage is the local storage used to cache the credentials. +// +// The filename (and path) is relative to the provided filesystem. +func LocalStorage(fs fs.FS, filename string, perm iofs.FileMode) Option { + return nilOptionFunc( + func(c *Credentials) { + c.fs = fs + c.filename = filename + c.perm = perm + }) +} + // MacAddress is the MAC address of the device. func MacAddress(macAddress wrp.DeviceID) Option { return nilOptionFunc( diff --git a/internal/fs/fs.go b/internal/fs/fs.go index 6c9971b..f14bab5 100644 --- a/internal/fs/fs.go +++ b/internal/fs/fs.go @@ -24,6 +24,7 @@ type FS interface { // Option is an interface for options that can be applied in order via the Operate function. type Option interface { + // Apply applies the option to the filesystem. Apply(FS) error } diff --git a/internal/fs/fs_test.go b/internal/fs/fs_test.go index 26dd464..7d3dcad 100644 --- a/internal/fs/fs_test.go +++ b/internal/fs/fs_test.go @@ -18,7 +18,7 @@ var ( errUnknown = errors.New("unknown error") ) -func TestMakeDir(t *testing.T) { +func TestWithDirOrSimiar(t *testing.T) { tests := []struct { description string opt xafs.Option @@ -29,40 +29,104 @@ func TestMakeDir(t *testing.T) { }{ { description: "simple path", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), expect: mem.New(mem.WithDir("foo", 0755)), }, { description: "simple existing path", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), start: mem.New(mem.WithDir("foo", 0755)), expect: mem.New(mem.WithDir("foo", 0755)), }, { description: "not a directory", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), start: mem.New(mem.WithFile("foo", "data", 0755)), expectErr: xafs.ErrNotDirectory, - }, { - description: "not able to read", - opt: xafs.MakeDir("foo", 0755), - start: mem.New(mem.WithDir("foo", 0111)), - expectErr: fs.ErrPermission, }, { description: "error opening the file", - opt: xafs.MakeDir("foo", 0755), + opt: xafs.WithDir("foo", 0755), start: mem.New(mem.WithError("foo", errUnknown)), expectErr: errUnknown, }, { - description: "two directory path", + description: "three directory path", opts: []xafs.Option{ - xafs.MakeDir("foo", 0700), - xafs.MakeDir("foo/bar", 0750), - xafs.MakeDir("foo/bar/car", 0755), + xafs.WithDir("foo", 0700), + xafs.WithDir("foo/bar", 0750), + xafs.WithDir("foo/bar/car", 0755), }, expect: mem.New( mem.WithDir("foo", 0700), mem.WithDir("foo/bar", 0750), mem.WithDir("foo/bar/car", 0755), ), + }, { + description: "abs directory path", + start: mem.New(mem.WithDir("/", 0755)), + opts: []xafs.Option{ + xafs.WithDir("/foo", 0700), + xafs.WithDir("/foo/bar", 0750), + xafs.WithDir("/foo/bar/car", 0755), + }, + expect: mem.New( + mem.WithDir("/", 0755), + mem.WithDir("/foo", 0700), + mem.WithDir("/foo/bar", 0750), + mem.WithDir("/foo/bar/car", 0755), + ), + }, { + description: "WithDirs three directory path", + start: mem.New(), + opts: []xafs.Option{ + xafs.WithDirs("foo/bar/car", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0755), + mem.WithDir("foo/bar", 0755), + mem.WithDir("foo/bar/car", 0755), + ), + }, { + description: "WithDirs three directory path one exists", + start: mem.New(mem.WithDir("foo", 0711)), + opts: []xafs.Option{ + xafs.WithDirs("foo/bar/car", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0711), + mem.WithDir("foo/bar", 0755), + mem.WithDir("foo/bar/car", 0755), + ), + }, { + description: "abs three directory path", + start: mem.New(mem.WithDir("/", 0700)), + opts: []xafs.Option{ + xafs.WithDirs("/boo/egg/cat", 0755), + }, + expect: mem.New( + mem.WithDir("/", 0700), + mem.WithDir("/boo", 0755), + mem.WithDir("/boo/egg", 0755), + mem.WithDir("/boo/egg/cat", 0755), + ), + }, { + description: "WithPath two directory path one exists, and a filename", + start: mem.New(mem.WithDir("foo", 0711)), + opts: []xafs.Option{ + xafs.WithPath("foo/bar/car.json", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0711), + mem.WithDir("foo/bar", 0755), + ), + }, { + description: "Ensure Operate can handle nil options", + start: mem.New(mem.WithDir("foo", 0711)), + opts: []xafs.Option{ + nil, + xafs.WithPath("foo/bar/car.json", 0755), + }, + expect: mem.New( + mem.WithDir("foo", 0711), + mem.WithDir("foo/bar", 0755), + ), }, } @@ -71,7 +135,12 @@ func TestMakeDir(t *testing.T) { require := require.New(t) assert := assert.New(t) - opts := append(tc.opts, tc.opt) + opts := make([]xafs.Option, 0, len(tc.opts)+1) + if tc.opt != nil { + opts = append(tc.opts, tc.opt) + } + opts = append(opts, tc.opts...) + fs := tc.start if fs == nil { fs = mem.New() diff --git a/internal/fs/mem/mem.go b/internal/fs/mem/mem.go index 6e6e118..508972c 100644 --- a/internal/fs/mem/mem.go +++ b/internal/fs/mem/mem.go @@ -105,8 +105,10 @@ func (fs *FS) Open(name string) (iofs.File, error) { } func (fs *FS) Mkdir(path string, perm iofs.FileMode) error { - if err := fs.hasPerms(path, iofs.FileMode(0111)); err != nil { - return err + if path != separator { + if err := fs.hasPerms(path, iofs.FileMode(0111)); err != nil { + return err + } } if fs.Dirs == nil { @@ -201,25 +203,3 @@ func (fs *FS) hasPerms(name string, perm iofs.FileMode) error { return nil } - -// Remove this in favor of pp ... except I can't download pp right now. -/* -func (fs *FS) String() string { - buf := strings.Builder{} - - buf.WriteString("Files:\n") - for k, v := range fs.Files { - fmt.Fprintf(&buf, " '%s': '%s'\n", k, string(v.Bytes)) - } - buf.WriteString("Dirs:\n") - for k, v := range fs.Dirs { - fmt.Fprintf(&buf, " '%s': %v\n", k, v) - } - buf.WriteString("Errs:\n") - for k, v := range fs.Errs { - fmt.Fprintf(&buf, " '%s': %v\n", k, v) - } - - return buf.String() -} -*/ diff --git a/internal/fs/options.go b/internal/fs/options.go index f346086..b07b431 100644 --- a/internal/fs/options.go +++ b/internal/fs/options.go @@ -18,9 +18,17 @@ var ( ErrInvalidSHA = errors.New("invalid SHA for file") ) -// MakeDir is an option that ensures the specified directory exists with the -// specified permissions. -func MakeDir(dir string, perm fs.FileMode) Option { +// Options provides a way to group multiple options together. +func Options(opts ...Option) Option { + return OptionFunc( + func(f FS) error { + return Operate(f, opts...) + }) +} + +// WithDir is an option that ensures the specified directory exists. If it +// does not, create it with the specified permissions. +func WithDir(dir string, perm fs.FileMode) Option { return OptionFunc( func(f FS) error { file, err := f.Open(dir) @@ -33,21 +41,46 @@ func MakeDir(dir string, perm fs.FileMode) Option { defer file.Close() stat, err := file.Stat() - if err != nil { - return err + if err == nil { + if !stat.IsDir() { + return ErrNotDirectory + } } - if !stat.IsDir() { - return ErrNotDirectory - } + return err + }) +} - mode := stat.Mode() - if (mode & fs.ModePerm & perm) != (fs.ModePerm & perm) { - return fs.ErrPermission - } +// WithDirs is an option that ensures the specified directory path exists with +// the specified permissions. The path is split on the path separator and +// each directory is created in order if needed. +// +// Notes: +// - The path should not contain the filename or that will be created as a directory. +// - The same permissions are applied to all directories that are created. +func WithDirs(path string, perm fs.FileMode) Option { + dirs := strings.Split(path, string(filepath.Separator)) + if filepath.IsAbs(path) { + dirs[0] = string(filepath.Separator) + } + + var full string + opts := make([]Option, 0, len(dirs)) + for _, dir := range dirs { + full = filepath.Join(full, dir) + opts = append(opts, WithDir(full, perm)) + } + return Options(opts...) +} - return nil - }) +// WithPath is an option that ensures the set of directories for the specified +// file exists. The directory is determined by calling filepath.Dir on the name. +// +// Notes: +// - The name should contain the filename and any path to ensure is present. +// - The same permissions are applied to all directories that are created. +func WithPath(name string, perm fs.FileMode) Option { + return WithDirs(filepath.Dir(name), perm) } // WriteFileWithSHA256 calculates and writes both the file and a checksum file.