diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index b3be11efdb..e36ff9d8d8 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -3080,7 +3080,7 @@ Iter:map({self}, {f}) *Iter:map()* • {f} function(...):any Mapping function. Takes all values returned from the previous stage in the pipeline as arguments and returns one or more new values, which are used in the next pipeline - stage. Nil return values returned are filtered from the output. + stage. Nil return values are filtered from the output. Return: ~ Iter diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index c2e2c5bd9f..bda3508262 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -1,13 +1,14 @@ ---@defgroup lua-iter --- ---- The \*vim.iter\* module provides a generic "iterator" interface over tables and iterator ---- functions. +--- The \*vim.iter\* module provides a generic "iterator" interface over tables +--- and iterator functions. --- ---- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object with methods (such ---- as |Iter:filter()| and |Iter:map()|) that transform the underlying source data. These methods ---- can be chained together to create iterator "pipelines". Each pipeline stage receives as input ---- the output values from the prior stage. The values used in the first stage of the pipeline ---- depend on the type passed to this function: +--- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object +--- with methods (such as |Iter:filter()| and |Iter:map()|) that transform the +--- underlying source data. These methods can be chained together to create +--- iterator "pipelines". Each pipeline stage receives as input the output +--- values from the prior stage. The values used in the first stage of the +--- pipeline depend on the type passed to this function: --- --- - List tables pass only the value of each element --- - Non-list tables pass both the key and value of each element @@ -47,8 +48,8 @@ --- -- true --- --- ---- In addition to the |vim.iter()| function, the |vim.iter| module provides convenience functions ---- like |vim.iter.filter()| and |vim.iter.totable()|. +--- In addition to the |vim.iter()| function, the |vim.iter| module provides +--- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|. local M = {} @@ -61,9 +62,9 @@ end --- Special case implementations for iterators on list tables. ---@class ListIter : Iter ----@field _table table Underlying table data (table iterators only) ----@field _head number Index to the front of a table iterator (table iterators only) ----@field _tail number Index to the end of a table iterator (table iterators only) +---@field _table table Underlying table data +---@field _head number Index to the front of a table iterator +---@field _tail number Index to the end of a table iterator local ListIter = {} ListIter.__index = setmetatable(ListIter, Iter) ListIter.__call = function(self) @@ -75,7 +76,7 @@ local packedmt = {} ---@private local function unpack(t) - if getmetatable(t) == packedmt then + if type(t) == 'table' and getmetatable(t) == packedmt then return _G.unpack(t, 1, t.n) end return t @@ -92,13 +93,47 @@ end ---@private local function sanitize(t) - if getmetatable(t) == packedmt then + if type(t) == 'table' and getmetatable(t) == packedmt then -- Remove length tag t.n = nil end return t end +--- Determine if the current iterator stage should continue. +--- +--- If any arguments are passed to this function, then return those arguments +--- and stop the current iterator stage. Otherwise, return true to signal that +--- the current stage should continue. +--- +---@param ... any Function arguments. +---@return boolean True if the iterator stage should continue, false otherwise +---@return any Function arguments. +---@private +local function continue(...) + if select('#', ...) > 0 then + return false, ... + end + return true +end + +--- If no input arguments are given return false, indicating the current +--- iterator stage should stop. Otherwise, apply the arguments to the function +--- f. If that function returns no values, the current iterator stage continues. +--- Otherwise, those values are returned. +--- +---@param f function Function to call with the given arguments +---@param ... any Arguments to apply to f +---@return boolean True if the iterator pipeline should continue, false otherwise +---@return any Return values of f +---@private +local function apply(f, ...) + if select('#', ...) > 0 then + return continue(f(...)) + end + return false +end + --- Add a filter step to the iterator pipeline. --- --- Example: @@ -106,33 +141,16 @@ end --- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded) --- --- ----@param f function(...):bool Takes all values returned from the previous stage in the pipeline and ---- returns false or nil if the current iterator element should be ---- removed. +---@param f function(...):bool Takes all values returned from the previous stage +--- in the pipeline and returns false or nil if the +--- current iterator element should be removed. ---@return Iter function Iter.filter(self, f) - ---@private - local function fn(...) - local result = nil - if select(1, ...) ~= nil then - if not f(...) then - return true, nil - else - result = pack(...) - end + return self:map(function(...) + if f(...) then + return ... end - return false, result - end - - local next = self.next - self.next = function(this) - local cont, result - repeat - cont, result = fn(next(this)) - until not cont - return unpack(result) - end - return self + end) end ---@private @@ -165,31 +183,52 @@ end --- -- { 6, 12 } --- --- ----@param f function(...):any Mapping function. Takes all values returned from the previous stage ---- in the pipeline as arguments and returns one or more new values, ---- which are used in the next pipeline stage. Nil return values returned ---- are filtered from the output. +---@param f function(...):any Mapping function. Takes all values returned from +--- the previous stage in the pipeline as arguments +--- and returns one or more new values, which are used +--- in the next pipeline stage. Nil return values +--- are filtered from the output. ---@return Iter function Iter.map(self, f) - ---@private - local function fn(...) - local result = nil - if select(1, ...) ~= nil then - result = pack(f(...)) - if result == nil then - return true, nil - end - end - return false, result - end + -- Implementation note: the reader may be forgiven for observing that this + -- function appears excessively convoluted. The problem to solve is that each + -- stage of the iterator pipeline can return any number of values, and the + -- number of values could even change per iteration. And the return values + -- must be checked to determine if the pipeline has ended, so we cannot + -- naively forward them along to the next stage. + -- + -- A simple approach is to pack all of the return values into a table, check + -- for nil, then unpack the table for the next stage. However, packing and + -- unpacking tables is quite slow. There is no other way in Lua to handle an + -- unknown number of function return values than to simply forward those + -- values along to another function. Hence the intricate function passing you + -- see here. local next = self.next - self.next = function(this) - local cont, result - repeat - cont, result = fn(next(this)) - until not cont - return unpack(result) + + --- Drain values from the upstream iterator source until a value can be + --- returned. + --- + --- This is a recursive function. The base case is when the first argument is + --- false, which indicates that the rest of the arguments should be returned + --- as the values for the current iteration stage. + --- + ---@param cont boolean If true, the current iterator stage should continue to + --- pull values from its upstream pipeline stage. + --- Otherwise, this stage is complete and returns the + --- values passed. + ---@param ... any Values to return if cont is false. + ---@return any + ---@private + local function fn(cont, ...) + if cont then + return fn(apply(f, next(self))) + end + return ... + end + + self.next = function() + return fn(apply(f, next(self))) end return self end @@ -211,17 +250,18 @@ end --- Call a function once for each item in the pipeline. --- ---- This is used for functions which have side effects. To modify the values in the iterator, use ---- |Iter:map()|. +--- This is used for functions which have side effects. To modify the values in +--- the iterator, use |Iter:map()|. --- --- This function drains the iterator. --- ----@param f function(...) Function to execute for each item in the pipeline. Takes all of the ---- values returned by the previous stage in the pipeline as arguments. +---@param f function(...) Function to execute for each item in the pipeline. +--- Takes all of the values returned by the previous stage +--- in the pipeline as arguments. function Iter.each(self, f) ---@private local function fn(...) - if select(1, ...) ~= nil then + if select('#', ...) > 0 then f(...) return true end