From c341a8a3e6f3e5215fefe99449306f8c88995a2f Mon Sep 17 00:00:00 2001 From: delphinus Date: Fri, 5 Apr 2024 09:03:35 +0900 Subject: [PATCH] feat: add type annotations for plenary.iterators --- lua/plenary/init.lua | 1 + lua/plenary/iterators.lua | 266 ++++++++++++++++++++++++++++++-------- 2 files changed, 215 insertions(+), 52 deletions(-) diff --git a/lua/plenary/init.lua b/lua/plenary/init.lua index 7b398d73e..560c57a24 100644 --- a/lua/plenary/init.lua +++ b/lua/plenary/init.lua @@ -8,6 +8,7 @@ ---@field filetype PlenaryFiletype ---@field fun PlenaryFun ---@field functional PlenaryFunctional +---@field iterators PlenaryIterators ---@field job PlenaryJob ---@field json PlenaryJson ---@field log PlenaryLog diff --git a/lua/plenary/iterators.lua b/lua/plenary/iterators.lua index 904d2db7f..03d3fe86b 100644 --- a/lua/plenary/iterators.lua +++ b/lua/plenary/iterators.lua @@ -12,12 +12,18 @@ local f = require "plenary.functional" -- Tools -------------------------------------------------------------------------------- +---@class PlenaryIterators local exports = {} ----@class Iterator ----@field gen function ----@field param any ----@field state any +---@generic V +---@alias PlenaryIteratorsIterator fun(param: string|V[], i?: integer): integer?, string|V|nil + +---@class PlenaryIterator +---@field gen PlenaryIteratorsIterator +---@field param string|table +---@field state? integer +---@overload fun(param?: string|table, state?: integer): integer?, any? + local Iterator = {} Iterator.__index = Iterator @@ -30,16 +36,24 @@ Iterator.__index = Iterator ---So instead we do not return param and state as multivals when doing wrap ---This causes the first loop iteration to call param and state with nil because we didn't return them as multivals ---We have to use or to check for nil and default to interal starting state and param +---@generic T +---@param param? string|T[] +---@param state? integer +---@return integer i +---@return string|T v function Iterator:__call(param, state) return self.gen(param or self.param, state or self.state) end +---@return string function Iterator:__tostring() return "" end -- A special hack for zip/chain to skip last two state, if a wrapped iterator -- has been passed +---@param ... any +---@return integer local numargs = function(...) local n = select("#", ...) if n >= 3 then @@ -57,6 +71,9 @@ local numargs = function(...) return n end +---@param state_x? integer +---@param ... any +---@return ... local return_if_not_empty = function(state_x, ...) if state_x == nil then return nil @@ -64,6 +81,11 @@ local return_if_not_empty = function(state_x, ...) return ... end +---@param fun function +---@param state_x? integer +---@param ... any +---@return integer? +---@return ... local call_if_not_empty = function(fun, state_x, ...) if state_x == nil then return nil @@ -74,18 +96,31 @@ end -------------------------------------------------------------------------------- -- Basic Functions -------------------------------------------------------------------------------- +---@param _param any +---@param _state any +---@return nil local nil_gen = function(_param, _state) return nil end local pairs_gen = pairs {} +---@generic K, V +---@param map table +---@param key K +---@return K key +---@return K key +---@return V value local map_gen = function(map, key) local value key, value = pairs_gen(map, key) return key, key, value end +---@param param string +---@param state integer +---@return integer state +---@return string r local string_gen = function(param, state) state = state + 1 if state > #param then @@ -95,6 +130,12 @@ local string_gen = function(param, state) return state, r end +---@param obj string|function|table|PlenaryIterator +---@param param? string|table +---@param state? integer +---@return PlenaryIteratorsIterator gen +---@return string|table|nil param +---@return integer? state local rawiter = function(obj, param, state) assert(obj ~= nil, "invalid iterator") @@ -129,10 +170,10 @@ end ---Wraps the iterator triplet into a table to allow metamethods and calling with method form ---Important! We do not return param and state as multivals like the original luafun ---See the __call metamethod for more information ----@param gen any ----@param param any ----@param state any ----@return Iterator +---@param gen PlenaryIteratorsIterator +---@param param? string|table +---@param state? integer|integer[] +---@return PlenaryIterator local function wrap(gen, param, state) return setmetatable({ gen = gen, @@ -142,19 +183,19 @@ local function wrap(gen, param, state) end ---Unwrap an iterator metatable into the iterator triplet ----@param self Iterator ----@return any ----@return any ----@return any +---@param self PlenaryIterator +---@return PlenaryIteratorsIterator gen +---@return string|table param +---@return integer? state local unwrap = function(self) return self.gen, self.param, self.state end ---Create an iterator from an object ----@param obj any ----@param param any (optional) ----@param state any (optional) ----@return Iterator +---@param obj string|function|table|PlenaryIterator +---@param param? table +---@param state? integer +---@return PlenaryIterator local iter = function(obj, param, state) return wrap(rawiter(obj, param, state)) end @@ -163,6 +204,7 @@ exports.iter = iter exports.wrap = wrap exports.unwrap = unwrap +---@param fn PlenaryIteratorsIterator function Iterator:for_each(fn) local param, state = self.param, self.state repeat @@ -170,6 +212,7 @@ function Iterator:for_each(fn) until state == nil end +---@return PlenaryIterator function Iterator:stateful() return wrap( co.wrap(function() @@ -206,6 +249,10 @@ end -------------------------------------------------------------------------------- -- Generators -------------------------------------------------------------------------------- +---@param param { [1]: integer, [2]: integer } +---@param state? integer +---@return integer? state +---@return integer? state local range_gen = function(param, state) local stop, step = param[1], param[2] state = state + step @@ -215,6 +262,10 @@ local range_gen = function(param, state) return state, state end +---@param param { [1]: integer, [2]: integer } +---@param state? integer +---@return integer? state +---@return integer? state local range_rev_gen = function(param, state) local stop, step = param[1], param[2] state = state + step @@ -225,10 +276,10 @@ local range_rev_gen = function(param, state) end ---Creates a range iterator ----@param start number ----@param stop number ----@param step number ----@return Iterator +---@param start integer +---@param stop? integer +---@param step? integer +---@return PlenaryIterator local range = function(start, stop, step) if step == nil then if stop == nil then @@ -254,22 +305,36 @@ local range = function(start, stop, step) end exports.range = range +---@generic T +---@param param_x T[] +---@param state_x integer +---@return integer state +---@return T ... local duplicate_table_gen = function(param_x, state_x) return state_x + 1, unpack(param_x) end +---@param param_x fun(state: integer): ... +---@param state_x integer +---@return integer state +---@return any ... local duplicate_fun_gen = function(param_x, state_x) return state_x + 1, param_x(state_x) end +---@generic T +---@param param_x T +---@param state_x integer +---@return integer state +---@return T param local duplicate_gen = function(param_x, state_x) return state_x + 1, param_x end ---Creates an infinite iterator that will yield the arguments ---If multiple arguments are passed, the args will be packed and unpacked ----@param ...: the arguments to duplicate ----@return Iterator +---@param ... any the arguments to duplicate +---@return PlenaryIterator local duplicate = function(...) if select("#", ...) <= 1 then return wrap(duplicate_gen, select(1, ...), 0) @@ -281,8 +346,8 @@ exports.duplicate = duplicate ---Creates an iterator from a function ---NOTE: if the function is a closure and modifies state, the resulting iterator will not be stateless ----@param fun function ----@return Iterator +---@param fun fun(state: integer): ... +---@return PlenaryIterator local from_fun = function(fun) assert(type(fun) == "function") return wrap(duplicate_fun_gen, fun, 0) @@ -291,7 +356,7 @@ exports.from_fun = from_fun ---Creates an infinite iterator that will yield zeros. ---This is an alias to calling duplicate(0) ----@return Iterator +---@return PlenaryIterator local zeros = function() return wrap(duplicate_gen, 0, 0) end @@ -299,24 +364,32 @@ exports.zeros = zeros ---Creates an infinite iterator that will yield ones. ---This is an alias to calling duplicate(1) ----@return Iterator +---@return PlenaryIterator local ones = function() return wrap(duplicate_gen, 1, 0) end exports.ones = ones +---@param param_x { [1]: integer, [2]: integer } +---@param _state_x any +---@return 0 +---@return integer local rands_gen = function(param_x, _state_x) return 0, math.random(param_x[1], param_x[2]) end +---@param _param_x any +---@param _state_x any +---@return 0 +---@return float local rands_nil_gen = function(_param_x, _state_x) return 0, math.random() end ---Creates an infinite iterator that will yield random values. ----@param n number ----@param m number ----@return Iterator +---@param n integer +---@param m integer +---@return PlenaryIterator local rands = function(n, m) if n == nil and m == nil then return wrap(rands_nil_gen, 0, 0) @@ -333,6 +406,10 @@ local rands = function(n, m) end exports.rands = rands +---@param param { [1]: string, [2]: string } +---@param state? integer +---@return integer? +---@return string? local split_gen = function(param, state) local input, sep = param[1], param[2] local input_len = #input @@ -353,9 +430,9 @@ local split_gen = function(param, state) end ---Return an iterator of substrings separated by a string ----@param input string: the string to split ----@param sep string: the separator to find and split based on ----@return Iterator +---@param input string the string to split +---@param sep string the separator to find and split based on +---@return PlenaryIterator local split = function(input, sep) return wrap(split_gen, { input, sep }, 1) end @@ -363,13 +440,15 @@ exports.split = split ---Splits a string based on a single space ---An alias for split(input, " ") ----@param input any ----@return any +---@param input string +---@return PlenaryIterator local words = function(input) return split(input, " ") end exports.words = words +---@param input string +---@return PlenaryIterator local lines = function(input) -- TODO: platform specific linebreaks return split(input, "\n") @@ -379,20 +458,33 @@ exports.lines = lines -------------------------------------------------------------------------------- -- Transformations -------------------------------------------------------------------------------- +---@alias PlenaryIteratorsLoop fun(i?: integer, v?: any): integer?, ... + +---@param param { [1]: PlenaryIteratorsIterator, [2]: string|table, [3]: PlenaryIteratorsLoop } +---@param state? integer +---@return integer? +---@return ... local map_gen2 = function(param, state) local gen_x, param_x, fun = param[1], param[2], param[3] return call_if_not_empty(fun, gen_x(param_x, state)) end ---Iterator adapter that maps the previous iterator with a function ----@param fun function: The function to map with. Will be called on each element ----@return Iterator +---@param fun PlenaryIteratorsLoop The function to map with. Will be called on each element +---@return PlenaryIterator function Iterator:map(fun) return wrap(map_gen2, { self.gen, self.param, fun }, self.state) end +---@alias PlenaryIteratorsFlattenParam { [1]: PlenaryIteratorsIterator, [2]: string|table, [3]: integer } + local flatten_gen1 do + ---@param new_iter PlenaryIterator + ---@param state_x? integer + ---@param ... + ---@return PlenaryIteratorsFlattenParam? + ---@return ... local it = function(new_iter, state_x, ...) if state_x == nil then return nil @@ -400,6 +492,11 @@ do return { new_iter.gen, new_iter.param, state_x }, ... end + ---@param state PlenaryIteratorsFlattenParam + ---@param state_x? integer + ---@param ... unknown + ---@return PlenaryIteratorsFlattenParam? + ---@return ... flatten_gen1 = function(state, state_x, ...) if state_x == nil then return nil @@ -410,7 +507,7 @@ do -- experimental part if getmetatable(first_arg) == Iterator then -- attach the iterator to the rest - local new_iter = (first_arg .. wrap(state[1], state[2], state_x)):flatten() + local new_iter = (first_arg .. wrap(state[1], state[2], state_x)):flatten() --[[@as PlenaryIterator]] -- advance the iterator by one return it(new_iter, new_iter.gen(new_iter.param, new_iter.state)) end @@ -419,6 +516,10 @@ do end end +---@param _ any +---@param state? PlenaryIteratorsFlattenParam +---@return PlearyIteratorsFlattenParam? +---@return ... local flatten_gen = function(_, state) if state == nil then return @@ -428,7 +529,7 @@ local flatten_gen = function(_, state) end ---Iterator adapter that will recursivley flatten nested iterator structure ----@return Iterator +---@return PlenaryIterator function Iterator:flatten() return wrap(flatten_gen, false, { self.gen, self.param, self.state }) end @@ -436,6 +537,14 @@ end -------------------------------------------------------------------------------- -- Filtering -------------------------------------------------------------------------------- +---@generic T +---@param fun fun(a: T): boolean +---@param gen_x PlenaryIteratorsIterator +---@param param_x T[] +---@param state_x? integer +---@param a T +---@return integer? state_x +---@return T a local filter1_gen = function(fun, gen_x, param_x, state_x, a) while true do if state_x == nil or fun(a) then @@ -450,10 +559,24 @@ end -- because we can't assign a vararg mutably in a while loop like filter1_gen -- so we have to use recursion in calling both of these functions local filterm_gen + +---@param fun fun(...): boolean +---@param gen_x PlenaryIteratorsIterator +---@param param_x table +---@param state_x? integer +---@return integer? state_x +---@return ... local filterm_gen_shrink = function(fun, gen_x, param_x, state_x) return filterm_gen(fun, gen_x, param_x, gen_x(param_x, state_x)) end +---@param fun fun(...): boolean +---@param gen_x PlenaryIteratorsIterator +---@param param_x table +---@param state_x? integer +---@param ... any +---@return integer? state_x +---@return ... filterm_gen = function(fun, gen_x, param_x, state_x, ...) if state_x == nil then return nil @@ -466,6 +589,13 @@ filterm_gen = function(fun, gen_x, param_x, state_x, ...) return filterm_gen_shrink(fun, gen_x, param_x, state_x) end +---@param fun fun(...): boolean +---@param gen_x PlenaryIteratorsIterator +---@param param_x table +---@param state_x? integer +---@param ... any +---@return integer? state_x +---@return ... local filter_detect = function(fun, gen_x, param_x, state_x, ...) if select("#", ...) < 2 then return filter1_gen(fun, gen_x, param_x, state_x, ...) @@ -474,20 +604,24 @@ local filter_detect = function(fun, gen_x, param_x, state_x, ...) end end +---@param param table +---@param state_x? integer +---@return integer? state_x +---@return ... local filter_gen = function(param, state_x) local gen_x, param_x, fun = param[1], param[2], param[3] return filter_detect(fun, gen_x, param_x, gen_x(param_x, state_x)) end ---Iterator adapter that will filter values ----@param fun function: The function to filter values with. If the function returns true, the value will be kept. ----@return Iterator +---@param fun fun(...): boolean The function to filter values with. If the function returns true, the value will be kept +---@return PlenaryIterator function Iterator:filter(fun) return wrap(filter_gen, { self.gen, self.param, fun }, self.state) end ---Iterator adapter that will provide numbers from 1 to n as the first multival ----@return Iterator +---@return PlenaryIterator function Iterator:enumerate() local i = 0 return self:map(function(...) @@ -501,7 +635,7 @@ end -------------------------------------------------------------------------------- ---Returns true if any of the values in the iterator satisfy a predicate ----@param fun function +---@param fun fun(...): boolean ---@return boolean function Iterator:any(fun) local r @@ -513,7 +647,7 @@ function Iterator:any(fun) end ---Returns true if all of the values in the iterator satisfy a predicate ----@param fun function +---@param fun fun(...): boolean ---@return boolean function Iterator:all(fun) local r @@ -525,7 +659,8 @@ function Iterator:all(fun) end ---Finds a value that is equal to the provided value of satisfies a predicate. ----@param val_or_fn any +---@generic T +---@param val_or_fn T|fun(a: T): boolean ---@return any function Iterator:find(val_or_fn) local gen, param, state = self.gen, self.param, self.state @@ -542,9 +677,10 @@ function Iterator:find(val_or_fn) end ---Folds an iterator into a single value using a function. ----@param init any ----@param fun fun(acc: any, val: any): any ----@return any +---@generic T, U +---@param init T +---@param fun fun(acc: T, val: U): T +---@return T function Iterator:fold(init, fun) local acc = init local gen, param, state = self.gen, self.param, self.state @@ -593,6 +729,13 @@ end -------------------------------------------------------------------------------- -- call each other local chain_gen_r1 + +---@param param table +---@param state { [1]: integer, [2]: integer } +---@param state_x? integer +---@param ... any +---@return { [1]: integer, [2]: integer }? state +---@return ... local chain_gen_r2 = function(param, state, state_x, ...) if state_x == nil then local i = state[1] + 1 @@ -605,6 +748,10 @@ local chain_gen_r2 = function(param, state, state_x, ...) return { state[1], state_x }, ... end +---@param param table +---@param state { [1]: integer, [2]: integer } +---@return { [1]: integer, [2]: integer }? state +---@return ... chain_gen_r1 = function(param, state) local i, state_x = state[1], state[2] local gen_x, param_x = param[3 * i - 2], param[3 * i - 1] @@ -616,8 +763,8 @@ end ---until all of the iterators are exhausted. ---Used for treating consecutive iterators as a single iterator. ---Infinity iterators are supported, but are not recommended. ----@param ...: the iterators to chain ----@return Iterator +---@param ... PlenaryIterator the iterators to chain +---@return PlenaryIterator local chain = function(...) local n = numargs(...) @@ -625,6 +772,7 @@ local chain = function(...) return wrap(nil_gen, nil, nil) end + ---@type (string|integer|table|PlenaryIteratorsIterator)[] local param = { [3 * n] = 0 } local gen_x, param_x, state_x @@ -636,13 +784,22 @@ local chain = function(...) param[3 * i] = state_x end - return wrap(chain_gen_r1, param, { 1, param[3] }) + return wrap(chain_gen_r1, param, { + 1, + param[3] --[[@as integer]], + }) end Iterator.chain = chain Iterator.__concat = chain exports.chain = chain +---@param param table +---@param state integer[] +---@param state_new integer[] +---@param ... any +---@return integer[]? state_x +---@return ... local function zip_gen_r(param, state, state_new, ...) if #state_new == #param / 2 then return state_new, ... @@ -658,6 +815,9 @@ local function zip_gen_r(param, state, state_new, ...) return zip_gen_r(param, state, state_new, r, ...) end +---@param param table +---@param state integer[] +---@return integer[]? state_x local zip_gen = function(param, state) return zip_gen_r(param, state, {}) end @@ -665,18 +825,20 @@ end ---Return a new iterator where i-th return value contains the i-th element from each of the iterators. ---The returned iterator is truncated in length to the length of the shortest iterator. ---For multi-return iterators only the first variable is used. ----@param ...: the iterators to zip ----@return Iterator +---@param ... PlenaryIterator the iterators to zip +---@return PlenaryIterator local zip = function(...) local n = numargs(...) if n == 0 then return wrap(nil_gen, nil, nil) end + ---@type (string|table|PlenaryIteratorsIterator)[] local param = { [2 * n] = 0 } local state = { [n] = 0 } local gen_x, param_x, state_x for i = 1, n, 1 do + ---@type PlenaryIterator local it = select(n - i + 1, ...) gen_x, param_x, state_x = rawiter(it) param[2 * i - 1] = gen_x