diff --git a/runtime/lua/vim/shared.lua b/runtime/lua/vim/shared.lua index 4d753d727a..200ac44e86 100644 --- a/runtime/lua/vim/shared.lua +++ b/runtime/lua/vim/shared.lua @@ -796,6 +796,61 @@ do return type(val) == t or (t == 'callable' and vim.is_callable(val)) end + --- @param param_name string + --- @param spec vim.validate.Spec + --- @return string? + local function is_param_valid(param_name, spec) + if type(spec) ~= 'table' then + return string.format('opt[%s]: expected table, got %s', param_name, type(spec)) + end + + local val = spec[1] -- Argument value + local types = spec[2] -- Type name, or callable + local optional = (true == spec[3]) + + if type(types) == 'string' then + types = { types } + end + + if vim.is_callable(types) then + -- Check user-provided validation function + local valid, optional_message = types(val) + if not valid then + local error_message = + string.format('%s: expected %s, got %s', param_name, (spec[3] or '?'), tostring(val)) + if optional_message ~= nil then + error_message = string.format('%s. Info: %s', error_message, optional_message) + end + + return error_message + end + elseif type(types) == 'table' then + local success = false + for i, t in ipairs(types) do + local t_name = type_names[t] + if not t_name then + return string.format('invalid type name: %s', t) + end + types[i] = t_name + + if (optional and val == nil) or _is_type(val, t_name) then + success = true + break + end + end + if not success then + return string.format( + '%s: expected %s, got %s', + param_name, + table.concat(types, '|'), + type(val) + ) + end + else + return string.format('invalid type name: %s', tostring(types)) + end + end + --- @param opt table --- @return boolean, string? local function is_valid(opt) @@ -803,56 +858,19 @@ do return false, string.format('opt: expected table, got %s', type(opt)) end - for param_name, spec in vim.spairs(opt) do - if type(spec) ~= 'table' then - return false, string.format('opt[%s]: expected table, got %s', param_name, type(spec)) + local report --- @type table? + + for param_name, spec in pairs(opt) do + local msg = is_param_valid(param_name, spec) + if msg then + report = report or {} + report[param_name] = msg end + end - local val = spec[1] -- Argument value - local types = spec[2] -- Type name, or callable - local optional = (true == spec[3]) - - if type(types) == 'string' then - types = { types } - end - - if vim.is_callable(types) then - -- Check user-provided validation function - local valid, optional_message = types(val) - if not valid then - local error_message = - string.format('%s: expected %s, got %s', param_name, (spec[3] or '?'), tostring(val)) - if optional_message ~= nil then - error_message = error_message .. string.format('. Info: %s', optional_message) - end - - return false, error_message - end - elseif type(types) == 'table' then - local success = false - for i, t in ipairs(types) do - local t_name = type_names[t] - if not t_name then - return false, string.format('invalid type name: %s', t) - end - types[i] = t_name - - if (optional and val == nil) or _is_type(val, t_name) then - success = true - break - end - end - if not success then - return false, - string.format( - '%s: expected %s, got %s', - param_name, - table.concat(types, '|'), - type(val) - ) - end - else - return false, string.format('invalid type name: %s', tostring(types)) + if report then + for _, msg in vim.spairs(report) do -- luacheck: ignore + return false, msg end end