diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index f7f722bc0e..7e0ad5f4c3 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -3639,6 +3639,25 @@ Iter:slice({first}, {last}) *Iter:slice()* Return: ~ Iter +Iter:take({n}) *Iter:take()* + Transforms an iterator to yield only the first n values. + + Example: >lua + local it = vim.iter({ 1, 2, 3, 4 }):take(2) + it:next() + -- 1 + it:next() + -- 2 + it:next() + -- nil +< + + Parameters: ~ + • {n} (integer) + + Return: ~ + Iter + Iter:totable() *Iter:totable()* Collect the iterator into a table. diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index e9c2b66bf2..8e602c406a 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -592,6 +592,41 @@ function ListIter.rfind(self, f) -- luacheck: no unused args self._head = self._tail end +--- Transforms an iterator to yield only the first n values. +--- +--- Example: +--- +--- ```lua +--- local it = vim.iter({ 1, 2, 3, 4 }):take(2) +--- it:next() +--- -- 1 +--- it:next() +--- -- 2 +--- it:next() +--- -- nil +--- ``` +--- +---@param n integer +---@return Iter +function Iter.take(self, n) + local next = self.next + local i = 0 + self.next = function() + if i < n then + i = i + 1 + return next(self) + end + end + return self +end + +---@private +function ListIter.take(self, n) + local inc = self._head < self._tail and 1 or -1 + self._tail = math.min(self._tail, self._head + n * inc) + return self +end + --- "Pops" a value from a |list-iterator| (gets the last value and decrements the tail). --- --- Example: diff --git a/test/functional/lua/iter_spec.lua b/test/functional/lua/iter_spec.lua index 2d28395c59..a589474262 100644 --- a/test/functional/lua/iter_spec.lua +++ b/test/functional/lua/iter_spec.lua @@ -203,6 +203,33 @@ describe('vim.iter', function() matches('skipback%(%) requires a list%-like table', pcall_err(it.nthback, it, 1)) end) + it('take()', function() + do + local t = { 4, 3, 2, 1 } + eq({}, vim.iter(t):take(0):totable()) + eq({ 4 }, vim.iter(t):take(1):totable()) + eq({ 4, 3 }, vim.iter(t):take(2):totable()) + eq({ 4, 3, 2 }, vim.iter(t):take(3):totable()) + eq({ 4, 3, 2, 1 }, vim.iter(t):take(4):totable()) + eq({ 4, 3, 2, 1 }, vim.iter(t):take(5):totable()) + end + + do + local t = { 4, 3, 2, 1 } + local it = vim.iter(t) + eq({ 4, 3 }, it:take(2):totable()) + -- tail is already set from the previous take() + eq({ 4, 3 }, it:take(3):totable()) + end + + do + local it = vim.iter(vim.gsplit('a|b|c|d', '|')) + eq({ 'a', 'b' }, it:take(2):totable()) + -- non-array iterators are consumed by take() + eq({}, it:take(2):totable()) + end + end) + it('any()', function() local function odd(v) return v % 2 ~= 0