Skip to content

Commit

Permalink
Merge branch 'upstream-pr-609' into local-pr-609
Browse files Browse the repository at this point in the history
Pull from upstream's PR golang#609, and merge into local branch so
I can push to my fork.
  • Loading branch information
somersf committed Nov 21, 2022
2 parents ec4a9b2 + 6c5e40d commit 038532b
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 9 deletions.
86 changes: 86 additions & 0 deletions deviceauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package oauth2

import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"

"golang.org/x/net/context/ctxhttp"
"golang.org/x/oauth2/internal"
)

const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)

type DeviceAuth struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri,verification_url"`
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval,omitempty"`
raw map[string]interface{}
}

func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuth, error) {
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")

r, err := ctxhttp.Do(ctx, nil, req)
if err != nil {
return nil, err
}

body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}
}

da := &DeviceAuth{}
err = json.Unmarshal(body, &da)
if err != nil {
return nil, fmt.Errorf("unmarshal %s", err)
}

_ = json.Unmarshal(body, &da.raw)

// Azure AD supplies verification_url instead of verification_uri
if da.VerificationURI == "" {
da.VerificationURI, _ = da.raw["verification_url"].(string)
}

return da, nil
}

func parseError(err error) string {
e, ok := err.(*RetrieveError)
if ok {
eResp := make(map[string]string)
_ = json.Unmarshal(e.Body, &eResp)
return eResp["error"]
}
e2, ok := err.(*internal.TokenError)
if ok {
return e2.Err
}
return ""
}
5 changes: 3 additions & 2 deletions endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ var Fitbit = oauth2.Endpoint{

// GitHub is the endpoint for Github.
var GitHub = oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
DeviceAuthURL: "https://github.com/login/device/code",
}

// GitLab is the endpoint for GitLab.
Expand Down
37 changes: 32 additions & 5 deletions internal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,15 @@ type Token struct {
}

// tokenJSON is the struct representing the HTTP response from OAuth2
// providers returning a token in JSON form.
// providers returning a token or error in JSON form.
type tokenJSON struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}

func (e *tokenJSON) expiry() (t time.Time) {
Expand Down Expand Up @@ -253,6 +256,13 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
if err != nil {
return nil, err
}
if tokenError := vals.Get("error"); tokenError != "" {
return nil, &TokenError{
Err: tokenError,
ErrorDescription: vals.Get("error_description"),
ErrorURI: vals.Get("error_uri"),
}
}
token = &Token{
AccessToken: vals.Get("access_token"),
TokenType: vals.Get("token_type"),
Expand All @@ -269,6 +279,13 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
if err = json.Unmarshal(body, &tj); err != nil {
return nil, err
}
if tj.Error != "" {
return nil, &TokenError{
Err: tj.Error,
ErrorDescription: tj.ErrorDescription,
ErrorURI: tj.ErrorURI,
}
}
token = &Token{
AccessToken: tj.AccessToken,
TokenType: tj.TokenType,
Expand All @@ -292,3 +309,13 @@ type RetrieveError struct {
func (r *RetrieveError) Error() string {
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
}

type TokenError struct {
Err string
ErrorDescription string
ErrorURI string
}

func (t *TokenError) Error() string {
return fmt.Sprintf("oauth2: error in token fetch response: %s\nerror_description: %s\nerror_uri: %s", t.Err, t.ErrorDescription, t.ErrorURI)
}
64 changes: 62 additions & 2 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"net/url"
"strings"
"sync"
"time"

"golang.org/x/oauth2/internal"
)
Expand Down Expand Up @@ -70,8 +71,9 @@ type TokenSource interface {
// Endpoint represents an OAuth 2.0 provider's authorization and token
// endpoint URLs.
type Endpoint struct {
AuthURL string
TokenURL string
AuthURL string
DeviceAuthURL string
TokenURL string

// AuthStyle optionally specifies how the endpoint wants the
// client ID & client secret sent. The zero value means to
Expand Down Expand Up @@ -224,6 +226,64 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti
return retrieveToken(ctx, c, v)
}

// AuthDevice returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func (c *Config) AuthDevice(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuth, error) {
v := url.Values{
"client_id": {c.ClientID},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveDeviceAuth(ctx, c, v)
}

// Poll does a polling to exchange an device code for a token.
func (c *Config) Poll(ctx context.Context, da *DeviceAuth, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{
"client_id": {c.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {da.DeviceCode},
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
for _, opt := range opts {
opt.setValue(v)
}

// If no interval was provided, the client MUST use a reasonable default polling interval.
// See https://tools.ietf.org/html/draft-ietf-oauth-device-flow-07#section-3.5
interval := da.Interval
if interval == 0 {
interval = 5
}

for {
time.Sleep(time.Duration(interval) * time.Second)

tok, err := retrieveToken(ctx, c, v)
if err == nil {
return tok, nil
}

errTyp := parseError(err)
switch errTyp {
case errSlowDown:
interval += 5
case errAuthorizationPending:
// Do nothing.
case errAccessDenied, errExpiredToken:
fallthrough
default:
return tok, err
}
}
}

// Client returns an HTTP client using the provided token.
// The token will auto-refresh as necessary. The underlying
// HTTP transport will be obtained using the provided context.
Expand Down

0 comments on commit 038532b

Please sign in to comment.