diff --git a/go.mod b/go.mod index abf47d4..aeb5983 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/vartanbeno/go-reddit/v2 +module github.com/bevzzz/go-reddit/v2 go 1.15 diff --git a/reddit/post_test.go b/reddit/post_test.go index 78d26b3..e53962b 100644 --- a/reddit/post_test.go +++ b/reddit/post_test.go @@ -36,6 +36,10 @@ var expectedPostAndComments = &PostAndComments{ AuthorID: "t2_testuser", IsSelfPost: true, + + Hidden: true, + RemovedByCategory: "moderator", + BannedBy: "", }, Comments: []*Comment{ { diff --git a/reddit/reddit-oauth.go b/reddit/reddit-oauth.go index 13030d6..a74444d 100644 --- a/reddit/reddit-oauth.go +++ b/reddit/reddit-oauth.go @@ -12,16 +12,30 @@ Docs: - Script (the simplest type of app). Select this if you are the only person who will use the app. Only has access to your account. -Best option for a client like this is to use the script option. +This package currently supports clients for "script" and "web app" options. 2. After creating the app, you will get a client id and client secret. 3. Send a POST request (with the Content-Type header set to "application/x-www-form-urlencoded") -to https://www.reddit.com/api/v1/access_token with the following form values: +to https://www.reddit.com/api/v1/access_token to obtain the access (and refresh, read further) token. + +To authorize a "script" app, include the following form values: - grant_type=password - username={your Reddit username} - password={your Reddit password} +To authorize a "web" app, you will first need to obtain a "code" that can be exchanged for an access token. +It's a two-step process: + +3.1 Redirect the user to https://www.reddit.com/api/v1/authorize (Reddit's official authorization URL). +To find out more about the required request parameters, see the "OAuth2" article from the Docs. + +3.2. User will be redirected to your app's "Redirect URI". Extract "code" from the query parameters +and exchange it for the access_token, including the following form values: + - grant_type=authorization_code + - code={code} + - redirect_uri={you app's "Redirect URI"} + 4. You should receive a response body like the following: { "access_token": "70743860-DRhHVNSEOMu1ldlI", @@ -29,17 +43,58 @@ to https://www.reddit.com/api/v1/access_token with the following form values: "expires_in": 3600, "scope": "*" } + +Note: web apps can obtain a refresh token by adding `&duration=permanent` parameter to the "authorization URL" (step 3.1). */ package reddit import ( "context" + "fmt" "net/http" + "time" "golang.org/x/oauth2" ) +// webAppOAuthParams are used to retrieve access token using "code flow" or refresh_token, +// see https://github.com/reddit-archive/reddit/wiki/OAuth2#token-retrieval-code-flow. +type webAppOAuthParams struct { + // Code can be exchanged for access_token. + Code string + + // RedirectURI is used to build an AuthCodeURL when requesting users to grant access, + // and later exchanging code for access_token. The URI must be valid, as it will receive + // a request containing the `code` after user grants access to the app. Part of the "code flow". + RedirectURI string + + // RefreshToken should be set to retrieve a new access_token, ignoring the "code flow". + RefreshToken string +} + +// TokenSource creates a reusable token source base on the provided configuration. If code is set, +// it is exchanged for an access_token. If, on the other hand RefreshToken is set, we assume that +// the initial authorization has already happened and create an oauth2.Token with immediate expiry. +func (p webAppOAuthParams) TokenSource(ctx context.Context, config *oauth2.Config) (oauth2.TokenSource, error) { + var tok *oauth2.Token + var err error + + if p.RefreshToken != "" { + tok = &oauth2.Token{ + RefreshToken: p.RefreshToken, + Expiry: time.Now(), // refresh before using + } + } else if p.Code != "" { + if tok, err = config.Exchange(ctx, p.Code); err != nil { + return nil, fmt.Errorf("exchange code: %w", err) + } + } + return config.TokenSource(ctx, tok), err +} + +// oauthTokenSource retrieves access_token from resource owner's +// username and password. It implements oauth2.TokenSource. type oauthTokenSource struct { ctx context.Context config *oauth2.Config @@ -50,7 +105,8 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) { return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password) } -func oauthTransport(client *Client) http.RoundTripper { +// oauthTransport returns a Transport to handle authorization based the selected app type. +func oauthTransport(client *Client) (*oauth2.Transport, error) { httpClient := &http.Client{Transport: client.client.Transport} ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) @@ -63,15 +119,51 @@ func oauthTransport(client *Client) http.RoundTripper { }, } - tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{ - ctx: ctx, - config: config, - username: client.Username, - password: client.Password, - }) + transport := &oauth2.Transport{Base: client.client.Transport} - return &oauth2.Transport{ - Source: tokenSource, - Base: client.client.Transport, + switch client.appType { + case Script: + transport.Source = oauth2.ReuseTokenSource(nil, &oauthTokenSource{ + ctx: ctx, + config: config, + username: client.Username, + password: client.Password, + }) + case WebApp: + config.RedirectURL = client.webOauth.RedirectURI + ts, err := client.webOauth.TokenSource(ctx, config) + if err != nil { + return nil, err + } + transport.Source = ts + default: + // Should we panic here? There is not supposed to be any other app type. } + + return transport, nil +} + +// AuthCodeURL is a util function for buiding a URL to request permission grant from a user. +// +// TODO: Currently only works with defaultAuthURL, +// but should be able to use a custom AuthURL. Need to find an elegant solution. +// +// By default, Reddit will only issue an access_token to a WebApp for 1h, +// after which the app would need to ask the user to grant access again. +// `permanent` should be set to true to additionally request a refresh_token. +func AuthCodeURL(clientID, redirectURI, state string, scopes []string, permanent bool) string { + config := &oauth2.Config{ + ClientID: clientID, + Endpoint: oauth2.Endpoint{ + AuthURL: defaultAuthURL, + }, + RedirectURL: redirectURI, + Scopes: scopes, + } + var opts []oauth2.AuthCodeOption + if permanent { + opts = append(opts, oauth2.SetAuthURLParam("duration", "permanent")) + } + + return config.AuthCodeURL(state, opts...) } diff --git a/reddit/reddit-oauth_test.go b/reddit/reddit-oauth_test.go new file mode 100644 index 0000000..0c679b1 --- /dev/null +++ b/reddit/reddit-oauth_test.go @@ -0,0 +1,152 @@ +package reddit + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +const ( + testCode = "test_code" + testAccessToken = "test_access_token" + testRefreshToken = "test_refresh_token" + testRedirectURI = "http://localhost:5000/auth" // doens't need to be a valid URL + + clientId = "test_client" + clientSecret = "test_secret" + + subreddit = "golang" +) + +func TestAuthCodeURL(t *testing.T) { + state := "test_state" + scopes := []string{"scope_a", "scope_b"} + + for _, tt := range []struct { + name string + permanent bool + }{ + {"not requesting refresh token", false}, + {"request refresh token", true}, + } { + t.Run(tt.name, func(t *testing.T) { + got, err := url.Parse(AuthCodeURL(clientId, testRedirectURI, state, scopes, tt.permanent)) + if err != nil { + t.Fatal(err) + } + + checkQueryParameter(t, got, "client_id", clientId) + checkQueryParameter(t, got, "state", state) + checkQueryParameter(t, got, "redirect_uri", testRedirectURI) + checkQueryParameter(t, got, "scope", strings.Join(scopes, " ")) + + if tt.permanent { + checkQueryParameter(t, got, "duration", "permanent") + } else { + checkQueryParameter(t, got, "duration", "") + } + }) + } +} + +func TestWebAppOauth(t *testing.T) { + srv := testRedditServer(t) + t.Cleanup(srv.Close) + + for _, tt := range []struct { + name string + opt Opt + }{ + {"web app with code", WithWebAppCode(testCode, testRedirectURI)}, + {"web app with refresh_token", WithWebAppRefresh(testRefreshToken)}, + } { + t.Run(tt.name, func(t *testing.T) { + rc, err := NewClient( + Credentials{ID: clientId, Secret: clientSecret}, + WithBaseURL(srv.URL), + WithTokenURL(srv.URL+"/access_token"), + tt.opt, + ) + if err != nil { + t.Fatalf("create client: %v", err) + } + + // Make a request: check that the client has received the correct access token + _, _, err = rc.Subreddit.TopPosts(context.Background(), subreddit, nil) + if err != nil { + t.Errorf("make authorized request: %v", err) + } + }) + } +} + +// testRedditServer mocks both reddit.com (for authorization) and oauth.reddit.com (for interactions). +// It only handles a number of endpoints necessary for tests. +func testRedditServer(tb testing.TB) *httptest.Server { + mux := http.NewServeMux() + + // Exchange code for access_token + mux.HandleFunc("/access_token", func(w http.ResponseWriter, r *http.Request) { + enc := json.NewEncoder(w) + w.Header().Set("Content-type", "application/json") + + // Validate grant type + var ok bool + switch r.FormValue("grant_type") { + case "authorization_code": + if code := r.FormValue("code"); code == testCode { + // Actual Reddit API returns a different error message + ok = true + } + case "refresh_token": + if rt := r.FormValue("refresh_token"); rt == testRefreshToken { + ok = true + } + default: + tb.Log("unexpected grant type:", r.FormValue("grant_type")) + } + + if !ok { + // Actual Reddit API returns a different error message + enc.Encode(map[string]string{"error": "bad_request"}) + return + } + + enc.Encode(map[string]interface{}{ + "access_token": testAccessToken, + "token_type": "bearer", + "expires_in": 10 * time.Second, + "scope": "scope1,scope2", + "refresh_token": testRefreshToken, + }) + }) + + // Return the top post for the subreddit + mux.HandleFunc("/r/"+subreddit+"/top", func(w http.ResponseWriter, r *http.Request) { + if tok := strings.TrimLeft(r.Header.Get("Authorization"), "Bearer "); tok != testAccessToken { + http.Error(w, "", http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "kind": kindPost, + "data": map[string]string{}, // data not needed for the test + }) + }) + + srv := httptest.NewServer(mux) + return srv +} + +// checkQueryParameter validates URL query parameters. +func checkQueryParameter(tb testing.TB, URL *url.URL, param, want string) { + if got := URL.Query().Get(param); got != want { + tb.Errorf("%s: got %q, want %q", param, got, want) + } +} diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go index 25d397b..a694472 100644 --- a/reddit/reddit-options.go +++ b/reddit/reddit-options.go @@ -55,6 +55,37 @@ func WithTokenURL(u string) Opt { } } +// WithWebAppCode sets webOauth parameters for the client. +// Can be used to authorize a client immediately after receiving a callback +// to the web apps' redirect URI. +// Unlike BaseURL and TokenURL, redirectURI is a required parameter, +// because it is client-specific and no sensible default can be provided. +// Changes the client's appType to WebApp. +func WithWebAppCode(code, redirectURI string) Opt { + return func(c *Client) error { + c.appType = WebApp + c.webOauth = webAppOAuthParams{ + Code: code, + RedirectURI: redirectURI, + } + return nil + } +} + +// WithWebAppCode sets webOauth parameters for the client. It should be used in cases +// where the client wishes to "restore" its session with a cached refresh token, +// and is therefore mutually exclusive with WithWebAppCode option. +// Changes the client's appType to WebApp. +func WithWebAppRefresh(refreshToken string) Opt { + return func(c *Client) error { + c.appType = WebApp + c.webOauth = webAppOAuthParams{ + RefreshToken: refreshToken, + } + return nil + } +} + // FromEnv configures the client with values from environment variables. // Supported environment variables: // GO_REDDIT_CLIENT_ID to set the client's id. diff --git a/reddit/reddit-user-agent.go b/reddit/reddit-user-agent.go index 0e69b67..7158323 100644 --- a/reddit/reddit-user-agent.go +++ b/reddit/reddit-user-agent.go @@ -2,21 +2,6 @@ package reddit import "net/http" -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map, -// since we'll only be modify the headers. -// Per the specification of http.RoundTripper, we should not directly modify a request. -func cloneRequest(r *http.Request) *http.Request { - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - return r2 -} - // Sets the User-Agent header for requests. // We need to set a custom user agent because using the one set by the // stdlib gives us 429 Too Many Requests responses from the Reddit API. @@ -25,14 +10,8 @@ type userAgentTransport struct { Base http.RoundTripper } -func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request { - req2 := cloneRequest(req) - req2.Header.Set(headerUserAgent, t.userAgent) - return req2 -} - func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req2 := t.setUserAgent(req) + req2 := withUserAgent(req, t.userAgent) return t.base().RoundTrip(req2) } @@ -42,3 +21,25 @@ func (t *userAgentTransport) base() http.RoundTripper { } return http.DefaultTransport } + +// withUserAgent creates a copy of the request with the "User-Agent" header set. +// Per the specification of http.RoundTripper, we should not modify the request directly. +func withUserAgent(req *http.Request, agent string) *http.Request { + req2 := cloneRequest(req) + req2.Header.Set(headerUserAgent, agent) + return req2 +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map, +// since we'll only need to modify the headers. +func cloneRequest(r *http.Request) *http.Request { + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + return r2 +} diff --git a/reddit/reddit.go b/reddit/reddit.go index bafda0c..c7ae5a0 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -26,6 +26,7 @@ const ( defaultBaseURL = "https://oauth.reddit.com" defaultBaseURLReadonly = "https://reddit.com" defaultTokenURL = "https://www.reddit.com/api/v1/access_token" + defaultAuthURL = "https://www.reddit.com/api/v1/authorize" mediaTypeJSON = "application/json" mediaTypeForm = "application/x-www-form-urlencoded" @@ -39,6 +40,15 @@ const ( headerRateLimitReset = "x-ratelimit-reset" ) +// AppType represents the possible OAuth2 application types as defined by Reddit. +// See: https://github.com/reddit-archive/reddit/wiki/oauth2-app-types +type AppType int + +const ( + Script AppType = iota // default to "Script" for backward compatibility + WebApp +) + var defaultClient, _ = NewReadonlyClient() // DefaultClient returns a valid, read-only client with limited access to the Reddit API. @@ -70,11 +80,20 @@ type Client struct { rateMu sync.Mutex rate Rate - ID string - Secret string + // AppType specifies the type of the app the client is used for. In particular, + // it determines the OAuth flow that will be used for token retrieval and renewal. + appType AppType + + ID string + Secret string + + // OAuth parameters for Script App Username string Password string + // OAuth parameters for Web App. + webOauth webAppOAuthParams + // This is the client's user ID in Reddit's database. redditID string @@ -110,7 +129,7 @@ func newClient() *Client { baseURL, _ := url.Parse(defaultBaseURL) tokenURL, _ := url.Parse(defaultTokenURL) - client := &Client{client: &http.Client{}, BaseURL: baseURL, TokenURL: tokenURL} + client := &Client{client: &http.Client{}, BaseURL: baseURL, TokenURL: tokenURL, appType: Script} client.Account = &AccountService{client: client} client.Collection = &CollectionService{client: client} @@ -163,7 +182,10 @@ func NewClient(credentials Credentials, opts ...Opt) (*Client, error) { client.client.CheckRedirect = client.redirect } - oauthTransport := oauthTransport(client) + oauthTransport, err := oauthTransport(client) + if err != nil { + return nil, err + } client.client.Transport = oauthTransport return client, nil diff --git a/reddit/stream.go b/reddit/stream.go index 14f5b85..09c4d32 100644 --- a/reddit/stream.go +++ b/reddit/stream.go @@ -16,6 +16,7 @@ type StreamService struct { // - a channel into which new posts will be sent // - a channel into which any errors will be sent // - a function that the client can call once to stop the streaming and close the channels +// // Because of the 100 post limit imposed by Reddit when fetching posts, some high-traffic // streams might drop submissions between API requests, such as when streaming r/all. func (s *StreamService) Posts(subreddit string, opts ...StreamOpt) (<-chan *Post, <-chan error, func()) { diff --git a/reddit/things.go b/reddit/things.go index 0a2f020..e773214 100644 --- a/reddit/things.go +++ b/reddit/things.go @@ -553,6 +553,10 @@ type Post struct { IsSelfPost bool `json:"is_self"` Saved bool `json:"saved"` Stickied bool `json:"stickied"` + Hidden bool `json:"hidden"` + + RemovedByCategory string `json:"removed_by_category"` + BannedBy string `json:"banned_by"` } // Subreddit holds information about a subreddit diff --git a/reddit/user.go b/reddit/user.go index f1839f5..7ca9665 100644 --- a/reddit/user.go +++ b/reddit/user.go @@ -24,6 +24,7 @@ type User struct { PostKarma int `json:"link_karma"` CommentKarma int `json:"comment_karma"` + TotalKarma int `json:"total_karma"` IsFriend bool `json:"is_friend"` IsEmployee bool `json:"is_employee"` @@ -40,6 +41,7 @@ type UserSummary struct { PostKarma int `json:"link_karma"` CommentKarma int `json:"comment_karma"` + TotalKarma int `json:"total_karma"` NSFW bool `json:"profile_over_18"` } diff --git a/reddit/user_test.go b/reddit/user_test.go index 81a2fd0..d008407 100644 --- a/reddit/user_test.go +++ b/reddit/user_test.go @@ -18,6 +18,7 @@ var expectedUser = &User{ PostKarma: 8239, CommentKarma: 130514, + TotalKarma: 10765, HasVerifiedEmail: true, } @@ -28,6 +29,7 @@ var expectedUsers = map[string]*UserSummary{ Created: &Timestamp{time.Date(2017, 3, 12, 2, 1, 47, 0, time.UTC)}, PostKarma: 488, CommentKarma: 22223, + TotalKarma: 28904, NSFW: false, }, "t2_2": { @@ -35,6 +37,7 @@ var expectedUsers = map[string]*UserSummary{ Created: &Timestamp{time.Date(2015, 12, 20, 18, 12, 51, 0, time.UTC)}, PostKarma: 8277, CommentKarma: 131948, + TotalKarma: 140026, NSFW: false, }, "t2_3": { @@ -42,6 +45,7 @@ var expectedUsers = map[string]*UserSummary{ Created: &Timestamp{time.Date(2013, 3, 4, 15, 46, 31, 0, time.UTC)}, PostKarma: 126887, CommentKarma: 81918, + TotalKarma: 200341, NSFW: true, }, } @@ -166,6 +170,7 @@ var expectedSearchUsers = []*User{ PostKarma: 1075227, CommentKarma: 339569, + TotalKarma: 10123, HasVerifiedEmail: true, }, @@ -176,6 +181,7 @@ var expectedSearchUsers = []*User{ PostKarma: 76744, CommentKarma: 42717, + TotalKarma: 84842, HasVerifiedEmail: true, }, diff --git a/testdata/post/post.json b/testdata/post/post.json index ee9ea26..512a678 100644 --- a/testdata/post/post.json +++ b/testdata/post/post.json @@ -19,7 +19,7 @@ "title": "Test", "link_flair_richtext": [], "subreddit_name_prefixed": "r/test", - "hidden": false, + "hidden": true, "pwls": 6, "link_flair_css_class": null, "downs": 0, @@ -61,7 +61,7 @@ "created": 1595096767.0, "link_flair_type": "text", "wls": 6, - "removed_by_category": null, + "removed_by_category": "moderator", "banned_by": null, "author_flair_type": "text", "domain": "self.test", diff --git a/testdata/user/get-multiple-by-id.json b/testdata/user/get-multiple-by-id.json index d8167b2..a36d8ae 100644 --- a/testdata/user/get-multiple-by-id.json +++ b/testdata/user/get-multiple-by-id.json @@ -3,6 +3,7 @@ "comment_karma": 22223, "created_utc": 1489284107, "link_karma": 488, + "total_karma": 28904, "name": "test_user_1", "profile_color": "", "profile_img": "https://www.redditstatic.com/avatars/avatar_default_01_94E044.png", @@ -12,6 +13,7 @@ "comment_karma": 131948, "created_utc": 1450635171, "link_karma": 8277, + "total_karma": 140026, "name": "test_user_2", "profile_color": "", "profile_img": "https://www.redditstatic.com/avatars/avatar_default_16_25B79F.png", @@ -21,6 +23,7 @@ "comment_karma": 81918, "created_utc": 1362411991, "link_karma": 126887, + "total_karma": 200341, "name": "test_user_3", "profile_color": "", "profile_img": "https://www.redditstatic.com/avatars/avatar_default_18_46A508.png", diff --git a/testdata/user/get.json b/testdata/user/get.json index 46300da..2282545 100644 --- a/testdata/user/get.json +++ b/testdata/user/get.json @@ -12,6 +12,7 @@ "created_utc": 1350555071.0, "link_karma": 8239, "comment_karma": 130514, + "total_karma": 10765, "is_gold": false, "is_mod": true, "verified": true, diff --git a/testdata/user/list.json b/testdata/user/list.json index b6c6ba6..a3e25cb 100644 --- a/testdata/user/list.json +++ b/testdata/user/list.json @@ -33,7 +33,10 @@ "previous_names": [], "user_is_moderator": false, "over_18": false, - "icon_size": [256, 256], + "icon_size": [ + 256, + 256 + ], "primary_color": "", "icon_img": "https://styles.redditmedia.com/t5_3kdh5/styles/profileIcon_0ws73gmqq8t21.png?width=256&height=256&crop=256:256,smart&s=a4d69298f5514b44cfa28a428c0953ebe0d5f6a1", "icon_color": "", @@ -49,7 +52,10 @@ "name": "t5_3kdh5", "is_default_banner": false, "url": "/user/washingtonpost/", - "banner_size": [1280, 384], + "banner_size": [ + 1280, + 384 + ], "user_is_contributor": false, "public_description": "Democracy Dies in Dankness. Official account.\n\nOur award-winning journalists have covered Washington and the world since 1877. Modded by /u/GenePark.", "link_flair_enabled": true, @@ -62,6 +68,7 @@ "is_mod": true, "accept_chats": false, "link_karma": 1075227, + "total_karma": 10123, "has_verified_email": true, "id": "179965", "accept_pms": true @@ -96,7 +103,10 @@ "previous_names": [], "user_is_moderator": false, "over_18": false, - "icon_size": [256, 256], + "icon_size": [ + 256, + 256 + ], "primary_color": "", "icon_img": "https://styles.redditmedia.com/t5_i4xj7/styles/profileIcon_mlsb0hlsebs01.jpg?width=256&height=256&crop=256:256,smart&s=7cb6c6fcf5079cd5514ea626e73398429f3b4b54", "icon_color": "", @@ -112,7 +122,10 @@ "name": "t5_i4xj7", "is_default_banner": false, "url": "/user/reuters/", - "banner_size": [1280, 384], + "banner_size": [ + 1280, + 384 + ], "user_is_contributor": false, "public_description": "", "link_flair_enabled": true, @@ -125,6 +138,7 @@ "is_mod": false, "accept_chats": false, "link_karma": 76744, + "total_karma": 84842, "has_verified_email": true, "id": "11kowl2w", "accept_pms": true