diff --git a/runtime/lua/vim/diagnostic.lua b/runtime/lua/vim/diagnostic.lua index b42eece4c2..3321b6ad71 100644 --- a/runtime/lua/vim/diagnostic.lua +++ b/runtime/lua/vim/diagnostic.lua @@ -371,6 +371,39 @@ local function to_severity(severity) return severity end +--- @param severity vim.diagnostic.SeverityFilter +--- @return fun(vim.Diagnostic):boolean +local function severity_predicate(severity) + if type(severity) ~= 'table' then + severity = assert(to_severity(severity)) + ---@param d vim.Diagnostic + return function(d) + return d.severity == severity + end + end + if severity.min or severity.max then + --- @cast severity {min:vim.diagnostic.Severity,max:vim.diagnostic.Severity} + local min_severity = to_severity(severity.min) or M.severity.HINT + local max_severity = to_severity(severity.max) or M.severity.ERROR + + --- @param d vim.Diagnostic + return function(d) + return d.severity <= min_severity and d.severity >= max_severity + end + end + + --- @cast severity vim.diagnostic.Severity[] + local severities = {} --- @type table + for _, s in ipairs(severity) do + severities[assert(to_severity(s))] = true + end + + --- @param d vim.Diagnostic + return function(d) + return severities[d.severity] + end +end + --- @param severity vim.diagnostic.SeverityFilter --- @param diagnostics vim.Diagnostic[] --- @return vim.Diagnostic[] @@ -378,37 +411,7 @@ local function filter_by_severity(severity, diagnostics) if not severity then return diagnostics end - - if type(severity) ~= 'table' then - severity = assert(to_severity(severity)) - --- @param t vim.Diagnostic - return vim.tbl_filter(function(t) - return t.severity == severity - end, diagnostics) - end - - if severity.min or severity.max then - --- @cast severity {min:vim.diagnostic.Severity,max:vim.diagnostic.Severity} - local min_severity = to_severity(severity.min) or M.severity.HINT - local max_severity = to_severity(severity.max) or M.severity.ERROR - - --- @param t vim.Diagnostic - return vim.tbl_filter(function(t) - return t.severity <= min_severity and t.severity >= max_severity - end, diagnostics) - end - - --- @cast severity vim.diagnostic.Severity[] - - local severities = {} --- @type table - for _, s in ipairs(severity) do - severities[assert(to_severity(s))] = true - end - - --- @param t vim.Diagnostic - return vim.tbl_filter(function(t) - return severities[t.severity] - end, diagnostics) + return vim.tbl_filter(severity_predicate(severity), diagnostics) end --- @param bufnr integer @@ -714,10 +717,18 @@ local function get_diagnostics(bufnr, opts, clamp) end, }) + local match_severity = opts.severity and severity_predicate(opts.severity) + or function(_) + return true + end + ---@param b integer ---@param d vim.Diagnostic local function add(b, d) - if not opts.lnum or (opts.lnum >= d.lnum and opts.lnum <= (d.end_lnum or d.lnum)) then + if + match_severity(d) + and (not opts.lnum or (opts.lnum >= d.lnum and opts.lnum <= (d.end_lnum or d.lnum))) + then if clamp and api.nvim_buf_is_loaded(b) then local line_count = buf_line_count[b] - 1 if @@ -771,10 +782,6 @@ local function get_diagnostics(bufnr, opts, clamp) end end - if opts.severity then - diagnostics = filter_by_severity(opts.severity, diagnostics) - end - return diagnostics end