diff --git a/lua/gitlinker.lua b/lua/gitlinker.lua index ed232f3..e203ccc 100644 --- a/lua/gitlinker.lua +++ b/lua/gitlinker.lua @@ -3,8 +3,8 @@ local str = require("gitlinker.commons.str") local num = require("gitlinker.commons.num") local LogLevels = require("gitlinker.commons.logging").LogLevels local logging = require("gitlinker.commons.logging") +local async = require("gitlinker.commons.async") -local async = require("gitlinker.async") local configs = require("gitlinker.configs") local range = require("gitlinker.range") local linker = require("gitlinker.linker") @@ -221,7 +221,7 @@ local _link = function(opts) lk.rev = opts.rev end - async.scheduler() + async.schedule() local ok, url = pcall(opts.router, lk, true) -- logger:debug( -- "|link| ok:%s, url:%s, router:%s", @@ -273,7 +273,7 @@ local _link = function(opts) end --- @type fun(opts:{action:gitlinker.Action?,router:gitlinker.Router,lstart:integer,lend:integer,remote:string?,file:string?,rev:string?}):string? -local _void_link = async.void(_link) +local _sync_link = async.sync(1, _link) --- @param args string? --- @return {router_type:string,remote:string?,file:string?,rev:string?} @@ -332,7 +332,7 @@ local function setup(opts) local lstart = math.min(r.lstart, r.lend, command_opts.line1, command_opts.line2) local lend = math.max(r.lstart, r.lend, command_opts.line1, command_opts.line2) local parsed = _parse_args(args) - _void_link({ + _sync_link({ action = command_opts.bang and require("gitlinker.actions").system or require("gitlinker.actions").clipboard, router = function(lk) @@ -392,7 +392,7 @@ local function link_api(opts) opts.lend = math.max(r.lstart, r.lend) end - _void_link({ + _sync_link({ action = opts.action, router = opts.router, lstart = opts.lstart, @@ -408,7 +408,7 @@ end local M = { _url_template_engine = _url_template_engine, _worker = _worker, - _void_link = _void_link, + _sync_link = _sync_link, _router = _router, _browse = _browse, _blame = _blame, diff --git a/lua/gitlinker/async.lua b/lua/gitlinker/async.lua deleted file mode 100644 index 1fad174..0000000 --- a/lua/gitlinker/async.lua +++ /dev/null @@ -1,277 +0,0 @@ ----@diagnostic disable: luadoc-miss-module-name, undefined-doc-name ---- Small async library for Neovim plugins ---- @module async - --- Store all the async threads in a weak table so we don't prevent them from --- being garbage collected -local handles = setmetatable({}, { __mode = "k" }) - -local M = {} - --- Note: coroutine.running() was changed between Lua 5.1 and 5.2: --- - 5.1: Returns the running coroutine, or nil when called by the main thread. --- - 5.2: Returns the running coroutine plus a boolean, true when the running --- coroutine is the main one. --- --- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT --- --- We need to handle both. - ---- Returns whether the current execution context is async. ---- ---- @treturn boolean? -function M.running() - local current = coroutine.running() - if current and handles[current] then - return true - end -end - -local function is_Async_T(handle) - if - handle - and type(handle) == "table" - and vim.is_callable(handle.cancel) - and vim.is_callable(handle.is_cancelled) - then - return true - end -end - -local Async_T = {} - --- Analogous to uv.close -function Async_T:cancel(cb) - -- Cancel anything running on the event loop - if self._current and not self._current:is_cancelled() then - self._current:cancel(cb) - end -end - -function Async_T.new(co) - local handle = setmetatable({}, { __index = Async_T }) - handles[co] = handle - return handle -end - --- Analogous to uv.is_closing -function Async_T:is_cancelled() - return self._current and self._current:is_cancelled() -end - ---- Run a function in an async context. ---- @tparam function func ---- @tparam function callback ---- @tparam any ... Arguments for func ---- @treturn async_t Handle -function M.run(func, callback, ...) - vim.validate({ - func = { func, "function" }, - callback = { callback, "function", true }, - }) - - local co = coroutine.create(func) - local handle = Async_T.new(co) - - local function step(...) - local ret = { coroutine.resume(co, ...) } - local ok = ret[1] - - if not ok then - local err = ret[2] - error( - string.format("The coroutine failed with this message:\n%s\n%s", err, debug.traceback(co)) - ) - end - - if coroutine.status(co) == "dead" then - if callback then - callback(unpack(ret, 4, table.maxn(ret))) - end - return - end - - local nargs, fn = ret[2], ret[3] - local args = { select(4, unpack(ret)) } - - assert(type(fn) == "function", "type error :: expected func") - - args[nargs] = step - - local r = fn(unpack(args, 1, nargs)) - if is_Async_T(r) then - handle._current = r - end - end - - step(...) - return handle -end - -local function wait(argc, func, ...) - vim.validate({ - argc = { argc, "number" }, - func = { func, "function" }, - }) - - -- Always run the wrapped functions in xpcall and re-raise the error in the - -- coroutine. This makes pcall work as normal. - local function pfunc(...) - local args = { ... } - local cb = args[argc] - args[argc] = function(...) - cb(true, ...) - end - xpcall(func, function(err) - cb(false, err, debug.traceback()) - end, unpack(args, 1, argc)) - end - - local ret = { coroutine.yield(argc, pfunc, ...) } - - local ok = ret[1] - if not ok then - local _, err, traceback = unpack(ret) - error(string.format("Wrapped function failed: %s\n%s", err, traceback)) - end - - return unpack(ret, 2, table.maxn(ret)) -end - ---- Wait on a callback style function ---- ---- @tparam integer? argc The number of arguments of func. ---- @tparam function func callback style function to execute ---- @tparam any ... Arguments for func -function M.wait(...) - if type(select(1, ...)) == "number" then - return wait(...) - end - - -- Assume argc is equal to the number of passed arguments. - return wait(select("#", ...) - 1, ...) -end - ---- Use this to create a function which executes in an async context but ---- called from a non-async context. Inherently this cannot return anything ---- since it is non-blocking ---- @tparam function func ---- @tparam number argc The number of arguments of func. Defaults to 0 ---- @tparam boolean strict Error when called in non-async context ---- @treturn function(...):async_t -function M.create(func, argc, strict) - vim.validate({ - func = { func, "function" }, - argc = { argc, "number", true }, - }) - argc = argc or 0 - return function(...) - if M.running() then - if strict then - error("This function must run in a non-async context") - end - return func(...) - end - local callback = select(argc + 1, ...) - return M.run(func, callback, unpack({ ... }, 1, argc)) - end -end - ---- Create a function which executes in an async context but ---- called from a non-async context. ---- @tparam function func ---- @tparam boolean strict Error when called in non-async context -function M.void(func, strict) - vim.validate({ func = { func, "function" } }) - return function(...) - if M.running() then - if strict then - error("This function must run in a non-async context") - end - return func(...) - end - return M.run(func, nil, ...) - end -end - ---- Creates an async function with a callback style function. ---- ---- @tparam function func A callback style function to be converted. The last argument must be the callback. ---- @tparam integer argc The number of arguments of func. Must be included. ---- @tparam boolean strict Error when called in non-async context ---- @treturn function Returns an async function -function M.wrap(func, argc, strict) - vim.validate({ - argc = { argc, "number" }, - }) - return function(...) - if not M.running() then - if strict then - error("This function must run in an async context") - end - return func(...) - end - return M.wait(argc, func, ...) - end -end - ---- Run a collection of async functions (`thunks`) concurrently and return when ---- all have finished. ---- @tparam function[] thunks ---- @tparam integer n Max number of thunks to run concurrently ---- @tparam function interrupt_check Function to abort thunks between calls -function M.join(thunks, n, interrupt_check) - local function run(finish) - if #thunks == 0 then - return finish() - end - - local remaining = { select(n + 1, unpack(thunks)) } - local to_go = #thunks - - local ret = {} - - local function cb(...) - ret[#ret + 1] = { ... } - to_go = to_go - 1 - if to_go == 0 then - finish(ret) - elseif not interrupt_check or not interrupt_check() then - if #remaining > 0 then - local next_task = table.remove(remaining) - next_task(cb) - end - end - end - - for i = 1, math.min(n, #thunks) do - thunks[i](cb) - end - end - - if not M.running() then - return run - end - return M.wait(1, false, run) -end - ---- Partially applying arguments to an async function ---- @tparam function fn ---- @param ... arguments to apply to `fn` -function M.curry(fn, ...) - local args = { ... } - local nargs = select("#", ...) - return function(...) - local other = { ... } - for i = 1, select("#", ...) do - args[nargs + i] = other[i] - end - fn(unpack(args)) - end -end - ---- An async function that when called will yield to the Neovim scheduler to be ---- able to call the neovim API. -M.scheduler = M.wrap(vim.schedule, 1, false) - -return M diff --git a/lua/gitlinker/commons/async.lua b/lua/gitlinker/commons/async.lua new file mode 100644 index 0000000..833ad37 --- /dev/null +++ b/lua/gitlinker/commons/async.lua @@ -0,0 +1,189 @@ +---@diagnostic disable +--- Small async library for Neovim plugins + +local function validate_callback(func, callback) + if callback and type(callback) ~= 'function' then + local info = debug.getinfo(func, 'nS') + error( + string.format( + 'Callback is not a function for %s, got: %s', + info.short_src .. ':' .. info.linedefined, + vim.inspect(callback) + ) + ) + end +end + +-- Coroutine.running() was changed between Lua 5.1 and 5.2: +-- - 5.1: Returns the running coroutine, or nil when called by the main thread. +-- - 5.2: Returns the running coroutine plus a boolean, true when the running +-- coroutine is the main one. +-- +-- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT +-- +-- We need to handle both. +local _main_co_or_nil = coroutine.running() + +--- Executes a future with a callback when it is done +--- @param func function +--- @param callback function? +--- @param ... any +local function run(func, callback, ...) + validate_callback(func, callback) + + local co = coroutine.create(func) + + local function step(...) + local ret = { coroutine.resume(co, ...) } + local stat = ret[1] + + if not stat then + local err = ret[2] --[[@as string]] + error( + string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co)) + ) + end + + if coroutine.status(co) == 'dead' then + if callback then + callback(unpack(ret, 2, table.maxn(ret))) + end + return + end + + --- @type integer, fun(...: any): any + local nargs, fn = ret[2], ret[3] + assert(type(fn) == 'function', 'type error :: expected func') + + --- @type any[] + local args = { unpack(ret, 4, table.maxn(ret)) } + args[nargs] = step + fn(unpack(args, 1, nargs)) + end + + step(...) +end + +local M = {} + +---Use this to create a function which executes in an async context but +---called from a non-async context. Inherently this cannot return anything +---since it is non-blocking +--- @generic F: function +--- @param argc integer +--- @param func async F +--- @return F +function M.sync(argc, func) + return function(...) + assert(not coroutine.running()) + local callback = select(argc + 1, ...) + run(func, callback, unpack({ ... }, 1, argc)) + end +end + +--- @param argc integer +--- @param func function +--- @param ... any +--- @return any ... +function M.wait(argc, func, ...) + -- Always run the wrapped functions in xpcall and re-raise the error in the + -- coroutine. This makes pcall work as normal. + local function pfunc(...) + local args = { ... } --- @type any[] + local cb = args[argc] + args[argc] = function(...) + cb(true, ...) + end + xpcall(func, function(err) + cb(false, err, debug.traceback()) + end, unpack(args, 1, argc)) + end + + local ret = { coroutine.yield(argc, pfunc, ...) } + + local ok = ret[1] + if not ok then + --- @type string, string + local err, traceback = ret[2], ret[3] + error(string.format('Wrapped function failed: %s\n%s', err, traceback)) + end + + return unpack(ret, 2, table.maxn(ret)) +end + +function M.run(func, ...) + return run(func, nil, ...) +end + +--- Creates an async function with a callback style function. +--- @param argc integer +--- @param func function +--- @return function +function M.wrap(argc, func) + assert(type(argc) == 'number') + assert(type(func) == 'function') + return function(...) + return M.wait(argc, func, ...) + end +end + +--- @generic R +--- @param n integer Mx number of jobs to run concurrently +--- @param thunks (fun(cb: function): R)[] +--- @param interrupt_check fun()? +--- @param callback fun(ret: R[][]) +M.join = M.wrap(4, function(n, thunks, interrupt_check, callback) + n = math.min(n, #thunks) + + local ret = {} --- @type any[][] + + if #thunks == 0 then + callback(ret) + return + end + + local remaining = { unpack(thunks, n + 1) } + local to_go = #thunks + + local function cb(...) + ret[#ret + 1] = { ... } + to_go = to_go - 1 + if to_go == 0 then + callback(ret) + elseif not interrupt_check or not interrupt_check() then + if #remaining > 0 then + local next_thunk = table.remove(remaining, 1) + next_thunk(cb) + end + end + end + + for i = 1, n do + thunks[i](cb) + end +end) + +---Useful for partially applying arguments to an async function +--- @param fn function +--- @param ... any +--- @return function +function M.curry(fn, ...) + --- @type integer, any[] + local nargs, args = select('#', ...), { ... } + + return function(...) + local other = { ... } + for i = 1, select('#', ...) do + args[nargs + i] = other[i] + end + return fn(unpack(args)) + end +end + +if vim.schedule then + --- An async function that when called will yield to the Neovim scheduler to be + --- able to call the API. + M.schedule = M.wrap(1, vim.schedule) +end + +return M diff --git a/lua/gitlinker/commons/version.txt b/lua/gitlinker/commons/version.txt index a8f5438..e9dbbda 100644 --- a/lua/gitlinker/commons/version.txt +++ b/lua/gitlinker/commons/version.txt @@ -1 +1 @@ -21.0.1 +21.1.0 diff --git a/lua/gitlinker/git.lua b/lua/gitlinker/git.lua index c5eab97..2d7933a 100644 --- a/lua/gitlinker/git.lua +++ b/lua/gitlinker/git.lua @@ -1,7 +1,7 @@ local logging = require("gitlinker.commons.logging") local spawn = require("gitlinker.commons.spawn") local uv = require("gitlinker.commons.uv") -local async = require("gitlinker.async") +local async = require("gitlinker.commons.async") --- @class gitlinker.CmdResult --- @field stdout string[] @@ -42,7 +42,7 @@ function CmdResult:print_err(default) end --- NOTE: async functions can't have optional parameters so wrap it into another function without '_' -local _run_cmd = async.wrap(function(args, cwd, callback) +local _run_cmd = async.wrap(3, function(args, cwd, callback) local result = CmdResult:new() local logger = logging.get("gitlinker") logger:debug(string.format("|_run_cmd| args:%s, cwd:%s", vim.inspect(args), vim.inspect(cwd))) @@ -63,7 +63,7 @@ local _run_cmd = async.wrap(function(args, cwd, callback) logger:debug(string.format("|_run_cmd| result:%s", vim.inspect(result))) callback(result) end) -end, 3) +end) -- wrap the git command to do the right thing always --- @package diff --git a/lua/gitlinker/linker.lua b/lua/gitlinker/linker.lua index a3b1f51..c841a6b 100644 --- a/lua/gitlinker/linker.lua +++ b/lua/gitlinker/linker.lua @@ -1,9 +1,10 @@ local logging = require("gitlinker.commons.logging") local str = require("gitlinker.commons.str") +local async = require("gitlinker.commons.async") + local git = require("gitlinker.git") local path = require("gitlinker.path") local giturlparser = require("gitlinker.giturlparser") -local async = require("gitlinker.async") --- @return string? local function _get_buf_dir() @@ -88,7 +89,7 @@ local function make_linker(remote, file, rev) end -- logger.debug("|linker - Linker:make| rev:%s", vim.inspect(rev)) - async.scheduler() + async.schedule() if not file_provided then local buf_path_on_root = path.buffer_relpath(root) --[[@as string]] @@ -113,7 +114,7 @@ local function make_linker(remote, file, rev) -- vim.inspect(file_in_rev_result) -- ) - async.scheduler() + async.schedule() local file_changed = false if not file_provided then diff --git a/spec/gitlinker/git_spec.lua b/spec/gitlinker/git_spec.lua index e3da6ba..6801ef0 100644 --- a/spec/gitlinker/git_spec.lua +++ b/spec/gitlinker/git_spec.lua @@ -11,7 +11,7 @@ describe("gitlinker.git", function() vim.cmd([[ edit lua/gitlinker.lua ]]) end) - local async = require("gitlinker.async") + local async = require("gitlinker.commons.async") local git = require("gitlinker.git") local path = require("gitlinker.path") local gitlinker = require("gitlinker") diff --git a/spec/gitlinker/linker_spec.lua b/spec/gitlinker/linker_spec.lua index 240243e..e3a3cd6 100644 --- a/spec/gitlinker/linker_spec.lua +++ b/spec/gitlinker/linker_spec.lua @@ -13,7 +13,7 @@ describe("gitlinker.linker", function() vim.cmd([[ edit lua/gitlinker.lua ]]) end) - local async = require("gitlinker.async") + local async = require("gitlinker.commons.async") local github_actions = os.getenv("GITHUB_ACTIONS") == "true" local linker = require("gitlinker.linker") describe("[make_linker]", function() diff --git a/spec/gitlinker_spec.lua b/spec/gitlinker_spec.lua index ed72a84..bd8e5ce 100644 --- a/spec/gitlinker_spec.lua +++ b/spec/gitlinker_spec.lua @@ -696,9 +696,9 @@ describe("gitlinker", function() end) end) - describe("[_void_link]", function() + describe("[_sync_link]", function() it("link browse", function() - gitlinker._void_link({ + gitlinker._sync_link({ action = require("gitlinker.actions").clipboard, router = function(lk) return require("gitlinker")._router("browse", lk) @@ -708,7 +708,7 @@ describe("gitlinker", function() }) end) it("link blame", function() - gitlinker._void_link({ + gitlinker._sync_link({ action = require("gitlinker.actions").clipboard, router = function(lk) return require("gitlinker")._router("blame", lk)