From 43c8ca9cd5ebe950f0a6c9f8681322ac97bd150b Mon Sep 17 00:00:00 2001 From: Pedro Baracho Date: Thu, 1 Feb 2024 18:11:40 -0800 Subject: [PATCH] Use Arel instead of String for AR Enumerator conditionals --- lib/job-iteration/active_record_cursor.rb | 23 ++++++------------- lib/job-iteration/active_record_enumerator.rb | 12 +++++++--- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/lib/job-iteration/active_record_cursor.rb b/lib/job-iteration/active_record_cursor.rb index 10a8f4ef..c060dcb7 100644 --- a/lib/job-iteration/active_record_cursor.rb +++ b/lib/job-iteration/active_record_cursor.rb @@ -18,23 +18,16 @@ def initialize end end - def initialize(relation, columns = nil, position = nil) - @columns = if columns - Array(columns) - else - Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } - end + def initialize(relation, columns, position = nil) + @columns = columns self.position = Array.wrap(position) raise ArgumentError, "Must specify at least one column" if columns.empty? - if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") } - raise ArgumentError, "You need to specify fully-qualified columns if you join a table" - end if relation.arel.orders.present? || relation.arel.taken.present? raise ConditionNotSupportedError end - @base_relation = relation.reorder(@columns.join(",")) + @base_relation = relation.reorder(*@columns) @reached_end = false end @@ -54,12 +47,10 @@ def position=(position) def update_from_record(record) self.position = @columns.map do |column| - method = column.to_s.split(".").last - if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && method == "id" record.id_value else - record.send(method.to_sym) + record.send(column.name) end end end @@ -89,14 +80,14 @@ def conditions i = @position.size - 1 column = @columns[i] conditions = if @columns.size == @position.size - "#{column} > ?" + column.gt(@position[i]) else - "#{column} >= ?" + column.gteq(@position[i]) end while i > 0 i -= 1 column = @columns[i] - conditions = "#{column} > ? OR (#{column} = ? AND (#{conditions}))" + conditions = column.gt(@position[i]).or(column.eq(@position[i]).and(conditions)) end ret = @position.reduce([conditions]) { |params, value| params << value << value } ret.pop diff --git a/lib/job-iteration/active_record_enumerator.rb b/lib/job-iteration/active_record_enumerator.rb index 363a4ecf..8a62f1c9 100644 --- a/lib/job-iteration/active_record_enumerator.rb +++ b/lib/job-iteration/active_record_enumerator.rb @@ -11,9 +11,15 @@ def initialize(relation, columns: nil, batch_size: 100, cursor: nil) @relation = relation @batch_size = batch_size @columns = if columns - Array(columns) + Array(columns).map do |column| + if column.is_a?(Arel::Attributes::Attribute) + column + else + relation.arel_table[column.to_sym] + end + end else - Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } + Array(relation.primary_key).map { |pk| relation.arel_table[pk.to_sym] } end @cursor = cursor end @@ -45,7 +51,7 @@ def size def cursor_value(record) positions = @columns.map do |column| - attribute_name = column.to_s.split(".").last + attribute_name = column.name column_value(record, attribute_name) end return positions.first if positions.size == 1