This commit is contained in:
Riley Bruins 2024-09-16 10:41:37 -06:00 committed by GitHub
commit 2677e26cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 191 additions and 25 deletions

View File

@ -3,9 +3,6 @@ local M = {}
-- TODO(lewis6991): Private for now until: -- TODO(lewis6991): Private for now until:
-- - There are other places in the codebase that could benefit from this -- - There are other places in the codebase that could benefit from this
-- (e.g. LSP), but might require other changes to accommodate. -- (e.g. LSP), but might require other changes to accommodate.
-- - Invalidation of the cache needs to be controllable. Using weak tables
-- is an acceptable invalidation policy, but it shouldn't be the only
-- one.
-- - I don't think the story around `hash` is completely thought out. We -- - I don't think the story around `hash` is completely thought out. We
-- may be able to have a good default hash by hashing each argument, -- may be able to have a good default hash by hashing each argument,
-- so basically a better 'concat'. -- so basically a better 'concat'.
@ -17,6 +14,10 @@ local M = {}
--- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the --- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the
--- cache will be invalidated whenever Lua does garbage collection. --- cache will be invalidated whenever Lua does garbage collection.
--- ---
--- The cache can also be manually invalidated by calling `:clear()` on the returned object.
--- Calling this function with no arguments clears the entire cache; otherwise, the arguments will
--- be interpreted as function inputs, and only the cache entry at their hash will be cleared.
---
--- The memoized function returns shared references so be wary about --- The memoized function returns shared references so be wary about
--- mutating return values. --- mutating return values.
--- ---
@ -32,11 +33,12 @@ local M = {}
--- first n arguments passed to {fn}. --- first n arguments passed to {fn}.
--- ---
--- @param fn F Function to memoize. --- @param fn F Function to memoize.
--- @param strong? boolean Do not use a weak table --- @param weak? boolean Use a weak table (default `true`)
--- @return F # Memoized version of {fn} --- @return F # Memoized version of {fn}
--- @nodoc --- @nodoc
function M._memoize(hash, fn, strong) function M._memoize(hash, fn, weak)
return require('vim.func._memoize')(hash, fn, strong) -- this is wrapped in a function to lazily require the module
return require('vim.func._memoize')(hash, fn, weak)
end end
return M return M

View File

@ -1,5 +1,7 @@
--- Module for private utility functions --- Module for private utility functions
--- @alias vim.func.MemoObj { _hash: (fun(...): any), _weak: boolean?, _cache: table<any> }
--- @param argc integer? --- @param argc integer?
--- @return fun(...): any --- @return fun(...): any
local function concat_hash(argc) local function concat_hash(argc)
@ -33,31 +35,51 @@ local function resolve_hash(hash)
return hash return hash
end end
--- @param weak boolean?
--- @return table
local create_cache = function(weak)
return setmetatable({}, {
__mode = weak ~= false and 'kv',
})
end
--- @generic F: function --- @generic F: function
--- @param hash integer|string|fun(...): any --- @param hash integer|string|fun(...): any
--- @param fn F --- @param fn F
--- @param strong? boolean --- @param weak? boolean
--- @return F --- @return F
return function(hash, fn, strong) return function(hash, fn, weak)
vim.validate({ vim.validate({
hash = { hash, { 'number', 'string', 'function' } }, hash = { hash, { 'number', 'string', 'function' } },
fn = { fn, 'function' }, fn = { fn, 'function' },
weak = { weak, 'boolean', true },
}) })
---@type table<any,table<any,any>> --- @type vim.func.MemoObj
local cache = {} local obj = {
if not strong then _cache = create_cache(weak),
setmetatable(cache, { __mode = 'kv' }) _hash = resolve_hash(hash),
end _weak = weak,
--- @param self vim.func.MemoObj
clear = function(self, ...)
if select('#', ...) == 0 then
self._cache = create_cache(self._weak)
return
end
local key = self._hash(...)
self._cache[key] = nil
end,
}
hash = resolve_hash(hash) return setmetatable(obj, {
--- @param self vim.func.MemoObj
return function(...) __call = function(self, ...)
local key = hash(...) local key = self._hash(...)
if cache[key] == nil then local cache = self._cache
cache[key] = vim.F.pack_len(fn(...)) if cache[key] == nil then
end cache[key] = vim.F.pack_len(fn(...))
end
return vim.F.unpack_len(cache[key]) return vim.F.unpack_len(cache[key])
end end,
})
end end

View File

@ -861,8 +861,8 @@ function Query:iter_captures(node, source, start, stop)
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
local apply_directives = memoize(match_id_hash, self.apply_directives, true) local apply_directives = memoize(match_id_hash, self.apply_directives, false)
local match_preds = memoize(match_id_hash, self.match_preds, true) local match_preds = memoize(match_id_hash, self.match_preds, false)
local function iter(end_line) local function iter(end_line)
local capture, captured_node, match = cursor:next_capture() local capture, captured_node, match = cursor:next_capture()

View File

@ -0,0 +1,142 @@
local t = require('test.testutil')
local n = require('test.functional.testnvim')()
local clear = n.clear
local exec_lua = n.exec_lua
local eq = t.eq
describe('vim.func._memoize', function()
before_each(clear)
it('caches function results based on their parameters', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat', function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
collectgarbage('stop')
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
collectgarbage('restart')
]])
eq(1, exec_lua([[return _G.count]]))
end)
it('caches function results using a weak table by default', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat-2', function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
adder(3, -4)
collectgarbage()
adder(3, -4)
collectgarbage()
adder(3, -4)
]])
eq(3, exec_lua([[return _G.count]]))
end)
it('can cache using a strong table', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat-2', function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end, false)
adder(3, -4)
collectgarbage()
adder(3, -4)
collectgarbage()
adder(3, -4)
]])
eq(1, exec_lua([[return _G.count]]))
end)
it('can clear a single cache entry', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize(function(arg1, arg2)
return tostring(arg1) .. '%%' .. tostring(arg2)
end, function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
collectgarbage('stop')
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder(3, -4)
adder:clear(3, -4)
adder(3, -4)
collectgarbage('restart')
]])
eq(2, exec_lua([[return _G.count]]))
end)
it('can clear the entire cache', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize(function(arg1, arg2)
return tostring(arg1) .. '%%' .. tostring(arg2)
end, function(arg1, arg2)
_G.count = _G.count + 1
return arg1 + arg2
end)
collectgarbage('stop')
adder(1, 2)
adder(3, -4)
adder(1, 2)
adder(3, -4)
adder(1, 2)
adder(3, -4)
adder:clear()
adder(1, 2)
adder(3, -4)
collectgarbage('restart')
]])
eq(4, exec_lua([[return _G.count]]))
end)
it('can cache functions that return nil', function()
exec_lua([[
_G.count = 0
local adder = vim.func._memoize('concat', function(arg1, arg2)
_G.count = _G.count + 1
return nil
end)
collectgarbage('stop')
adder(1, 2)
adder(1, 2)
adder(1, 2)
adder(1, 2)
adder:clear()
adder(1, 2)
collectgarbage('restart')
]])
eq(2, exec_lua([[return _G.count]]))
end)
end)