From 73bf078992918128a0bd92ecd9f868cb87f5c39d Mon Sep 17 00:00:00 2001 From: Jason Numeroff Date: Sun, 18 Aug 2024 10:52:32 -0400 Subject: [PATCH] Optionally allow overriding the policy class for permitted params --- README.md | 16 ++++++++++++++++ lib/pundit/authorization.rb | 12 ++++-------- lib/pundit/context.rb | 22 ++++++++++++++++++++++ spec/authorization_spec.rb | 14 ++++++++++++++ spec/spec_helper.rb | 4 ++++ 5 files changed, 60 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index a1b59868..d7be91c9 100644 --- a/README.md +++ b/README.md @@ -757,6 +757,22 @@ def pundit_params_for(_record) end ``` +You can pass an argument to override the policy class if necessary. For example: + +```ruby +# app/controllers/posts_controller.rb +class PostsController < ApplicationController + def update + @post = Post.find(params[:id]) + if @post.update(permitted_attributes(@post), policy_class: PostPolicy) + redirect_to @post + else + render :edit + end + end +end +``` + ## RSpec ### Policy Specs diff --git a/lib/pundit/authorization.rb b/lib/pundit/authorization.rb index bc4cc4f4..ea6583f1 100644 --- a/lib/pundit/authorization.rb +++ b/lib/pundit/authorization.rb @@ -121,15 +121,11 @@ def policy(record) # @param record [Object] the object we're retrieving permitted attributes for # @param action [Symbol, String] the name of the action being performed on the record (e.g. `:update`). # If omitted then this defaults to the Rails controller action name. + # @param policy_class [Class] the policy class we want to force use of # @return [Hash{String => Object}] the permitted attributes - def permitted_attributes(record, action = action_name) - policy = policy(record) - method_name = if policy.respond_to?("permitted_attributes_for_#{action}") - "permitted_attributes_for_#{action}" - else - "permitted_attributes" - end - pundit_params_for(record).permit(*policy.public_send(method_name)) + def permitted_attributes(record, action = action_name, policy_class: nil) + required_params = pundit_params_for(record) + pundit.permitted_attributes(record, action: action, required_params: required_params, policy_class: policy_class) end # Retrieves the params for the given record. diff --git a/lib/pundit/context.rb b/lib/pundit/context.rb index a5f86716..24a6ea21 100644 --- a/lib/pundit/context.rb +++ b/lib/pundit/context.rb @@ -99,6 +99,28 @@ def policy!(record) cached_find(record, &:policy!) end + # Retrieves a set of permitted attributes from the policy by instantiating + # the policy class for the given record and calling `permitted_attributes` on + # it, or `permitted_attributes_for_{action}` if `action` is defined. It then infers + # what key the record should have in the params hash and retrieves the + # permitted attributes from the params hash under that key. + # + # @see https://github.com/varvet/pundit#strong-parameters + # @param record [Object] the object we're retrieving permitted attributes for + # @param action [Symbol, String] the name of the action being performed on the record (e.g. `:update`). + # @param required_params [ActionController::Parameters] the params + # @param policy_class [Class] the policy class we want to force use of + # @return [Hash{String => Object}] the permitted attributes + def permitted_attributes(record, action:, required_params:, policy_class: nil) + policy = policy_class ? policy_class.new(user, record) : policy(record) + method_name = if policy.respond_to?("permitted_attributes_for_#{action}") + "permitted_attributes_for_#{action}" + else + "permitted_attributes" + end + required_params.permit(*policy.public_send(method_name)) + end + private def cached_find(record) diff --git a/spec/authorization_spec.rb b/spec/authorization_spec.rb index 8bfa3fcb..00cda217 100644 --- a/spec/authorization_spec.rb +++ b/spec/authorization_spec.rb @@ -208,6 +208,20 @@ def to_params(*args, **kwargs, &block) expect(Controller.new(double, action, params).permitted_attributes(post).to_h).to eq("votes" => 5) end + it "checks different policy for permitted attributes" do + params = to_params( + post: { + title: "Hello", + votes: 5 + } + ) + + action = "update" + + expect(Controller.new(user, action, params) + .permitted_attributes(post, policy_class: PublicationPolicy).to_h).to eq("title" => "Hello") + end + it "checks policy for permitted attributes for record of a ActiveModel type" do customer_post = Customer::Post.new(user) params = to_params( diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index ff70ac0d..2369f08a 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -161,6 +161,10 @@ def resolve def create? true end + + def permitted_attributes + [:title] + end end class Comment