forked from dghubble/oauth1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
auther.go
264 lines (244 loc) · 8.69 KB
/
auther.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
package oauth1
import (
"bytes"
"crypto/rand"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
const (
authorizationHeaderParam = "Authorization"
authorizationPrefix = "OAuth " // trailing space is intentional
oauthConsumerKeyParam = "oauth_consumer_key"
oauthNonceParam = "oauth_nonce"
oauthSignatureParam = "oauth_signature"
oauthSignatureMethodParam = "oauth_signature_method"
oauthTimestampParam = "oauth_timestamp"
oauthTokenParam = "oauth_token"
oauthVersionParam = "oauth_version"
oauthCallbackParam = "oauth_callback"
oauthVerifierParam = "oauth_verifier"
defaultOauthVersion = "1.0"
contentType = "Content-Type"
formContentType = "application/x-www-form-urlencoded"
)
// clock provides a interface for current time providers. A Clock can be used
// in place of calling time.Now() directly.
type clock interface {
Now() time.Time
}
// A noncer provides random nonce strings.
type noncer interface {
Nonce() string
}
// auther adds an "OAuth" Authorization header field to requests.
type auther struct {
config *Config
clock clock
noncer noncer
}
func newAuther(config *Config) *auther {
return &auther{
config: config,
}
}
// setRequestTokenAuthHeader adds the OAuth1 header for the request token
// request (temporary credential) according to RFC 5849 2.1.
func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthCallbackParam] = a.config.CallbackURL
params, err := collectParameters(req, oauthParams)
if err != nil {
return err
}
signatureBase := signatureBase(req, params)
signature, err := a.signer().Sign("", signatureBase)
if err != nil {
return err
}
oauthParams[oauthSignatureParam] = signature
req.Header.Set(authorizationHeaderParam, authHeaderValue(oauthParams))
return nil
}
// setAccessTokenAuthHeader sets the OAuth1 header for the access token request
// (token credential) according to RFC 5849 2.3.
func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = requestToken
oauthParams[oauthVerifierParam] = verifier
params, err := collectParameters(req, oauthParams)
if err != nil {
return err
}
signatureBase := signatureBase(req, params)
signature, err := a.signer().Sign(requestSecret, signatureBase)
if err != nil {
return err
}
oauthParams[oauthSignatureParam] = signature
req.Header.Set(authorizationHeaderParam, authHeaderValue(oauthParams))
return nil
}
// setRequestAuthHeader sets the OAuth1 header for making authenticated
// requests with an AccessToken (token credential) according to RFC 5849 3.1.
func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = accessToken.Token
params, err := collectParameters(req, oauthParams)
if err != nil {
return err
}
signatureBase := signatureBase(req, params)
signature, err := a.signer().Sign(accessToken.TokenSecret, signatureBase)
if err != nil {
return err
}
oauthParams[oauthSignatureParam] = signature
req.Header.Set(authorizationHeaderParam, authHeaderValue(oauthParams))
return nil
}
// commonOAuthParams returns a map of the common OAuth1 protocol parameters,
// excluding the oauth_signature parameter.
func (a *auther) commonOAuthParams() map[string]string {
return map[string]string{
oauthConsumerKeyParam: a.config.ConsumerKey,
oauthSignatureMethodParam: a.signer().Name(),
oauthTimestampParam: strconv.FormatInt(a.epoch(), 10),
oauthNonceParam: a.nonce(),
oauthVersionParam: defaultOauthVersion,
}
}
// Returns a base64 encoded random 32 byte string.
func (a *auther) nonce() string {
if a.noncer != nil {
return a.noncer.Nonce()
}
b := make([]byte, 16)
rand.Read(b)
return fmt.Sprintf("%x", b)
}
// Returns the Unix epoch seconds.
func (a *auther) epoch() int64 {
if a.clock != nil {
return a.clock.Now().Unix()
}
return time.Now().Unix()
}
// Returns the Config's Signer or the default Signer.
func (a *auther) signer() Signer {
if a.config.Signer != nil {
return a.config.Signer
}
return &HMACSigner{ConsumerSecret: a.config.ConsumerSecret}
}
// authHeaderValue formats OAuth parameters according to RFC 5849 3.5.1. OAuth
// params are percent encoded, sorted by key (for testability), and joined by
// "=" into pairs. Pairs are joined with a ", " comma separator into a header
// string.
// The given OAuth params should include the "oauth_signature" key.
func authHeaderValue(oauthParams map[string]string) string {
pairs := sortParameters(encodeParameters(oauthParams), `%s="%s"`)
return authorizationPrefix + strings.Join(pairs, ", ")
}
// encodeParameters percent encodes parameter keys and values according to
// RFC5849 3.6 and RFC3986 2.1 and returns a new map.
func encodeParameters(params map[string]string) map[string]string {
encoded := map[string]string{}
for key, value := range params {
encoded[PercentEncode(key)] = PercentEncode(value)
}
return encoded
}
// sortParameters sorts parameters by key and returns a slice of key/value
// pairs formatted with the given format string (e.g. "%s=%s").
func sortParameters(params map[string]string, format string) []string {
// sort by key
keys := make([]string, len(params))
i := 0
for key := range params {
keys[i] = key
i++
}
sort.Strings(keys)
// parameter join
pairs := make([]string, len(params))
for i, key := range keys {
pairs[i] = fmt.Sprintf(format, key, params[key])
}
return pairs
}
// collectParameters collects request parameters from the request query, OAuth
// parameters (which should exclude oauth_signature), and the request body
// provided the body is single part, form encoded, and the form content type
// header is set. The returned map of collected parameter keys and values
// follow RFC 5849 3.4.1.3, except duplicate parameters are not supported.
func collectParameters(req *http.Request, oauthParams map[string]string) (map[string]string, error) {
// add oauth, query, and body parameters into params
params := map[string]string{}
for key, value := range req.URL.Query() {
// most backends do not accept duplicate query keys
params[key] = value[0]
}
if req.Body != nil && req.Header.Get(contentType) == formContentType {
// reads data to a []byte, draining req.Body
b, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, err
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, err
}
for key, value := range values {
// not supporting params with duplicate keys
params[key] = value[0]
}
// reinitialize Body with ReadCloser over the []byte
req.Body = ioutil.NopCloser(bytes.NewReader(b))
}
for key, value := range oauthParams {
params[key] = value
}
return params, nil
}
// signatureBase combines the uppercase request method, percent encoded base
// string URI, and normalizes the request parameters int a parameter string.
// Returns the OAuth1 signature base string according to RFC5849 3.4.1.
func signatureBase(req *http.Request, params map[string]string) string {
method := strings.ToUpper(req.Method)
baseURL := baseURI(req)
parameterString := normalizedParameterString(params)
// signature base string constructed accoding to 3.4.1.1
baseParts := []string{method, PercentEncode(baseURL), PercentEncode(parameterString)}
return strings.Join(baseParts, "&")
}
// baseURI returns the base string URI of a request according to RFC 5849
// 3.4.1.2. The scheme and host are lowercased, the port is dropped if it
// is 80 or 443, and the path minus query parameters is included.
func baseURI(req *http.Request) string {
scheme := strings.ToLower(req.URL.Scheme)
host := strings.ToLower(req.URL.Host)
if hostPort := strings.Split(host, ":"); len(hostPort) == 2 && (hostPort[1] == "80" || hostPort[1] == "443") {
host = hostPort[0]
}
// TODO: use req.URL.EscapedPath() once Go 1.5 is more generally adopted
// For now, hacky workaround accomplishes the same internal escaping mode
// escape(u.Path, encodePath) for proper compliance with the OAuth1 spec.
path := req.URL.Path
if path != "" {
path = strings.Split(req.URL.RequestURI(), "?")[0]
}
return fmt.Sprintf("%v://%v%v", scheme, host, path)
}
// parameterString normalizes collected OAuth parameters (which should exclude
// oauth_signature) into a parameter string as defined in RFC 5894 3.4.1.3.2.
// The parameters are encoded, sorted by key, keys and values joined with "&",
// and pairs joined with "=" (e.g. foo=bar&q=gopher).
func normalizedParameterString(params map[string]string) string {
return strings.Join(sortParameters(encodeParameters(params), "%s=%s"), "&")
}