From 6debb1852355e0112ce75a5b4aed714ba1469ddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maria=20Jos=C3=A9=20Solano?= Date: Fri, 15 Sep 2023 12:45:40 -0700 Subject: [PATCH] refactor(treesitter): remove duplicated diagnostic code (#24976) * refactor(treesitter): remove duplicated diagnostic code * fixup!: fix type errors * fixup!: add type namespace --- runtime/lua/vim/treesitter/_query_linter.lua | 177 +++++-------------- 1 file changed, 48 insertions(+), 129 deletions(-) diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua index 3dd0177a81..abf0bf345d 100644 --- a/runtime/lua/vim/treesitter/_query_linter.lua +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -1,8 +1,6 @@ local api = vim.api local namespace = api.nvim_create_namespace('vim.treesitter.query_linter') --- those node names exist for every language -local BUILT_IN_NODE_NAMES = { '_', 'ERROR' } local M = {} @@ -10,11 +8,13 @@ local M = {} --- @field langs string[] --- @field clear boolean +--- @alias vim.treesitter.ParseError {msg: string, range: Range4} + --- @private --- Caches parse results for queries for each language. --- Entries of parse_cache[lang][query_text] will either be true for successful parse or contain the ---- error message of the parse ---- @type table> +--- message and range of the parse error. +--- @type table> local parse_cache = {} --- Contains language dependent context for the query linter @@ -26,20 +26,16 @@ local parse_cache = {} --- @private --- Adds a diagnostic for node in the query buffer --- @param diagnostics Diagnostic[] ---- @param node TSNode ---- @param buf integer +--- @param range Range4 --- @param lint string --- @param lang string? -local function add_lint_for_node(diagnostics, node, buf, lint, lang) - local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') - --- @type string - local message = lint .. ': ' .. node_text - local error_range = { node:range() } +local function add_lint_for_node(diagnostics, range, lint, lang) + local message = lint:gsub('\n', ' ') diagnostics[#diagnostics + 1] = { - lnum = error_range[1], - end_lnum = error_range[3], - col = error_range[2], - end_col = error_range[4], + lnum = range[1], + end_lnum = range[3], + col = range[2], + end_col = range[4], severity = vim.diagnostic.ERROR, message = message, source = lang, @@ -91,6 +87,31 @@ local lint_query = [[;; query (ERROR) @error ]] +--- @private +--- @param err string +--- @param node TSNode +--- @return vim.treesitter.ParseError +local function get_error_entry(err, node) + local start_line, start_col = node:range() + local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string + start_line, start_col = + start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1 + local end_line, end_col = start_line, start_col + if msg:match('^Invalid syntax') or msg:match('^Impossible') then + -- Use the length of the underlined node + local underlined = vim.split(err, '\n')[2] + end_col = end_col + #underlined + elseif msg:match('^Invalid') then + -- Use the length of the problematic type/capture/field + end_col = end_col + #msg:match('"([^"]+)"') + end + + return { + msg = msg, + range = { start_line, start_col, end_line, end_col }, + } +end + --- @private --- @param node TSNode --- @param buf integer @@ -106,104 +127,19 @@ local function check_toplevel(node, buf, lang, diagnostics) local lang_cache = parse_cache[lang] if lang_cache[query_text] == nil then - local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) + local cache_val, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query - if not ok and type(err) == 'string' then - err = err:match('.-:%d+: (.+)') + if not cache_val and type(err) == 'string' then + cache_val = get_error_entry(err, node) end - lang_cache[query_text] = ok or err + lang_cache[query_text] = cache_val end local cache_entry = lang_cache[query_text] - if type(cache_entry) == 'string' then - add_lint_for_node(diagnostics, node, buf, cache_entry, lang) - end -end - ---- @private ---- @param node TSNode ---- @param buf integer ---- @param lang string ---- @param parser_info table ---- @param diagnostics Diagnostic[] -local function check_field(node, buf, lang, parser_info, diagnostics) - local field_name = vim.treesitter.get_node_text(node, buf) - if not vim.tbl_contains(parser_info.fields, field_name) then - add_lint_for_node(diagnostics, node, buf, 'Invalid field', lang) - end -end - ---- @private ---- @param node TSNode ---- @param buf integer ---- @param lang string ---- @param parser_info (table) ---- @param diagnostics Diagnostic[] -local function check_node(node, buf, lang, parser_info, diagnostics) - local node_type = vim.treesitter.get_node_text(node, buf) - local is_named = node_type:sub(1, 1) ~= '"' - - if not is_named then - node_type = node_type:gsub('"(.*)".*$', '%1'):gsub('\\(.)', '%1') - end - - local found = vim.tbl_contains(BUILT_IN_NODE_NAMES, node_type) - or vim.tbl_contains(parser_info.symbols, function(s) - return vim.deep_equal(s, { node_type, is_named }) - end, { predicate = true }) - - if not found then - add_lint_for_node(diagnostics, node, buf, 'Invalid node type', lang) - end -end - ---- @private ---- @param node TSNode ---- @param buf integer ---- @param is_predicate boolean ---- @return string -local function get_predicate_name(node, buf, is_predicate) - local name = vim.treesitter.get_node_text(node, buf) - if is_predicate then - if vim.startswith(name, 'not-') then - --- @type string - name = name:sub(string.len('not-') + 1) - end - return name .. '?' - end - return name .. '!' -end - ---- @private ---- @param predicate_node TSNode ---- @param predicate_type_node TSNode ---- @param buf integer ---- @param lang string? ---- @param diagnostics Diagnostic[] -local function check_predicate(predicate_node, predicate_type_node, buf, lang, diagnostics) - local type_string = vim.treesitter.get_node_text(predicate_type_node, buf) - - -- Quirk of the query grammar that directives are also predicates! - if type_string == '?' then - if - not vim.tbl_contains( - vim.treesitter.query.list_predicates(), - get_predicate_name(predicate_node, buf, true) - ) - then - add_lint_for_node(diagnostics, predicate_node, buf, 'Unknown predicate', lang) - end - elseif type_string == '!' then - if - not vim.tbl_contains( - vim.treesitter.query.list_directives(), - get_predicate_name(predicate_node, buf, false) - ) - then - add_lint_for_node(diagnostics, predicate_node, buf, 'Unknown directive', lang) - end + if type(cache_entry) ~= 'boolean' then + add_lint_for_node(diagnostics, cache_entry.range, cache_entry.msg, lang) end end @@ -214,8 +150,6 @@ end --- @param lang_context QueryLinterLanguageContext --- @param diagnostics Diagnostic[] local function lint_match(buf, match, query, lang_context, diagnostics) - local predicate --- @type TSNode - local predicate_type --- @type TSNode local lang = lang_context.lang local parser_info = lang_context.parser_info @@ -223,31 +157,16 @@ local function lint_match(buf, match, query, lang_context, diagnostics) local cap_id = query.captures[id] -- perform language-independent checks only for first lang - if lang_context.is_first_lang then - if cap_id == 'error' then - add_lint_for_node(diagnostics, node, buf, 'Syntax error') - elseif cap_id == 'predicate.name' then - predicate = node - elseif cap_id == 'predicate.type' then - predicate_type = node - end + if lang_context.is_first_lang and cap_id == 'error' then + local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') + add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) end -- other checks rely on Neovim parser introspection - if lang and parser_info then - if cap_id == 'toplevel' then - check_toplevel(node, buf, lang, diagnostics) - elseif cap_id == 'field' then - check_field(node, buf, lang, parser_info, diagnostics) - elseif cap_id == 'node.named' or cap_id == 'node.anonymous' then - check_node(node, buf, lang, parser_info, diagnostics) - end + if lang and parser_info and cap_id == 'toplevel' then + check_toplevel(node, buf, lang, diagnostics) end end - - if predicate and predicate_type then - check_predicate(predicate, predicate_type, buf, lang, diagnostics) - end end --- @private @@ -339,7 +258,7 @@ function M.omnifunc(findstart, base) end end for _, s in pairs(parser_info.symbols) do - local text = s[2] and s[1] or '"' .. s[1]:gsub([[\]], [[\\]]) .. '"' + local text = s[2] and s[1] or '"' .. s[1]:gsub([[\]], [[\\]]) .. '"' ---@type string if text:find(base, 1, true) then table.insert(items, text) end