Skip to content

Commit

Permalink
Reuse MakeURL moved to common package
Browse files Browse the repository at this point in the history
  • Loading branch information
vapopov committed Jan 23, 2025
1 parent 30e5bcd commit 721915e
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 234 deletions.
109 changes: 12 additions & 97 deletions lib/autoupdate/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package agent

import (
"encoding/json"
"errors"
"fmt"
"io/fs"
Expand All @@ -30,6 +29,8 @@ import (
"github.com/google/renameio/v2"
"github.com/gravitational/trace"
"gopkg.in/yaml.v3"

"github.com/gravitational/teleport/lib/autoupdate"
)

const (
Expand Down Expand Up @@ -84,13 +85,13 @@ type Revision struct {
// Version is the version of Teleport.
Version string `yaml:"version" json:"version"`
// Flags describe the edition of Teleport.
Flags InstallFlags `yaml:"flags,flow,omitempty" json:"flags,omitempty"`
Flags autoupdate.InstallFlags `yaml:"flags,flow,omitempty" json:"flags,omitempty"`
}

// NewRevision create a Revision.
// If version is not set, no flags are returned.
// This ensures that all Revisions without versions are zero-valued.
func NewRevision(version string, flags InstallFlags) Revision {
func NewRevision(version string, flags autoupdate.InstallFlags) Revision {
if version != "" {
return Revision{
Version: version,
Expand All @@ -113,16 +114,16 @@ func NewRevisionFromDir(dir string) (Revision, error) {
}
switch flags := parts[1:]; len(flags) {
case 2:
if flags[1] != FlagFIPS.DirFlag() {
if flags[1] != autoupdate.FlagFIPS.DirFlag() {
break
}
out.Flags |= FlagFIPS
out.Flags |= autoupdate.FlagFIPS
fallthrough
case 1:
if flags[0] != FlagEnterprise.DirFlag() {
if flags[0] != autoupdate.FlagEnterprise.DirFlag() {
break
}
out.Flags |= FlagEnterprise
out.Flags |= autoupdate.FlagEnterprise
fallthrough
case 0:
return out, nil
Expand All @@ -135,11 +136,11 @@ func (r Revision) Dir() string {
// Do not change the order of these statements.
// Otherwise, installed versions will no longer match update.yaml.
var suffix string
if r.Flags&(FlagEnterprise|FlagFIPS) != 0 {
suffix += "_" + FlagEnterprise.DirFlag()
if r.Flags&(autoupdate.FlagEnterprise|autoupdate.FlagFIPS) != 0 {
suffix += "_" + autoupdate.FlagEnterprise.DirFlag()
}
if r.Flags&FlagFIPS != 0 {
suffix += "_" + FlagFIPS.DirFlag()
if r.Flags&autoupdate.FlagFIPS != 0 {
suffix += "_" + autoupdate.FlagFIPS.DirFlag()
}
return r.Version + suffix
}
Expand Down Expand Up @@ -239,89 +240,3 @@ type FindResp struct {
// Jitter duration before an automated install
Jitter time.Duration `yaml:"jitter"`
}

// InstallFlags sets flags for the Teleport installation
type InstallFlags int

const (
// FlagEnterprise installs enterprise Teleport
FlagEnterprise InstallFlags = 1 << iota
// FlagFIPS installs FIPS Teleport
FlagFIPS
)

// NewInstallFlagsFromStrings returns InstallFlags given a slice of human-readable strings.
func NewInstallFlagsFromStrings(s []string) InstallFlags {
var out InstallFlags
for _, f := range s {
for _, flag := range []InstallFlags{
FlagEnterprise,
FlagFIPS,
} {
if f == flag.String() {
out |= flag
}
}
}
return out
}

// Strings converts InstallFlags to a slice of human-readable strings.
func (i InstallFlags) Strings() []string {
var out []string
for _, flag := range []InstallFlags{
FlagEnterprise,
FlagFIPS,
} {
if i&flag != 0 {
out = append(out, flag.String())
}
}
return out
}

// String returns the string representation of a single InstallFlag flag, or "Unknown".
func (i InstallFlags) String() string {
switch i {
case 0:
return ""
case FlagEnterprise:
return "Enterprise"
case FlagFIPS:
return "FIPS"
}
return "Unknown"
}

// DirFlag returns the directory path representation of a single InstallFlag flag, or "unknown".
func (i InstallFlags) DirFlag() string {
switch i {
case 0:
return ""
case FlagEnterprise:
return "ent"
case FlagFIPS:
return "fips"
}
return "unknown"
}

func (i InstallFlags) MarshalYAML() (any, error) {
return i.Strings(), nil
}

func (i InstallFlags) MarshalJSON() ([]byte, error) {
return json.Marshal(i.Strings())
}

func (i *InstallFlags) UnmarshalYAML(n *yaml.Node) error {
var s []string
if err := n.Decode(&s); err != nil {
return trace.Wrap(err)
}
if i == nil {
return trace.BadParameter("nil install flags while parsing YAML")
}
*i = NewInstallFlagsFromStrings(s)
return nil
}
75 changes: 3 additions & 72 deletions lib/autoupdate/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package agent
import (
"testing"

"github.com/gravitational/teleport/lib/autoupdate"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)

func TestNewRevisionFromDir(t *testing.T) {
Expand All @@ -46,15 +46,15 @@ func TestNewRevisionFromDir(t *testing.T) {
dir: "1.2.3_ent_fips",
rev: Revision{
Version: "1.2.3",
Flags: FlagEnterprise | FlagFIPS,
Flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS,
},
},
{
name: "ent",
dir: "1.2.3_ent",
rev: Revision{
Version: "1.2.3",
Flags: FlagEnterprise,
Flags: autoupdate.FlagEnterprise,
},
},
{
Expand Down Expand Up @@ -124,72 +124,3 @@ func TestNewRevisionFromDir(t *testing.T) {
})
}
}

func TestInstallFlagsYAML(t *testing.T) {
t.Parallel()

for _, tt := range []struct {
name string
yaml string
flags InstallFlags
skipYAML bool
}{
{
name: "both",
yaml: `["Enterprise", "FIPS"]`,
flags: FlagEnterprise | FlagFIPS,
},
{
name: "order",
yaml: `["FIPS", "Enterprise"]`,
flags: FlagEnterprise | FlagFIPS,
skipYAML: true,
},
{
name: "extra",
yaml: `["FIPS", "Enterprise", "bad"]`,
flags: FlagEnterprise | FlagFIPS,
skipYAML: true,
},
{
name: "enterprise",
yaml: `["Enterprise"]`,
flags: FlagEnterprise,
},
{
name: "fips",
yaml: `["FIPS"]`,
flags: FlagFIPS,
},
{
name: "empty",
yaml: `[]`,
},
{
name: "nil",
skipYAML: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
var flags InstallFlags
err := yaml.Unmarshal([]byte(tt.yaml), &flags)
require.NoError(t, err)
require.Equal(t, tt.flags, flags)

// verify test YAML
var v any
err = yaml.Unmarshal([]byte(tt.yaml), &v)
require.NoError(t, err)
res, err := yaml.Marshal(v)
require.NoError(t, err)

// compare verified YAML to flag output
out, err := yaml.Marshal(flags)
require.NoError(t, err)

if !tt.skipYAML {
require.Equal(t, string(res), string(out))
}
})
}
}
33 changes: 4 additions & 29 deletions lib/autoupdate/agent/installer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ import (
"os"
"path"
"path/filepath"
"runtime"
"syscall"
"text/template"
"time"

"github.com/google/renameio/v2"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/autoupdate"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -135,7 +134,7 @@ func (li *LocalInstaller) Install(ctx context.Context, rev Revision, template st
sumPath := filepath.Join(versionDir, checksumType)

// generate download URI from template
uri, err := makeURL(template, rev)
uri, err := autoupdate.MakeURL(template, autoupdate.DefaultBaseURL, autoupdate.DefaultPackage, rev.Version, rev.Flags)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -229,30 +228,6 @@ func (li *LocalInstaller) Install(ctx context.Context, rev Revision, template st
return nil
}

// makeURL to download the Teleport tgz.
func makeURL(uriTmpl string, rev Revision) (string, error) {
tmpl, err := template.New("uri").Parse(uriTmpl)
if err != nil {
return "", trace.Wrap(err)
}
var uriBuf bytes.Buffer
params := struct {
OS, Version, Arch string
FIPS, Enterprise bool
}{
OS: runtime.GOOS,
Version: rev.Version,
Arch: runtime.GOARCH,
FIPS: rev.Flags&FlagFIPS != 0,
Enterprise: rev.Flags&(FlagEnterprise|FlagFIPS) != 0,
}
err = tmpl.Execute(&uriBuf, params)
if err != nil {
return "", trace.Wrap(err)
}
return uriBuf.String(), nil
}

// readChecksum from the version directory.
func readChecksum(path string) ([]byte, error) {
f, err := os.Open(path)
Expand Down Expand Up @@ -354,7 +329,7 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64,
return shaReader.Sum(nil), nil
}

func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64, flags InstallFlags) error {
func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64, flags autoupdate.InstallFlags) error {
if err := os.MkdirAll(dstDir, systemDirMode); err != nil {
return trace.Wrap(err)
}
Expand All @@ -372,7 +347,7 @@ func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Rea
}
li.Log.InfoContext(ctx, "Extracting Teleport tarball.", "path", dstDir, "size", max)

err = utils.Extract(zr, dstDir, tgzExtractPaths(flags&(FlagEnterprise|FlagFIPS) != 0)...)
err = utils.Extract(zr, dstDir, tgzExtractPaths(flags&(autoupdate.FlagEnterprise|autoupdate.FlagFIPS) != 0)...)
if err != nil {
return trace.Wrap(err)
}
Expand Down
4 changes: 3 additions & 1 deletion lib/autoupdate/agent/installer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/autoupdate"
)

func TestLocalInstaller_Install(t *testing.T) {
Expand All @@ -52,7 +54,7 @@ func TestLocalInstaller_Install(t *testing.T) {
reservedTmp uint64
reservedInstall uint64
existingSum string
flags InstallFlags
flags autoupdate.InstallFlags

errMatch string
}{
Expand Down
Loading

0 comments on commit 721915e

Please sign in to comment.