diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index 0197452..0abfb95 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -53,14 +53,10 @@ def request_phase end def authorize_params - options.authorize_params[:state] = SecureRandom.hex(24) - params = options.authorize_params.merge(options_for("authorize")) - if OmniAuth.config.test_mode - @env ||= {} - @env["rack.session"] ||= {} - end - session["omniauth.state"] = params[:state] - params + site = self.class.name + state = state_store.fetch(site) { state_store[site] = SecureRandom.hex(24) } + options.authorize_params[:state] = state + options.authorize_params.merge(options_for("authorize")) end def token_params @@ -69,11 +65,18 @@ def token_params def callback_phase # rubocop:disable AbcSize, CyclomaticComplexity, MethodLength, PerceivedComplexity error = request.params["error_reason"] || request.params["error"] + site = self.class.name + expected_state = state_store[site] + actual_state = request.params["state"].to_s + if error - fail!(error, CallbackError.new(request.params["error"], request.params["error_description"] || request.params["error_reason"], request.params["error_uri"])) - elsif !options.provider_ignores_state && (request.params["state"].to_s.empty? || request.params["state"] != session.delete("omniauth.state")) + description = request.params["error_description"] || request.params["error_reason"] + error_uri = request.params["error_uri"] + fail!(error, CallbackError.new(request.params["error"], description, error_uri)) + elsif !options.provider_ignores_state && (actual_state.empty? || actual_state != expected_state) fail!(:csrf_detected, CallbackError.new(:csrf_detected, "CSRF detected")) else + state_store.delete(site) self.access_token = build_access_token self.access_token = access_token.refresh! if access_token.expired? super @@ -109,6 +112,22 @@ def options_for(option) hash end + private + + def state_store + if OmniAuth.config.test_mode + @env ||= {} + @env["rack.session"] ||= {} + end + + state_key = "omniauth.oauth2.state" + state_store = session[state_key] + unless state_store.is_a?(Hash) + state_store = session[state_key] = {} + end + state_store + end + # An error that is indicated in the OAuth 2.0 callback. # This could be a `redirect_uri_mismatch` or other class CallbackError < StandardError diff --git a/spec/omniauth/strategies/oauth2_spec.rb b/spec/omniauth/strategies/oauth2_spec.rb index 0bffcde..4c9db94 100644 --- a/spec/omniauth/strategies/oauth2_spec.rb +++ b/spec/omniauth/strategies/oauth2_spec.rb @@ -53,11 +53,49 @@ def app expect(instance.authorize_params["scope"]).to eq("bar") expect(instance.authorize_params["foo"]).to eq("baz") end + end - it "includes random state in the authorize params" do - instance = subject.new("abc", "def") - expect(instance.authorize_params.keys).to eq(["state"]) - expect(instance.session["omniauth.state"]).not_to be_empty + describe "state handling" do + SocialNetwork = Class.new(OmniAuth::Strategies::OAuth2) + + let(:client_options) { {:site => "https://graph.example.com"} } + let(:instance) { SocialNetwork.new(-> env {}) } + + before do + allow(SecureRandom).to receive(:hex).with(24).and_return("hex-1", "hex-2") + end + + it "includes a state scoped to the client" do + expect(instance.authorize_params["state"]).to eq("hex-1") + expect(instance.session["omniauth.oauth2.state"]).to eq("SocialNetwork" => "hex-1") + end + + context "once a state value has been generated" do + before do + instance.authorize_params + end + + it "does not replace an existing session value" do + expect(instance.authorize_params["state"]).to eq("hex-1") + expect(instance.session["omniauth.oauth2.state"]).to eq("SocialNetwork" => "hex-1") + end + end + + context "on a successful callback" do + let(:request) { double("Request", :params => {"code" => "auth-code", "state" => "hex-1"}) } + let(:access_token) { double("AccessToken", :expired? => false, :expires? => false, :token => "access-token") } + + before do + allow(instance).to receive(:request).and_return(request) + allow(instance).to receive(:build_access_token).and_return(access_token) + + instance.authorize_params + instance.callback_phase + end + + it "removes the value from the session" do + expect(instance.session["omniauth.oauth2.state"]).to eq({}) + end end end