diff --git a/lib/pundit.rb b/lib/pundit.rb index 5eab334f..a70aca53 100644 --- a/lib/pundit.rb +++ b/lib/pundit.rb @@ -8,6 +8,7 @@ require "active_support/core_ext/module/introspection" require "active_support/dependencies/autoload" require "pundit/authorization" +require "pundit/context" # @api private # To avoid name clashes with common Error naming when mixing in Pundit, @@ -64,105 +65,22 @@ def self.included(base) end class << self - # Retrieves the policy for the given record, initializing it with the - # record and user and finally throwing an error if the user is not - # authorized to perform the given action. - # - # @param user [Object] the user that initiated the action - # @param possibly_namespaced_record [Object, Array] the object we're checking permissions of - # @param query [Symbol, String] the predicate method to check on the policy (e.g. `:show?`) - # @param policy_class [Class] the policy class we want to force use of - # @param cache [#[], #[]=] a Hash-like object to cache the found policy instance in - # @raise [NotAuthorizedError] if the given query method returned false - # @return [Object] Always returns the passed object record - def authorize(user, possibly_namespaced_record, query, policy_class: nil, cache: {}) - record = pundit_model(possibly_namespaced_record) - policy = if policy_class - policy_class.new(user, record) - else - cache[possibly_namespaced_record] ||= policy!(user, possibly_namespaced_record) - end - - raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query) - - record + # @see [Pundit::Context#authorize] + def authorize(user, record, query, policy_class: nil, cache: {}) + Context.new(user: user, policy_cache: cache).authorize(record, query: query, policy_class: policy_class) end - # Retrieves the policy scope for the given record. - # - # @see https://github.com/varvet/pundit#scopes - # @param user [Object] the user that initiated the action - # @param scope [Object] the object we're retrieving the policy scope for - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Scope{#resolve}, nil] instance of scope class which can resolve to a scope - def policy_scope(user, scope) - policy_scope_class = PolicyFinder.new(scope).scope - return unless policy_scope_class - - begin - policy_scope = policy_scope_class.new(user, pundit_model(scope)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" - end + # @see [Pundit::Context#policy_scope] + def policy_scope(user, ...) = Context.new(user: user).policy_scope(...) - policy_scope.resolve - end + # @see [Pundit::Context#policy_scope!] + def policy_scope!(user, ...) = Context.new(user: user).policy_scope!(...) - # Retrieves the policy scope for the given record. - # - # @see https://github.com/varvet/pundit#scopes - # @param user [Object] the user that initiated the action - # @param scope [Object] the object we're retrieving the policy scope for - # @raise [NotDefinedError] if the policy scope cannot be found - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Scope{#resolve}] instance of scope class which can resolve to a scope - def policy_scope!(user, scope) - policy_scope_class = PolicyFinder.new(scope).scope! - return unless policy_scope_class - - begin - policy_scope = policy_scope_class.new(user, pundit_model(scope)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" - end + # @see [Pundit::Context#policy] + def policy(user, ...) = Context.new(user: user).policy(...) - policy_scope.resolve - end - - # Retrieves the policy for the given record. - # - # @see https://github.com/varvet/pundit#policies - # @param user [Object] the user that initiated the action - # @param record [Object] the object we're retrieving the policy for - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Object, nil] instance of policy class with query methods - def policy(user, record) - policy = PolicyFinder.new(record).policy - policy&.new(user, pundit_model(record)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" - end - - # Retrieves the policy for the given record. - # - # @see https://github.com/varvet/pundit#policies - # @param user [Object] the user that initiated the action - # @param record [Object] the object we're retrieving the policy for - # @raise [NotDefinedError] if the policy cannot be found - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Object] instance of policy class with query methods - def policy!(user, record) - policy = PolicyFinder.new(record).policy! - policy.new(user, pundit_model(record)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" - end - - private - - def pundit_model(record) - record.is_a?(Array) ? record.last : record - end + # @see [Pundit::Context#policy!] + def policy!(user, ...) = Context.new(user: user).policy!(...) end # @api private diff --git a/lib/pundit/authorization.rb b/lib/pundit/authorization.rb index 1231f2a7..61f0b76b 100644 --- a/lib/pundit/authorization.rb +++ b/lib/pundit/authorization.rb @@ -15,6 +15,15 @@ module Authorization protected + # @return [Pundit::Core] a new instance of {Pundit::Core} with the current user + def pundit + @pundit ||= Pundit::Context.new( + user: pundit_user, + policy_cache: policies, + scope_cache: policy_scopes + ) + end + # @return [Boolean] whether authorization has been performed, i.e. whether # one {#authorize} or {#skip_authorization} has been called def pundit_policy_authorized? @@ -64,7 +73,7 @@ def authorize(record, query = nil, policy_class: nil) @_pundit_policy_authorized = true - Pundit.authorize(pundit_user, record, query, policy_class: policy_class, cache: policies) + pundit.authorize(record, query: query, policy_class: policy_class) end # Allow this action not to perform authorization. @@ -100,7 +109,7 @@ def policy_scope(scope, policy_scope_class: nil) # @param record [Object] the object we're retrieving the policy for # @return [Object, nil] instance of policy class with query methods def policy(record) - policies[record] ||= Pundit.policy!(pundit_user, record) + policies[record] ||= pundit.policy!(record) end # Retrieves a set of permitted attributes from the policy by instantiating @@ -115,7 +124,7 @@ def policy(record) # If omitted then this defaults to the Rails controller action name. # @return [Hash{String => Object}] the permitted attributes def permitted_attributes(record, action = action_name) - policy = policy(record) + policy = pundit.policy(record) method_name = if policy.respond_to?("permitted_attributes_for_#{action}") "permitted_attributes_for_#{action}" else @@ -162,7 +171,7 @@ def pundit_user private def pundit_policy_scope(scope) - policy_scopes[scope] ||= Pundit.policy_scope!(pundit_user, scope) + policy_scopes[scope] ||= pundit.policy_scope!(scope) end end end diff --git a/lib/pundit/context.rb b/lib/pundit/context.rb new file mode 100644 index 00000000..2a332462 --- /dev/null +++ b/lib/pundit/context.rb @@ -0,0 +1,127 @@ +# frozen_string_literal: true + +module Pundit + class Context + def initialize(user:, policy_cache: {}, scope_cache: {}) + @user = user + + @policy_cache = policy_cache + @scope_cache = scope_cache + end + + attr_reader :user + + def with_user(new_user) + clone.tap { _1.instance_variable_set(:@user, new_user) } + end + + # @api private + attr_reader :policy_cache + + # @api private + attr_reader :scope_cache + + # Retrieves the policy for the given record, initializing it with the + # record and user and finally throwing an error if the user is not + # authorized to perform the given action. + # + # @param user [Object] the user that initiated the action + # @param possibly_namespaced_record [Object, Array] the object we're checking permissions of + # @param query [Symbol, String] the predicate method to check on the policy (e.g. `:show?`) + # @param policy_class [Class] the policy class we want to force use of + # @raise [NotAuthorizedError] if the given query method returned false + # @return [Object] Always returns the passed object record + def authorize(possibly_namespaced_record, query:, policy_class:) + record = pundit_model(possibly_namespaced_record) + policy = if policy_class + policy_class.new(user, record) + else + policy_cache[possibly_namespaced_record] ||= policy!(possibly_namespaced_record) + end + + raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query) + + record + end + + # Retrieves the policy scope for the given record. + # + # @see https://github.com/varvet/pundit#scopes + # @param user [Object] the user that initiated the action + # @param scope [Object] the object we're retrieving the policy scope for + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Scope{#resolve}, nil] instance of scope class which can resolve to a scope + def policy_scope(scope) + policy_scope_class = policy_finder(scope).scope + return unless policy_scope_class + + begin + policy_scope = policy_scope_class.new(user, pundit_model(scope)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" + end + + policy_scope.resolve + end + + # Retrieves the policy scope for the given record. + # + # @see https://github.com/varvet/pundit#scopes + # @param user [Object] the user that initiated the action + # @param scope [Object] the object we're retrieving the policy scope for + # @raise [NotDefinedError] if the policy scope cannot be found + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Scope{#resolve}] instance of scope class which can resolve to a scope + def policy_scope!(scope) + policy_scope_class = policy_finder(scope).scope! + return unless policy_scope_class + + begin + policy_scope = policy_scope_class.new(user, pundit_model(scope)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" + end + + policy_scope.resolve + end + + # Retrieves the policy for the given record. + # + # @see https://github.com/varvet/pundit#policies + # @param user [Object] the user that initiated the action + # @param record [Object] the object we're retrieving the policy for + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Object, nil] instance of policy class with query methods + def policy(record) + policy = policy_finder(record).policy + policy&.new(user, pundit_model(record)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" + end + + # Retrieves the policy for the given record. + # + # @see https://github.com/varvet/pundit#policies + # @param user [Object] the user that initiated the action + # @param record [Object] the object we're retrieving the policy for + # @raise [NotDefinedError] if the policy cannot be found + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Object] instance of policy class with query methods + def policy!(record) + policy = policy_finder(record).policy! + policy.new(user, pundit_model(record)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" + end + + private + + def policy_finder(...) + PolicyFinder.new(...) + end + + def pundit_model(record) + record.is_a?(Array) ? record.last : record + end + end +end