Skip to content

Commit

Permalink
feat: authenticate with access token
Browse files Browse the repository at this point in the history
  • Loading branch information
k-capehart committed Jun 7, 2024
1 parent b03426b commit 20c1907
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 18 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ if err != nil {
}
```

Authenticate with an Access Token
- Implement your own OAuth flow and use the resulting `access_token` in the response

```go
sf, sfErr := salesforce.Init(salesforce.Creds{
Domain: DOMAIN,
AccessToken: ACCESS_TOKEN,
})
if err != nil {
panic(err)
}
```


## SOQL
Query Salesforce records
- [Review Salesforce REST API resources for queries](https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/resources_query.htm)
Expand Down
32 changes: 30 additions & 2 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,35 @@ type Creds struct {
SecurityToken string
ConsumerKey string
ConsumerSecret string
AccessToken string
}

const (
grantTypePassword = "password"
grantTypeClientCredentials = "client_credentials"
)

func validateAuth(sf Salesforce) error {
if sf.auth == nil || sf.auth.AccessToken == "" {
return errors.New("not authenticated: please use salesforce.Init()")
}
return nil
}

func validateSession(auth authentication) error {
if err := validateAuth(Salesforce{auth: &auth}); err != nil {
return err
}
resp, err := doRequest(http.MethodGet, "/limits", jsonType, auth, "")
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return processSalesforceError(*resp)
}
return nil
}

func doAuth(url string, body *strings.Reader) (*authentication, error) {
resp, err := http.Post(url, "application/x-www-form-urlencoded", body)
if err != nil {
Expand All @@ -61,7 +81,7 @@ func doAuth(url string, body *strings.Reader) (*authentication, error) {

func usernamePasswordFlow(domain string, username string, password string, securityToken string, consumerKey string, consumerSecret string) (*authentication, error) {
payload := url.Values{
"grant_type": {"password"},
"grant_type": {grantTypePassword},
"client_id": {consumerKey},
"client_secret": {consumerSecret},
"username": {username},
Expand All @@ -78,7 +98,7 @@ func usernamePasswordFlow(domain string, username string, password string, secur

func clientCredentialsFlow(domain string, consumerKey string, consumerSecret string) (*authentication, error) {
payload := url.Values{
"grant_type": {"client_credentials"},
"grant_type": {grantTypeClientCredentials},
"client_id": {consumerKey},
"client_secret": {consumerSecret},
}
Expand All @@ -90,3 +110,11 @@ func clientCredentialsFlow(domain string, consumerKey string, consumerSecret str
}
return auth, nil
}

func setAccessToken(domain string, accessToken string) (*authentication, error) {
auth := &authentication{InstanceUrl: domain, AccessToken: accessToken}
if err := validateSession(*auth); err != nil {
return nil, err
}
return auth, nil
}
73 changes: 66 additions & 7 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package salesforce

import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
Expand Down Expand Up @@ -53,9 +52,7 @@ func Test_usernamePasswordFlow(t *testing.T) {
server, _ := setupTestServer(auth, http.StatusOK)
defer server.Close()

badServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
badServer, _ := setupTestServer(auth, http.StatusForbidden)
defer badServer.Close()

type args struct {
Expand Down Expand Up @@ -124,9 +121,7 @@ func Test_clientCredentialsFlow(t *testing.T) {
server, _ := setupTestServer(auth, http.StatusOK)
defer server.Close()

badServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
badServer, _ := setupTestServer(auth, http.StatusForbidden)
defer badServer.Close()

type args struct {
Expand Down Expand Up @@ -174,3 +169,67 @@ func Test_clientCredentialsFlow(t *testing.T) {
})
}
}

func Test_setAccessToken(t *testing.T) {
auth := authentication{
InstanceUrl: "example.com",
AccessToken: "1234",
}
server, _ := setupTestServer(auth, http.StatusOK)
defer server.Close()

badServer, _ := setupTestServer(auth, http.StatusForbidden)
defer badServer.Close()

type args struct {
domain string
accessToken string
}
tests := []struct {
name string
args args
want *authentication
wantErr bool
}{
{
name: "authentication_success",
args: args{
domain: server.URL,
accessToken: "1234",
},
want: &auth,
wantErr: false,
},
{
name: "authentication_fail_http_error",
args: args{
domain: badServer.URL,
accessToken: "1234",
},
want: nil,
wantErr: true,
},
{
name: "authentication_fail_no_token",
args: args{
domain: server.URL,
accessToken: "",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := setAccessToken(tt.args.domain, tt.args.accessToken)
if (err != nil) != tt.wantErr {
t.Errorf("setAccessToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if (tt.want == nil && !reflect.DeepEqual(got, tt.want)) ||
(tt.want != nil && !reflect.DeepEqual(got.AccessToken, tt.want.AccessToken)) {
t.Errorf("setAccessToken() = %v, want %v", got, tt.want)
}
})
}
}
19 changes: 11 additions & 8 deletions salesforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,6 @@ func Init(creds Creds) (*Salesforce, error) {
var auth *authentication
var err error
if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" &&
(creds.Username == "" || creds.Password == "" || creds.SecurityToken == "") {

auth, err = clientCredentialsFlow(
creds.Domain,
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" &&
creds.Username != "" && creds.Password != "" && creds.SecurityToken != "" {

auth, err = usernamePasswordFlow(
Expand All @@ -210,6 +202,17 @@ func Init(creds Creds) (*Salesforce, error) {
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" {
auth, err = clientCredentialsFlow(
creds.Domain,
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.AccessToken != "" {
auth, err = setAccessToken(
creds.Domain,
creds.AccessToken,
)
}

if err != nil {
Expand Down
20 changes: 19 additions & 1 deletion salesforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,24 @@ func TestInit(t *testing.T) {
want: &Salesforce{auth: &sfAuth},
wantErr: false,
},
{
name: "authentication_client_credentials",
args: args{creds: Creds{
Domain: server.URL,
ConsumerKey: "key",
ConsumerSecret: "secret",
}},
want: &Salesforce{auth: &sfAuth},
wantErr: false,
},
{
name: "authentication_access_token",
args: args{creds: Creds{
Domain: server.URL,
AccessToken: "1234",
}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -386,7 +404,7 @@ func TestInit(t *testing.T) {
t.Errorf("Init() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
if tt.want != nil && !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
t.Errorf("Init() = %v, want %v", *got.auth, *tt.want.auth)
}
})
Expand Down

0 comments on commit 20c1907

Please sign in to comment.