diff --git a/lib/tapioca/dsl/compilers/job_iteration.rb b/lib/tapioca/dsl/compilers/job_iteration.rb index a7af5482..a3d3f860 100644 --- a/lib/tapioca/dsl/compilers/job_iteration.rb +++ b/lib/tapioca/dsl/compilers/job_iteration.rb @@ -3,6 +3,8 @@ return unless defined?(JobIteration::Iteration) +require "prism" + module Tapioca module Dsl module Compilers @@ -54,24 +56,64 @@ def decorate typed_param.param.name == "cursor" end + expanded_parameters = parameters.flat_map do |typed_param| + hash_param = typed_param.type.match(/\A\{.*\}\z/) + if hash_param + key_value_pairs = parse_hash_parameter(typed_param) + key_value_pairs.map do |key, value| + create_param(key, type: value) + end + else + typed_param + end + end + return_type = compile_method_return_type_to_rbi(method) job.create_method( "perform_later", - parameters: parameters, + parameters: expanded_parameters, return_type: "T.any(#{constant_name}, FalseClass)", class_method: true, ) job.create_method( "perform_now", - parameters: parameters, + parameters: expanded_parameters, return_type: return_type, class_method: true, ) end end + private + + def parse_hash_parameter(typed_param) + parse_result = Prism.parse(typed_param.type) + return "T.untyped" if parse_result.failure? + + visitor = HashParamVisitor.new + parse_result.value.accept(visitor) + visitor.key_value_pairs + end + + class HashParamVisitor < Prism::Visitor + attr_reader :key_value_pairs + + def initialize + super + @key_value_pairs = [] + end + + def visit_hash_node(node) + node.elements.each do |element| + key = element.key.unescaped + value = element.value.slice + @key_value_pairs << [key, value] + end + end + end + class << self extend T::Sig diff --git a/test/tapioca/dsl/compilers/job_iteration_test.rb b/test/tapioca/dsl/compilers/job_iteration_test.rb index 44e56938..31314b87 100644 --- a/test/tapioca/dsl/compilers/job_iteration_test.rb +++ b/test/tapioca/dsl/compilers/job_iteration_test.rb @@ -131,6 +131,109 @@ def perform_now(user_id, name); end RBI assert_equal(expected, rbi_for(:NotifyJob)) end + + def test_generates_correct_rbi_file_for_job_with_build_enumerator_method_with_aliased_hash_parameter + add_ruby_file("job.rb", <<~RUBY) + class NotifyJob < ActiveJob::Base + include JobIteration::Iteration + + Params = T.type_alias { { user_id: Integer, name: String } } + + extend T::Sig + sig { params(params: Params, cursor: T.untyped).void } + def build_enumerator(params, cursor:) + # ... + end + end + RUBY + + expected = template(<<~RBI) + # typed: strong + + class NotifyJob + class << self + sig { params(user_id: Integer, name: String).returns(T.any(NotifyJob, FalseClass)) } + def perform_later(user_id, name); end + + sig { params(user_id: Integer, name: String).void } + def perform_now(user_id, name); end + end + end + RBI + assert_equal(expected, rbi_for(:NotifyJob)) + end + + def test_generates_correct_rbi_file_for_job_with_build_enumerator_method_with_nested_hash_parameter + add_ruby_file("job.rb", <<~RUBY) + class ResourceType; end + class Locale; end + + class NotifyJob < ActiveJob::Base + include JobIteration::Iteration + + extend T::Sig + sig { params(params: { shop_id: Integer, resource_types: T::Array[ResourceType], locale: Locale, metadata: T.nilable(String) }, cursor: T.untyped).void } + def build_enumerator(params, cursor:) + # ... + end + end + RUBY + + expected = template(<<~RBI) + # typed: strong + + class NotifyJob + class << self + sig { params(shop_id: Integer, resource_types: T::Array[ResourceType], locale: Locale, metadata: T.nilable(String)).returns(T.any(NotifyJob, FalseClass)) } + def perform_later(shop_id, resource_types, locale, metadata); end + + sig { params(shop_id: Integer, resource_types: T::Array[ResourceType], locale: Locale, metadata: T.nilable(String)).void } + def perform_now(shop_id, resource_types, locale, metadata); end + end + end + RBI + assert_equal(expected, rbi_for(:NotifyJob)) + end + + def test_generates_correct_rbi_file_for_job_with_build_enumerator_method_with_complex_hash_parameter + add_ruby_file("job.rb", <<~RUBY) + class NotifyJob < ActiveJob::Base + include JobIteration::Iteration + + extend T::Sig + sig do + params( + params: { + shop_ids: T.any(Integer, T::Array[Integer]), + profile_ids: T.any(Integer, T::Array[Integer]), + extension_ids: T.any(Integer, T::Array[Integer]), + foo: Symbol, + bar: String + }, + cursor: T.untyped + ).void + end + def build_enumerator(params, cursor:) + # ... + end + end + RUBY + + expected = template(<<~RBI) + # typed: strong + + class NotifyJob + class << self + sig { params(shop_ids: T.any(Integer, T::Array[Integer]), profile_ids: T.any(Integer, T::Array[Integer]), extension_ids: T.any(Integer, T::Array[Integer]), foo: Symbol, bar: String).returns(T.any(NotifyJob, FalseClass)) } + def perform_later(shop_ids, profile_ids, extension_ids, foo, bar); end + + sig { params(shop_ids: T.any(Integer, T::Array[Integer]), profile_ids: T.any(Integer, T::Array[Integer]), extension_ids: T.any(Integer, T::Array[Integer]), foo: Symbol, bar: String).void } + def perform_now(shop_ids, profile_ids, extension_ids, foo, bar); end + end + end + RBI + assert_equal(expected, rbi_for(:NotifyJob)) + end end end end