diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua index f4db5016ac..35081c6400 100644 --- a/runtime/lua/vim/treesitter/_range.lua +++ b/runtime/lua/vim/treesitter/_range.lua @@ -143,6 +143,29 @@ function M.contains(r1, r2) return true end +--- @param source integer|string +--- @param index integer +--- @return integer +local function get_offset(source, index) + if index == 0 then + return 0 + end + + if type(source) == 'number' then + return api.nvim_buf_get_offset(source, index) + end + + local byte = 0 + local next_offset = source:gmatch('()\n') + local line = 1 + while line <= index do + byte = next_offset() --[[@as integer]] + line = line + 1 + end + + return byte +end + ---@private ---@param source integer|string ---@param range Range @@ -152,19 +175,10 @@ function M.add_bytes(source, range) return range --[[@as Range6]] end - local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4] - local start_byte = 0 - local end_byte = 0 + local start_row, start_col, end_row, end_col = M.unpack4(range) -- TODO(vigoux): proper byte computation here, and account for EOL ? - if type(source) == 'number' then - -- Easy case, this is a buffer parser - start_byte = api.nvim_buf_get_offset(source, start_row) + start_col - end_byte = api.nvim_buf_get_offset(source, end_row) + end_col - elseif type(source) == 'string' then - -- string parser, single `\n` delimited string - start_byte = vim.fn.byteidx(source, start_col) - end_byte = vim.fn.byteidx(source, end_col) - end + local start_byte = get_offset(source, start_row) + start_col + local end_byte = get_offset(source, end_row) + end_col return { start_row, start_col, start_byte, end_row, end_col, end_byte } end diff --git a/test/functional/treesitter/parser_spec.lua b/test/functional/treesitter/parser_spec.lua index 60042ab2fd..14de07639b 100644 --- a/test/functional/treesitter/parser_spec.lua +++ b/test/functional/treesitter/parser_spec.lua @@ -486,7 +486,6 @@ end]] eq({ 'any-of?', 'contains?', 'eq?', 'has-ancestor?', 'has-parent?', 'is-main?', 'lua-match?', 'match?', 'vim-match?' }, res_list) end) - it('allows to set simple ranges', function() insert(test_text) @@ -528,6 +527,7 @@ end]] eq(range_tbl, { { { 0, 0, 0, 17, 1, 508 } } }) end) + it("allows to set complex ranges", function() insert(test_text) @@ -992,4 +992,58 @@ int x = INT_MAX; }, run_query()) end) + + it('handles ranges when source is a multiline string (#20419)', function() + local source = [==[ + vim.cmd[[ + set number + set cmdheight=2 + set lastsatus=2 + ]] + + set query = [[;; query + ((function_call + name: [ + (identifier) @_cdef_identifier + (_ _ (identifier) @_cdef_identifier) + ] + arguments: (arguments (string content: _ @injection.content))) + (#set! injection.language "c") + (#eq? @_cdef_identifier "cdef")) + ]] + ]==] + + local r = exec_lua([[ + local parser = vim.treesitter.get_string_parser(..., 'lua') + parser:parse() + local ranges = {} + parser:for_each_tree(function(tstree, tree) + ranges[tree:lang()] = { tstree:root():range(true) } + end) + return ranges + ]], source) + + eq({ + lua = { 0, 6, 6, 16, 4, 438 }, + query = { 6, 20, 113, 15, 6, 431 }, + vim = { 1, 0, 16, 4, 6, 89 } + }, r) + + -- The above ranges are provided directly from treesitter, however query directives may mutate + -- the ranges but only provide a Range4. Strip the byte entries from the ranges and make sure + -- add_bytes() produces the same result. + + local rb = exec_lua([[ + local r, source = ... + local add_bytes = require('vim.treesitter._range').add_bytes + for lang, range in pairs(r) do + r[lang] = {range[1], range[2], range[4], range[5]} + r[lang] = add_bytes(source, r[lang]) + end + return r + ]], r, source) + + eq(rb, r) + + end) end)