feat(iter): add Iter.take (#26525)

This commit is contained in:
Will Hopkins 2023-12-12 12:27:24 -08:00 committed by GitHub
parent 1907abb4c2
commit 69ffbb76c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 0 deletions

View File

@ -3639,6 +3639,25 @@ Iter:slice({first}, {last}) *Iter:slice()*
Return: ~
Iter
Iter:take({n}) *Iter:take()*
Transforms an iterator to yield only the first n values.
Example: >lua
local it = vim.iter({ 1, 2, 3, 4 }):take(2)
it:next()
-- 1
it:next()
-- 2
it:next()
-- nil
<
Parameters: ~
• {n} (integer)
Return: ~
Iter
Iter:totable() *Iter:totable()*
Collect the iterator into a table.

View File

@ -592,6 +592,41 @@ function ListIter.rfind(self, f) -- luacheck: no unused args
self._head = self._tail
end
--- Transforms an iterator to yield only the first n values.
---
--- Example:
---
--- ```lua
--- local it = vim.iter({ 1, 2, 3, 4 }):take(2)
--- it:next()
--- -- 1
--- it:next()
--- -- 2
--- it:next()
--- -- nil
--- ```
---
---@param n integer
---@return Iter
function Iter.take(self, n)
local next = self.next
local i = 0
self.next = function()
if i < n then
i = i + 1
return next(self)
end
end
return self
end
---@private
function ListIter.take(self, n)
local inc = self._head < self._tail and 1 or -1
self._tail = math.min(self._tail, self._head + n * inc)
return self
end
--- "Pops" a value from a |list-iterator| (gets the last value and decrements the tail).
---
--- Example:

View File

@ -203,6 +203,33 @@ describe('vim.iter', function()
matches('skipback%(%) requires a list%-like table', pcall_err(it.nthback, it, 1))
end)
it('take()', function()
do
local t = { 4, 3, 2, 1 }
eq({}, vim.iter(t):take(0):totable())
eq({ 4 }, vim.iter(t):take(1):totable())
eq({ 4, 3 }, vim.iter(t):take(2):totable())
eq({ 4, 3, 2 }, vim.iter(t):take(3):totable())
eq({ 4, 3, 2, 1 }, vim.iter(t):take(4):totable())
eq({ 4, 3, 2, 1 }, vim.iter(t):take(5):totable())
end
do
local t = { 4, 3, 2, 1 }
local it = vim.iter(t)
eq({ 4, 3 }, it:take(2):totable())
-- tail is already set from the previous take()
eq({ 4, 3 }, it:take(3):totable())
end
do
local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
eq({ 'a', 'b' }, it:take(2):totable())
-- non-array iterators are consumed by take()
eq({}, it:take(2):totable())
end
end)
it('any()', function()
local function odd(v)
return v % 2 ~= 0