fix(treesitter): correctly calculate bytes for text sources (#23655)

Fixes #20419
This commit is contained in:
Lewis Russell 2023-05-16 16:41:47 +01:00 committed by GitHub
parent d36dd2bae8
commit 6b19170d44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 13 deletions

View File

@ -143,6 +143,29 @@ function M.contains(r1, r2)
return true
end
--- @param source integer|string
--- @param index integer
--- @return integer
local function get_offset(source, index)
if index == 0 then
return 0
end
if type(source) == 'number' then
return api.nvim_buf_get_offset(source, index)
end
local byte = 0
local next_offset = source:gmatch('()\n')
local line = 1
while line <= index do
byte = next_offset() --[[@as integer]]
line = line + 1
end
return byte
end
---@private
---@param source integer|string
---@param range Range
@ -152,19 +175,10 @@ function M.add_bytes(source, range)
return range --[[@as Range6]]
end
local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
local start_byte = 0
local end_byte = 0
local start_row, start_col, end_row, end_col = M.unpack4(range)
-- TODO(vigoux): proper byte computation here, and account for EOL ?
if type(source) == 'number' then
-- Easy case, this is a buffer parser
start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
elseif type(source) == 'string' then
-- string parser, single `\n` delimited string
start_byte = vim.fn.byteidx(source, start_col)
end_byte = vim.fn.byteidx(source, end_col)
end
local start_byte = get_offset(source, start_row) + start_col
local end_byte = get_offset(source, end_row) + end_col
return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end

View File

@ -486,7 +486,6 @@ end]]
eq({ 'any-of?', 'contains?', 'eq?', 'has-ancestor?', 'has-parent?', 'is-main?', 'lua-match?', 'match?', 'vim-match?' }, res_list)
end)
it('allows to set simple ranges', function()
insert(test_text)
@ -528,6 +527,7 @@ end]]
eq(range_tbl, { { { 0, 0, 0, 17, 1, 508 } } })
end)
it("allows to set complex ranges", function()
insert(test_text)
@ -992,4 +992,58 @@ int x = INT_MAX;
}, run_query())
end)
it('handles ranges when source is a multiline string (#20419)', function()
local source = [==[
vim.cmd[[
set number
set cmdheight=2
set lastsatus=2
]]
set query = [[;; query
((function_call
name: [
(identifier) @_cdef_identifier
(_ _ (identifier) @_cdef_identifier)
]
arguments: (arguments (string content: _ @injection.content)))
(#set! injection.language "c")
(#eq? @_cdef_identifier "cdef"))
]]
]==]
local r = exec_lua([[
local parser = vim.treesitter.get_string_parser(..., 'lua')
parser:parse()
local ranges = {}
parser:for_each_tree(function(tstree, tree)
ranges[tree:lang()] = { tstree:root():range(true) }
end)
return ranges
]], source)
eq({
lua = { 0, 6, 6, 16, 4, 438 },
query = { 6, 20, 113, 15, 6, 431 },
vim = { 1, 0, 16, 4, 6, 89 }
}, r)
-- The above ranges are provided directly from treesitter, however query directives may mutate
-- the ranges but only provide a Range4. Strip the byte entries from the ranges and make sure
-- add_bytes() produces the same result.
local rb = exec_lua([[
local r, source = ...
local add_bytes = require('vim.treesitter._range').add_bytes
for lang, range in pairs(r) do
r[lang] = {range[1], range[2], range[4], range[5]}
r[lang] = add_bytes(source, r[lang])
end
return r
]], r, source)
eq(rb, r)
end)
end)