diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 5c20dca4..c59a627a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -25,4 +25,5 @@ https://github.com/em0ney https://github.com/zqben402 https://github.com/dlackty https://github.com/amcintosh - +https://github.com/valiton +https://github.com/joshgarnett diff --git a/aws-es-proxy.go b/aws-es-proxy.go index 4b3e0ad4..e9729ee2 100644 --- a/aws-es-proxy.go +++ b/aws-es-proxy.go @@ -87,6 +87,7 @@ type proxy struct { password string realm string remoteTerminate bool + assumeRole string } func newProxy(args ...interface{}) *proxy { @@ -112,6 +113,7 @@ func newProxy(args ...interface{}) *proxy { password: args[8].(string), realm: args[9].(string), remoteTerminate: args[10].(bool), + assumeRole: args[11].(string), } } @@ -187,7 +189,6 @@ func (p *proxy) parseEndpoint() error { func (p *proxy) getSigner() *v4.Signer { // Refresh credentials after expiration. Required for STS if p.credentials == nil { - sess, err := session.NewSession( &aws.Config{ Region: aws.String(p.region), @@ -198,13 +199,26 @@ func (p *proxy) getSigner() *v4.Signer { logrus.Debugln(err) } - credentials := sess.Config.Credentials awsRoleARN := os.Getenv("AWS_ROLE_ARN") awsWebIdentityTokenFile := os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE") + + var creds *credentials.Credentials if awsRoleARN != "" && awsWebIdentityTokenFile != "" { - credentials = stscreds.NewWebIdentityCredentials(sess, awsRoleARN, "", awsWebIdentityTokenFile) + logrus.Infof("Using web identity credentials with role %s", awsRoleARN) + creds = stscreds.NewWebIdentityCredentials(sess, awsRoleARN, "", awsWebIdentityTokenFile) + } else if p.assumeRole != "" { + logrus.Infof("Assuming credentials from %s", p.assumeRole) + creds = stscreds.NewCredentials(sess, p.assumeRole, func(provider *stscreds.AssumeRoleProvider) { + provider.Duration = 17 * time.Minute + provider.ExpiryWindow = 13 * time.Minute + provider.MaxJitterFrac = 0.1 + }) + } else { + logrus.Infoln("Using default credentials") + creds = sess.Config.Credentials } - p.credentials = credentials + + p.credentials = creds logrus.Infoln("Generated fresh AWS Credentials object") } @@ -264,7 +278,13 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Sign the request with AWSv4 payload := bytes.NewReader(replaceBody(req)) - signer.Sign(req, payload, p.service, p.region, time.Now()) + _, err := signer.Sign(req, payload, p.service, p.region, time.Now()) + if err != nil { + p.credentials = nil + logrus.Errorln("Failed to sign", err) + http.Error(w, "Failed to sign", http.StatusForbidden) + return + } } resp, err := p.httpClient.Do(req) @@ -443,6 +463,7 @@ func main() { err error timeout int remoteTerminate bool + assumeRole string ) flag.StringVar(&endpoint, "endpoint", "", "Amazon ElasticSearch Endpoint (e.g: https://dummy-host.eu-west-1.es.amazonaws.com)") @@ -459,6 +480,7 @@ func main() { flag.StringVar(&password, "password", "", "HTTP Basic Auth Password") flag.StringVar(&realm, "realm", "", "Authentication Required") flag.BoolVar(&remoteTerminate, "remote-terminate", false, "Allow HTTP remote termination") + flag.StringVar(&assumeRole, "assume", "", "Optionally specify role to assume") flag.Parse() if endpoint == "" { @@ -505,6 +527,7 @@ func main() { password, realm, remoteTerminate, + assumeRole, ) if err = p.parseEndpoint(); err != nil {