Skip to content

Commit

Permalink
added GenerateAuthorizationURLState for returning generated state, co…
Browse files Browse the repository at this point in the history
…nfig for default StateController, state expiring & state generation function overwrite
  • Loading branch information
topi314 committed Nov 30, 2021
1 parent ddea2e3 commit c1fb9c0
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 31 deletions.
19 changes: 13 additions & 6 deletions oauth2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func New(id discord.Snowflake, secret string, opts ...ConfigOpt) *Client {
config.SessionController = NewSessionController()
}
if config.StateController == nil {
config.StateController = NewStateController()
config.StateController = NewStateController(config.StateControllerConfig)
}

return &Client{ID: id, Secret: secret, Config: *config}
Expand All @@ -53,14 +53,21 @@ type Client struct {
Config
}

// GenerateAuthorizationURL generates an authorization URL with the given redirect URI & scopes, state is automatically generated
// GenerateAuthorizationURL generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes, state is automatically generated
func (c *Client) GenerateAuthorizationURL(redirectURI string, permissions discord.Permissions, guildID discord.Snowflake, disableGuildSelect bool, scopes ...discord.ApplicationScope) string {
url, _ := c.GenerateAuthorizationURLState(redirectURI, permissions, guildID, disableGuildSelect, scopes...)
return url
}

// GenerateAuthorizationURLState generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes, state is automatically generated & returned
func (c *Client) GenerateAuthorizationURLState(redirectURI string, permissions discord.Permissions, guildID discord.Snowflake, disableGuildSelect bool, scopes ...discord.ApplicationScope) (string, string) {
state := c.StateController.GenerateNewState(redirectURI)
values := route.QueryValues{
"client_id": c.ID,
"redirect_uri": redirectURI,
"response_type": "code",
"scope": discord.JoinScopes(scopes),
"state": c.StateController.GenerateNewState(redirectURI),
"state": state,
}
if permissions != discord.PermissionsNone {
values["permissions"] = permissions
Expand All @@ -72,16 +79,16 @@ func (c *Client) GenerateAuthorizationURL(redirectURI string, permissions discor
values["disable_guild_select"] = true
}
compiledRoute, _ := route.Authorize.Compile(values)
return compiledRoute.URL()
return compiledRoute.URL(), state
}

// StartSession starts a new session with the given authorization code & state
func (c *Client) StartSession(code string, state string, identifier string, opts ...rest.RequestOpt) (Session, error) {
redirectURI := c.StateController.ConsumeState(state)
if redirectURI == nil {
if redirectURI == "" {
return nil, ErrStateNotFound
}
exchange, err := c.OAuth2Service.GetAccessToken(c.ID, c.Secret, code, *redirectURI, opts...)
exchange, err := c.OAuth2Service.GetAccessToken(c.ID, c.Secret, code, redirectURI, opts...)
if err != nil {
return nil, err
}
Expand Down
32 changes: 26 additions & 6 deletions oauth2/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ var DefaultConfig = Config{

// Config is the configuration for the OAuth2 client
type Config struct {
Logger log.Logger
RestClient rest.Client
RestClientConfig *rest.Config
OAuth2Service rest.OAuth2Service
SessionController SessionController
StateController StateController
Logger log.Logger
RestClient rest.Client
RestClientConfig *rest.Config
OAuth2Service rest.OAuth2Service
SessionController SessionController
StateControllerConfig *StateControllerConfig
StateController StateController
}

// ConfigOpt can be used to supply optional parameters to New
Expand Down Expand Up @@ -88,3 +89,22 @@ func WithStateController(stateController StateController) ConfigOpt {
config.StateController = stateController
}
}

// WithStateControllerConfig applies a custom StateControllerConfig to the SessionController
//goland:noinspection GoUnusedExportedFunction
func WithStateControllerConfig(stateControllerConfig StateControllerConfig) ConfigOpt {
return func(config *Config) {
config.StateControllerConfig = &stateControllerConfig
}
}

// WithStateControllerOpts applies all StateControllerConfigOpt(s) to the StateController
//goland:noinspection GoUnusedExportedFunction
func WithStateControllerOpts(opts ...StateControllerConfigOpt) ConfigOpt {
return func(config *Config) {
if config.StateControllerConfig == nil {
config.StateControllerConfig = &DefaultStateControllerConfig
}
config.StateControllerConfig.Apply(opts)
}
}
46 changes: 27 additions & 19 deletions oauth2/state_controller.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,51 @@
package oauth2

import "github.com/DisgoOrg/disgo/internal/insecurerandstr"

var _ StateController = (*stateControllerImpl)(nil)
var (
_ StateController = (*stateControllerImpl)(nil)
)

// StateController is responsible for generating, storing and validating states
type StateController interface {
// GenerateNewState generates a new random state to be used as a state
GenerateNewState(redirectURI string) string

// ConsumeState validates a state and returns the redirect url or nil if it is invalid
ConsumeState(state string) *string
ConsumeState(state string) string
}

// NewStateController returns a new empty StateController
func NewStateController() StateController {
return NewStateControllerWithStates(map[string]string{})
}
func NewStateController(config *StateControllerConfig) StateController {
if config == nil {
config = &DefaultStateControllerConfig
}

// NewStateControllerWithStates returns a new StateController with the given states
func NewStateControllerWithStates(states map[string]string) StateController {
return &stateControllerImpl{states: states}
states := NewTTLMap(config.MaxTTL)
for state, url := range config.States {
states.Put(state, url)
}

return &stateControllerImpl{
states: states,
newStateFunc: config.NewStateFunc,
}
}

type stateControllerImpl struct {
states map[string]string
states *TTLMap
newStateFunc func() string
}

func (c *stateControllerImpl) GenerateNewState(redirectURI string) string {
state := insecurerandstr.RandStr(32)
c.states[state] = redirectURI
state := c.newStateFunc()
c.states.Put(state, redirectURI)
return state
}

func (c *stateControllerImpl) ConsumeState(state string) *string {
uri, ok := c.states[state]
if !ok {
return nil
func (c *stateControllerImpl) ConsumeState(state string) string {
uri := c.states.Get(state)
if uri == "" {
return ""
}
delete(c.states, state)
return &uri
c.states.Delete(state)
return uri
}
55 changes: 55 additions & 0 deletions oauth2/state_controller_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package oauth2

import (
"time"

"github.com/DisgoOrg/disgo/internal/insecurerandstr"
)

// DefaultStateControllerConfig is the default configuration for the StateController
var DefaultStateControllerConfig = StateControllerConfig{
States: map[string]string{},
NewStateFunc: func() string { return insecurerandstr.RandStr(32) },
MaxTTL: time.Hour,
}

// StateControllerConfig is the configuration for the StateController
type StateControllerConfig struct {
States map[string]string
NewStateFunc func() string
MaxTTL time.Duration
}

// StateControllerConfigOpt is used to pass optional parameters to NewStateController
type StateControllerConfigOpt func(config *StateControllerConfig)

// Apply applies the given StateControllerConfigOpt(s) to the StateControllerConfig
func (c *StateControllerConfig) Apply(opts []StateControllerConfigOpt) {
for _, opt := range opts {
opt(c)
}
}

// WithStates loads states from an existing map
//goland:noinspection GoUnusedExportedFunction
func WithStates(states map[string]string) StateControllerConfigOpt {
return func(config *StateControllerConfig) {
config.States = states
}
}

// WithNewStateFunc sets the function which is used to generate a new random state
//goland:noinspection GoUnusedExportedFunction
func WithNewStateFunc(newStateFunc func() string) StateControllerConfigOpt {
return func(config *StateControllerConfig) {
config.NewStateFunc = newStateFunc
}
}

// WithMaxTTL sets the maximum time to live for a state
//goland:noinspection GoUnusedExportedFunction
func WithMaxTTL(maxTTL time.Duration) StateControllerConfigOpt {
return func(config *StateControllerConfig) {
config.MaxTTL = maxTTL
}
}
68 changes: 68 additions & 0 deletions oauth2/ttl_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package oauth2

import (
"sync"
"time"
)

type value struct {
value string
insertedAt int64
}

func NewTTLMap(maxTTL time.Duration) *TTLMap {
m := &TTLMap{
maxTTL: maxTTL,
m: map[string]value{},
}

if maxTTL > 0 {
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for now := range ticker.C {
m.mu.Lock()
for k, v := range m.m {
if now.Unix()-v.insertedAt > int64(m.maxTTL) {
delete(m.m, k)
}
}
m.mu.Unlock()
}
}()
}

return m
}

type TTLMap struct {
maxTTL time.Duration
m map[string]value
mu sync.Mutex
}

func (m *TTLMap) Len() int {
return len(m.m)
}

func (m *TTLMap) Put(k string, v string) {
m.mu.Lock()
m.m[k] = value{v, time.Now().Unix()}
m.mu.Unlock()
}

func (m *TTLMap) Get(k string) string {
m.mu.Lock()
v, ok := m.m[k]
m.mu.Unlock()
if ok {
return v.value
}
return ""
}

func (m *TTLMap) Delete(k string) {
m.mu.Lock()
delete(m.m, k)
m.mu.Unlock()
}

0 comments on commit c1fb9c0

Please sign in to comment.