From 19a793545f15bb7e0bac2fc8f705c600e8f9c9bb Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sun, 30 Apr 2023 16:11:38 +0100 Subject: [PATCH] fix(treesitter): redraw added/removed injections properly (#23287) When injections are added or removed make sure to: - invoke 'changedtree' callbacks for when new trees are added. - invoke 'changedtree' callbacks for when trees are invalidated - redraw regions when languagetree children are removed --- runtime/doc/treesitter.txt | 27 +++--- runtime/lua/vim/treesitter/highlighter.lua | 20 +++-- runtime/lua/vim/treesitter/languagetree.lua | 83 ++++++++++++------- test/functional/treesitter/highlight_spec.lua | 82 +++++++++++++++--- 4 files changed, 156 insertions(+), 56 deletions(-) diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index 94690f0b7f..0168b11499 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -1149,21 +1149,24 @@ LanguageTree:parse({self}) *LanguageTree:parse()* Return: ~ TSTree[] -LanguageTree:register_cbs({self}, {cbs}) *LanguageTree:register_cbs()* + *LanguageTree:register_cbs()* +LanguageTree:register_cbs({self}, {cbs}, {recursive}) Registers callbacks for the |LanguageTree|. Parameters: ~ - • {cbs} (table) An |nvim_buf_attach()|-like table argument with the - following handlers: - • `on_bytes` : see |nvim_buf_attach()|, but this will be called after the parsers callback. - • `on_changedtree` : a callback that will be called every time - the tree has syntactical changes. It will only be passed one - argument, which is a table of the ranges (as node ranges) - that changed. - • `on_child_added` : emitted when a child is added to the - tree. - • `on_child_removed` : emitted when a child is removed from - the tree. + • {cbs} (table) An |nvim_buf_attach()|-like table argument with + the following handlers: + • `on_bytes` : see |nvim_buf_attach()|, but this will be called after the parsers callback. + • `on_changedtree` : a callback that will be called + every time the tree has syntactical changes. It will + only be passed one argument, which is a table of the + ranges (as node ranges) that changed. + • `on_child_added` : emitted when a child is added to + the tree. + • `on_child_removed` : emitted when a child is removed + from the tree. + • {recursive?} boolean Apply callbacks recursively for all children. + Any new children will also inherit the callbacks. • {self} LanguageTree:source({self}) *LanguageTree:source()* diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index ac2a929487..4bb764c5c6 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -76,9 +76,6 @@ function TSHighlighter.new(tree, opts) opts = opts or {} ---@type { queries: table } self.tree = tree tree:register_cbs({ - on_changedtree = function(...) - self:on_changedtree(...) - end, on_bytes = function(...) self:on_bytes(...) end, @@ -87,6 +84,17 @@ function TSHighlighter.new(tree, opts) end, }) + tree:register_cbs({ + on_changedtree = function(...) + self:on_changedtree(...) + end, + on_child_removed = function(child) + child:for_each_tree(function(t) + self:on_changedtree(t:included_ranges(true)) + end) + end, + }, true) + self.bufnr = tree:source() --[[@as integer]] self.edit_count = 0 self.redraw_count = 0 @@ -177,10 +185,10 @@ function TSHighlighter:on_detach() end ---@package ----@param changes integer[][]? +---@param changes Range6[][] function TSHighlighter:on_changedtree(changes) - for _, ch in ipairs(changes or {}) do - api.nvim__buf_redraw_range(self.bufnr, ch[1], ch[3] + 1) + for _, ch in ipairs(changes) do + api.nvim__buf_redraw_range(self.bufnr, ch[1], ch[4] + 1) end end diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 4aa07d1b96..19cea32367 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -51,8 +51,18 @@ local Range = require('vim.treesitter._range') ---| 'on_child_added' ---| 'on_child_removed' +--- @type table +local TSCallbackNames = { + on_changedtree = 'changedtree', + on_bytes = 'bytes', + on_detach = 'detach', + on_child_added = 'child_added', + on_child_removed = 'child_removed', +} + ---@class LanguageTree ---@field private _callbacks table Callback handlers +---@field package _callbacks_rec table Callback handlers (recursive) ---@field private _children table Injected languages ---@field private _injection_query Query Queries defining injected languages ---@field private _opts table Options @@ -79,7 +89,7 @@ LanguageTree.__index = LanguageTree --- "injected" language parsers, which themselves may inject other languages, recursively. --- ---@param source (integer|string) Buffer or text string to parse ----@param lang string|nil Root language of this tree +---@param lang string Root language of this tree ---@param opts (table|nil) Optional arguments: --- - injections table Map of language to injection query strings. Overrides the --- built-in runtime file searching for language injections. @@ -100,15 +110,15 @@ function LanguageTree.new(source, lang, opts) or query.get(lang, 'injections'), _valid = false, _parser = vim._create_ts_parser(lang), - _callbacks = { - changedtree = {}, - bytes = {}, - detach = {}, - child_added = {}, - child_removed = {}, - }, + _callbacks = {}, + _callbacks_rec = {}, }, LanguageTree) + for _, name in pairs(TSCallbackNames) do + self._callbacks[name] = {} + self._callbacks_rec[name] = {} + end + return self end @@ -121,6 +131,7 @@ local function tcall(f, ...) local start = vim.loop.hrtime() ---@diagnostic disable-next-line local r = { f(...) } + --- @type number local duration = (vim.loop.hrtime() - start) / 1000000 return duration, unpack(r) end @@ -161,6 +172,9 @@ function LanguageTree:invalidate(reload) -- buffer was reloaded, reparse all trees if reload then + for _, t in ipairs(self._trees) do + self:_do_callback('changedtree', t:included_ranges(true), t) + end self._trees = {} end @@ -245,9 +259,12 @@ function LanguageTree:parse() if not self._valid or not self._valid[i] then self._parser:set_included_ranges(ranges) local parse_time, tree, tree_changes = - tcall(self._parser.parse, self._parser, self._trees[i], self._source) + tcall(self._parser.parse, self._parser, self._trees[i], self._source, true) - self:_do_callback('changedtree', tree_changes, tree) + -- Pass ranges if this is an initial parse + local cb_changes = self._trees[i] and tree_changes or ranges + + self:_do_callback('changedtree', cb_changes, tree) self._trees[i] = tree vim.list_extend(changes, tree_changes) @@ -341,7 +358,14 @@ function LanguageTree:add_child(lang) self:remove_child(lang) end - self._children[lang] = LanguageTree.new(self._source, lang, self._opts) + local child = LanguageTree.new(self._source, lang, self._opts) + + -- Inherit recursive callbacks + for nm, cb in pairs(self._callbacks_rec) do + vim.list_extend(child._callbacks_rec[nm], cb) + end + + self._children[lang] = child self:invalidate() self:_do_callback('child_added', self._children[lang]) @@ -453,6 +477,10 @@ function LanguageTree:set_included_regions(new_regions) end if #self:included_regions() ~= #new_regions then + -- TODO(lewis6991): inefficient; invalidate trees incrementally + for _, t in ipairs(self._trees) do + self:_do_callback('changedtree', t:included_ranges(true), t) + end self._trees = {} self:invalidate() else @@ -707,6 +735,9 @@ function LanguageTree:_do_callback(cb_name, ...) for _, cb in ipairs(self._callbacks[cb_name]) do cb(...) end + for _, cb in ipairs(self._callbacks_rec[cb_name]) do + cb(...) + end end ---@package @@ -855,30 +886,26 @@ end --- changed. --- - `on_child_added` : emitted when a child is added to the tree. --- - `on_child_removed` : emitted when a child is removed from the tree. -function LanguageTree:register_cbs(cbs) +--- @param recursive? boolean Apply callbacks recursively for all children. Any new children will +--- also inherit the callbacks. +function LanguageTree:register_cbs(cbs, recursive) ---@cast cbs table if not cbs then return end - if cbs.on_changedtree then - table.insert(self._callbacks.changedtree, cbs.on_changedtree) + local callbacks = recursive and self._callbacks_rec or self._callbacks + + for name, cbname in pairs(TSCallbackNames) do + if cbs[name] then + table.insert(callbacks[cbname], cbs[name]) + end end - if cbs.on_bytes then - table.insert(self._callbacks.bytes, cbs.on_bytes) - end - - if cbs.on_detach then - table.insert(self._callbacks.detach, cbs.on_detach) - end - - if cbs.on_child_added then - table.insert(self._callbacks.child_added, cbs.on_child_added) - end - - if cbs.on_child_removed then - table.insert(self._callbacks.child_removed, cbs.on_child_removed) + if recursive then + self:for_each_child(function(child) + child:register_cbs(cbs, true) + end) end end diff --git a/test/functional/treesitter/highlight_spec.lua b/test/functional/treesitter/highlight_spec.lua index 4e1efec404..dc303c564f 100644 --- a/test/functional/treesitter/highlight_spec.lua +++ b/test/functional/treesitter/highlight_spec.lua @@ -11,7 +11,7 @@ local eq = helpers.eq before_each(clear) -local hl_query = [[ +local hl_query_c = [[ (ERROR) @error "if" @keyword @@ -47,7 +47,7 @@ local hl_query = [[ (comment) @comment ]] -local hl_text = [[ +local hl_text_c = [[ /// Schedule Lua callback on main loop's event queue static int nlua_schedule(lua_State *const lstate) { @@ -64,7 +64,7 @@ static int nlua_schedule(lua_State *const lstate) return 0; }]] -local test_text = [[ +local test_text_c = [[ void ui_refresh(void) { int width = INT_MAX, height = INT_MAX; @@ -85,7 +85,7 @@ void ui_refresh(void) } }]] -describe('treesitter highlighting', function() +describe('treesitter highlighting (C)', function() local screen before_each(function() @@ -105,13 +105,13 @@ describe('treesitter highlighting', function() [11] = {foreground = Screen.colors.Cyan4}; } - exec_lua([[ hl_query = ... ]], hl_query) + exec_lua([[ hl_query = ... ]], hl_query_c) command [[ hi link @error ErrorMsg ]] command [[ hi link @warning WarningMsg ]] end) it('is updated with edits', function() - insert(hl_text) + insert(hl_text_c) screen:expect{grid=[[ /// Schedule Lua callback on main loop's event queue | static int nlua_schedule(lua_State *const lstate) | @@ -274,7 +274,7 @@ describe('treesitter highlighting', function() end) it('is updated with :sort', function() - insert(test_text) + insert(test_text_c) exec_lua [[ local parser = vim.treesitter.get_parser(0, "c") test_hl = vim.treesitter.highlighter.new(parser, {queries = {c = hl_query}}) @@ -351,7 +351,7 @@ describe('treesitter highlighting', function() [1] = {bold = true, foreground = Screen.colors.SeaGreen4}; } - insert(test_text) + insert(test_text_c) screen:expect{ grid= [[ int width = INT_MAX, height = INT_MAX; | @@ -510,7 +510,7 @@ describe('treesitter highlighting', function() end) it("supports highlighting with custom highlight groups", function() - insert(hl_text) + insert(hl_text_c) exec_lua [[ local parser = vim.treesitter.get_parser(0, "c") @@ -692,7 +692,7 @@ describe('treesitter highlighting', function() end) it("supports conceal attribute", function() - insert(hl_text) + insert(hl_text_c) -- conceal can be empty or a single cchar. exec_lua [=[ @@ -753,3 +753,65 @@ describe('treesitter highlighting', function() eq(nil, get_hl"@total.nonsense.but.a.lot.of.dots") end) end) + +describe('treesitter highlighting (help)', function() + local screen + + before_each(function() + screen = Screen.new(40, 6) + screen:attach() + screen:set_default_attr_ids { + [1] = {foreground = Screen.colors.Blue1}; + [2] = {bold = true, foreground = Screen.colors.Blue1}; + [3] = {bold = true, foreground = Screen.colors.Brown}; + [4] = {foreground = Screen.colors.Cyan4}; + [5] = {foreground = Screen.colors.Magenta1}; + } + end) + + it("correctly redraws added/removed injections", function() + insert[[ + >ruby + -- comment + local this_is = 'actually_lua' + < + ]] + + exec_lua [[ + vim.bo.filetype = 'help' + vim.treesitter.start() + ]] + + screen:expect{grid=[[ + {1:>ruby} | + {1: -- comment} | + {1: local this_is = 'actually_lua'} | + < | + ^ | + | + ]]} + + helpers.curbufmeths.set_text(0, 1, 0, 5, {'lua'}) + + screen:expect{grid=[[ + {1:>lua} | + {1: -- comment} | + {1: }{3:local}{1: }{4:this_is}{1: }{3:=}{1: }{5:'actually_lua'} | + < | + ^ | + | + ]]} + + helpers.curbufmeths.set_text(0, 1, 0, 4, {'ruby'}) + + screen:expect{grid=[[ + {1:>ruby} | + {1: -- comment} | + {1: local this_is = 'actually_lua'} | + < | + ^ | + | + ]]} + end) + +end)