diff --git a/lib/pundit.rb b/lib/pundit.rb index 5eab334f..137fdc00 100644 --- a/lib/pundit.rb +++ b/lib/pundit.rb @@ -80,7 +80,8 @@ def authorize(user, possibly_namespaced_record, query, policy_class: nil, cache: policy = if policy_class policy_class.new(user, record) else - cache[possibly_namespaced_record] ||= policy!(user, possibly_namespaced_record) + cache[{ policy_class: policy_class, + record: possibly_namespaced_record }] ||= policy!(user, possibly_namespaced_record) end raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query) diff --git a/lib/pundit/authorization.rb b/lib/pundit/authorization.rb index 1231f2a7..2c7974cc 100644 --- a/lib/pundit/authorization.rb +++ b/lib/pundit/authorization.rb @@ -98,9 +98,14 @@ def policy_scope(scope, policy_scope_class: nil) # # @see https://github.com/varvet/pundit#policies # @param record [Object] the object we're retrieving the policy for + # @param policy_class [Class] the policy class we want to force use of # @return [Object, nil] instance of policy class with query methods - def policy(record) - policies[record] ||= Pundit.policy!(pundit_user, record) + def policy(record, policy_class: nil) + policies[{ policy_class: policy_class, record: record }] ||= if policy_class + policy_class.new(pundit_user, record) + else + Pundit.policy!(pundit_user, record) + end end # Retrieves a set of permitted attributes from the policy by instantiating @@ -113,9 +118,10 @@ def policy(record) # @param record [Object] the object we're retrieving permitted attributes for # @param action [Symbol, String] the name of the action being performed on the record (e.g. `:update`). # If omitted then this defaults to the Rails controller action name. + # @param policy_class [Class] the policy class we want to force use of # @return [Hash{String => Object}] the permitted attributes - def permitted_attributes(record, action = action_name) - policy = policy(record) + def permitted_attributes(record, action = action_name, policy_class: nil) + policy = policy(record, policy_class: policy_class) method_name = if policy.respond_to?("permitted_attributes_for_#{action}") "permitted_attributes_for_#{action}" else diff --git a/spec/authorization_spec.rb b/spec/authorization_spec.rb index 2995bfdb..b567cf96 100644 --- a/spec/authorization_spec.rb +++ b/spec/authorization_spec.rb @@ -104,9 +104,9 @@ end it "caches the policy" do - expect(controller.policies[post]).to be_nil + expect(controller.policies[{ policy_class: nil, record: post }]).to be_nil controller.authorize(post) - expect(controller.policies[post]).not_to be_nil + expect(controller.policies[{ policy_class: nil, record: post }]).not_to be_nil end it "raises an error when the given record is nil" do @@ -155,7 +155,7 @@ it "allows policy to be injected" do new_policy = OpenStruct.new - controller.policies[post] = new_policy + controller.policies[{ policy_class: nil, record: post }] = new_policy expect(controller.policy(post)).to eq new_policy end