fix(iter): add tag to packed table

If pack() is called with a single value, it does not create a table; it
simply returns the value it is passed. When unpack is called with a
table argument, it interprets that table as a list of values that were
packed together into a table.

This causes a problem when the single value being packed is _itself_ a
table. pack() will not place it into another table, but unpack() sees
the table argument and tries to unpack it.

To fix this, we add a simple "tag" to packed table values so that
unpack() only attempts to unpack tables that have this tag. Other tables
are left alone. The tag is simply the length of the table.
This commit is contained in:
Gregory Anders 2023-04-19 06:45:56 -06:00
parent 0a3645a723
commit 6b96122453
2 changed files with 47 additions and 5 deletions

View File

@ -28,16 +28,17 @@ end
---@private
local function unpack(t)
if type(t) == 'table' then
return _G.unpack(t)
if type(t) == 'table' and t.__n ~= nil then
return _G.unpack(t, 1, t.__n)
end
return t
end
---@private
local function pack(...)
if select('#', ...) > 1 then
return { ... }
local n = select('#', ...)
if n > 1 then
return { __n = n, ... }
end
return ...
end
@ -210,6 +211,12 @@ function Iter.totable(self)
if args == nil then
break
end
if type(args) == 'table' then
-- Removed packed table tag if it exists
args.__n = nil
end
t[#t + 1] = args
end
return t
@ -218,6 +225,14 @@ end
---@private
function ListIter.totable(self)
if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then
-- Remove any packed table tags
for i = 1, #self._table do
local v = self._table[i]
if type(v) == 'table' then
v.__n = nil
self._table[i] = v
end
end
return self._table
end
@ -747,7 +762,7 @@ function ListIter.enumerate(self)
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
local v = self._table[i]
self._table[i] = { i, v }
self._table[i] = pack(i, v)
end
return self
end

View File

@ -3381,6 +3381,33 @@ describe('lua stdlib', function()
end
end)
eq({ A = 2, C = 6 }, it:totable())
it('handles table values mid-pipeline', function()
local map = {
item = {
file = 'test',
},
item_2 = {
file = 'test',
},
item_3 = {
file = 'test',
},
}
local output = vim.iter(map):map(function(key, value)
return { [key] = value.file }
end):totable()
table.sort(output, function(a, b)
return next(a) < next(b)
end)
eq({
{ item = 'test' },
{ item_2 = 'test' },
{ item_3 = 'test' },
}, output)
end)
end)
end)