-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added GenerateAuthorizationURLState for returning generated state, co…
…nfig for default StateController, state expiring & state generation function overwrite
- Loading branch information
Showing
5 changed files
with
189 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,51 @@ | ||
package oauth2 | ||
|
||
import "github.com/DisgoOrg/disgo/internal/insecurerandstr" | ||
|
||
var _ StateController = (*stateControllerImpl)(nil) | ||
var ( | ||
_ StateController = (*stateControllerImpl)(nil) | ||
) | ||
|
||
// StateController is responsible for generating, storing and validating states | ||
type StateController interface { | ||
// GenerateNewState generates a new random state to be used as a state | ||
GenerateNewState(redirectURI string) string | ||
|
||
// ConsumeState validates a state and returns the redirect url or nil if it is invalid | ||
ConsumeState(state string) *string | ||
ConsumeState(state string) string | ||
} | ||
|
||
// NewStateController returns a new empty StateController | ||
func NewStateController() StateController { | ||
return NewStateControllerWithStates(map[string]string{}) | ||
} | ||
func NewStateController(config *StateControllerConfig) StateController { | ||
if config == nil { | ||
config = &DefaultStateControllerConfig | ||
} | ||
|
||
// NewStateControllerWithStates returns a new StateController with the given states | ||
func NewStateControllerWithStates(states map[string]string) StateController { | ||
return &stateControllerImpl{states: states} | ||
states := NewTTLMap(config.MaxTTL) | ||
for state, url := range config.States { | ||
states.Put(state, url) | ||
} | ||
|
||
return &stateControllerImpl{ | ||
states: states, | ||
newStateFunc: config.NewStateFunc, | ||
} | ||
} | ||
|
||
type stateControllerImpl struct { | ||
states map[string]string | ||
states *TTLMap | ||
newStateFunc func() string | ||
} | ||
|
||
func (c *stateControllerImpl) GenerateNewState(redirectURI string) string { | ||
state := insecurerandstr.RandStr(32) | ||
c.states[state] = redirectURI | ||
state := c.newStateFunc() | ||
c.states.Put(state, redirectURI) | ||
return state | ||
} | ||
|
||
func (c *stateControllerImpl) ConsumeState(state string) *string { | ||
uri, ok := c.states[state] | ||
if !ok { | ||
return nil | ||
func (c *stateControllerImpl) ConsumeState(state string) string { | ||
uri := c.states.Get(state) | ||
if uri == "" { | ||
return "" | ||
} | ||
delete(c.states, state) | ||
return &uri | ||
c.states.Delete(state) | ||
return uri | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package oauth2 | ||
|
||
import ( | ||
"time" | ||
|
||
"github.com/DisgoOrg/disgo/internal/insecurerandstr" | ||
) | ||
|
||
// DefaultStateControllerConfig is the default configuration for the StateController | ||
var DefaultStateControllerConfig = StateControllerConfig{ | ||
States: map[string]string{}, | ||
NewStateFunc: func() string { return insecurerandstr.RandStr(32) }, | ||
MaxTTL: time.Hour, | ||
} | ||
|
||
// StateControllerConfig is the configuration for the StateController | ||
type StateControllerConfig struct { | ||
States map[string]string | ||
NewStateFunc func() string | ||
MaxTTL time.Duration | ||
} | ||
|
||
// StateControllerConfigOpt is used to pass optional parameters to NewStateController | ||
type StateControllerConfigOpt func(config *StateControllerConfig) | ||
|
||
// Apply applies the given StateControllerConfigOpt(s) to the StateControllerConfig | ||
func (c *StateControllerConfig) Apply(opts []StateControllerConfigOpt) { | ||
for _, opt := range opts { | ||
opt(c) | ||
} | ||
} | ||
|
||
// WithStates loads states from an existing map | ||
//goland:noinspection GoUnusedExportedFunction | ||
func WithStates(states map[string]string) StateControllerConfigOpt { | ||
return func(config *StateControllerConfig) { | ||
config.States = states | ||
} | ||
} | ||
|
||
// WithNewStateFunc sets the function which is used to generate a new random state | ||
//goland:noinspection GoUnusedExportedFunction | ||
func WithNewStateFunc(newStateFunc func() string) StateControllerConfigOpt { | ||
return func(config *StateControllerConfig) { | ||
config.NewStateFunc = newStateFunc | ||
} | ||
} | ||
|
||
// WithMaxTTL sets the maximum time to live for a state | ||
//goland:noinspection GoUnusedExportedFunction | ||
func WithMaxTTL(maxTTL time.Duration) StateControllerConfigOpt { | ||
return func(config *StateControllerConfig) { | ||
config.MaxTTL = maxTTL | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package oauth2 | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
) | ||
|
||
type value struct { | ||
value string | ||
insertedAt int64 | ||
} | ||
|
||
func NewTTLMap(maxTTL time.Duration) *TTLMap { | ||
m := &TTLMap{ | ||
maxTTL: maxTTL, | ||
m: map[string]value{}, | ||
} | ||
|
||
if maxTTL > 0 { | ||
go func() { | ||
ticker := time.NewTicker(10 * time.Second) | ||
defer ticker.Stop() | ||
for now := range ticker.C { | ||
m.mu.Lock() | ||
for k, v := range m.m { | ||
if now.Unix()-v.insertedAt > int64(m.maxTTL) { | ||
delete(m.m, k) | ||
} | ||
} | ||
m.mu.Unlock() | ||
} | ||
}() | ||
} | ||
|
||
return m | ||
} | ||
|
||
type TTLMap struct { | ||
maxTTL time.Duration | ||
m map[string]value | ||
mu sync.Mutex | ||
} | ||
|
||
func (m *TTLMap) Len() int { | ||
return len(m.m) | ||
} | ||
|
||
func (m *TTLMap) Put(k string, v string) { | ||
m.mu.Lock() | ||
m.m[k] = value{v, time.Now().Unix()} | ||
m.mu.Unlock() | ||
} | ||
|
||
func (m *TTLMap) Get(k string) string { | ||
m.mu.Lock() | ||
v, ok := m.m[k] | ||
m.mu.Unlock() | ||
if ok { | ||
return v.value | ||
} | ||
return "" | ||
} | ||
|
||
func (m *TTLMap) Delete(k string) { | ||
m.mu.Lock() | ||
delete(m.m, k) | ||
m.mu.Unlock() | ||
} |