repotool/lib/cli.lua
2025-06-30 19:56:36 -05:00

943 lines
23 KiB
Lua

--
-- cmd is a way to declaratively describe command line interfaces
--
-- It is inspired by the excellent cmdliner[1] library for OCaml.
--
-- The main idea is to define command line interfaces with "terms".
--
-- There are "primitive terms" such as "options" and "arguments". Then terms
-- could be composed further with "application" or "table" term combinators.
--
-- Effectively this forms a tree of terms which describes (a) how to parse
-- command line arguments and then (b) how to compute a Lua value, finally (c)
-- it can be used to automatically generate help messages and man pages.
--
-- [1]: https://github.com/dbuenzli/cmdliner
--
--
-- PRELUDE
--
-- {{{
local argv = arg
--
-- Iterate both keys and then indecies in a sorted manner.
--
local function spairs(t)
local keys = {}
for k, _ in pairs(t) do
if type(k) == 'string' then
table.insert(keys, k)
end
end
table.sort(keys)
for idx, _ in ipairs(t) do
table.insert(keys, idx)
end
local it = ipairs(keys)
local i = 0
return function()
i, k = it(keys, i)
if i == nil then return nil end
return k, t[k]
end
end
-- }}}
--
-- PARSING AND EVALUATION
--
-- {{{
--
-- Split text line into an array of shell words respecting quoting.
--
local function shell_split(text)
local line = {}
local spat, epat, buf, quoted = [=[^(['"])]=], [=[(['"])$]=]
for str in text:gmatch("%S+") do
local squoted = str:match(spat)
local equoted = str:match(epat)
local escaped = str:match([=[(\*)['"]$]=])
if squoted and not quoted and not equoted then
buf, quoted = str, squoted
elseif buf and equoted == quoted and #escaped % 2 == 0 then
str, buf, quoted = buf .. ' ' .. str, nil, nil
elseif buf then
buf = buf .. ' ' .. str
end
if not buf then
table.insert(line, (str:gsub(spat, ""):gsub(epat, "")))
end
end
if buf then table.insert(line, buf) end
return line
end
local ZSH_COMPLETION_SCRIPT = [=[
function _NAME {
local -a completions
response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD="${CURRENT}" "NAME")}")
for type key desc in ${response}; do
if [[ "$type" == "item" ]]; then
completions+=("$key":"$desc")
elif [[ "$type" == "dir" ]]; then
_path_files -/
elif [[ "$type" == "file" ]]; then
_path_files -f
fi
done
if [ -n "$completions" ]; then
_describe -V unsorted completions -U
fi
}
compdef _NAME NAME
]=]
local BASH_COMPLETION_SCRIPT = [=[
_NAME() {
local IFS=$'\n'
while read type; read value; read _desc; do
if [[ $type == 'dir' ]] && (type compopt &> /dev/null); then
COMPREPLY=()
compopt -o dirnames
elif [[ $type == 'file' ]] && (type compopt &> /dev/null); then
COMPREPLY=()
compopt -o default
elif [[ $type == 'item' ]]; then
COMPREPLY+=($value)
fi
done < <(env COMP_WORDS="$(IFS=',' echo "${COMP_WORDS[*]}")" COMP_CWORD=$((COMP_CWORD+1)) "NAME")
return 0
}
_NAME_setup() {
complete -F _NAME NAME
}
_NAME_setup;
]=]
--
-- Raise an error which should be reported to user
--
local function err(msg, ...)
coroutine.yield {
action = 'error',
message = string.format(msg, ...),
}
end
--
-- Traverse terms and yield primitive terms (options and arguments).
--
local function primitives(term)
local queue = {term}
local function next()
local t = table.remove(queue, 1)
if t == nil then
return nil
elseif t.type == 'opt' then
return t
elseif t.type == 'arg' then
return t
elseif t.type == 'val' then
return next()
elseif t.type == 'app' then
table.insert(queue, t.f)
for _, a in ipairs(t.args) do
table.insert(queue, a)
end
return next()
elseif t.type == 'all' then
for _, v in spairs(t.spec) do
table.insert(queue, v)
end
return next()
else
assert(false, 'unknown term')
end
end
return next
end
--
-- A special token within line which specifies the position for completion.
--
local __COMPLETE__ = '__COMPLETE__'
local arg, opt
local function parse_and_eval(cmd, line)
local is_completion = line.cword ~= nil
--
-- Iterator which given a command `cmd` and an `idx` into `line` yield a new
-- `idx`, `term`, `value` triple.
--
local function terms(cmd, idx)
local opts = {}
for k, v in pairs(cmd.lookup.opts) do opts[k] = v end
local args = {unpack(cmd.lookup.args)}
idx = idx or 1
local function next()
local v = line[idx]
if v == nil then return end
if v:sub(1, 2) == "--" or v:sub(1, 1) == "-" then
local name, value = v, nil
-- Check if option is supplied as '--name=value' and parse it accordingly.
local sep = v:find("=", 1, true)
if sep then
name, value = v:sub(0, sep - 1), v:sub(sep + 1)
end
local o = opts[name]
if o == nil then
coroutine.yield {
action = 'error',
message = string.format("unknown option '%s'", name)
}
-- Recover by synthesizing a dummy option flag
o = opt { "--ERROR", flag = true }
end
if o.flag then
if value ~= nil then
coroutine.yield {
action = 'error',
message = string.format("unexpected value for option '%s'", name)
}
end
idx = idx + 1
return idx, o, true
else
if not value then
idx, value = idx + 1, line[idx + 1]
if value == __COMPLETE__ then
coroutine.yield {
action = 'completion',
cmd = cmd,
opt = o,
arg = nil,
}
end
if value == nil then
coroutine.yield {
action = 'error',
message = string.format("missing value for option '%s'", name)
}
end
end
idx = idx + 1
return idx, o, value
end
else
local a = args[1]
if v == __COMPLETE__ then
coroutine.yield {
action = 'completion',
cmd = cmd,
opt = nil,
arg = a,
}
elseif a == nil then
coroutine.yield {
action = 'error',
message = string.format("unexpected argument '%s'", v)
}
-- Recover by synthesizing a dummy arg
a = arg "ERROR"
end
if not a.plural then
table.remove(args, 1)
end
idx = idx + 1
return idx, a, v
end
end
return next
end
--
-- Eval `term` given values for `opts` and `args`.
--
local function eval(term, opts, args)
if term.type == 'opt' then
local v = opts[term.name]
if term.plural and v == nil then v = {} end
if term.flag and v == nil then v = false end
return v
elseif term.type == 'arg' then
local v = table.remove(args, 1)
if term.plural and v == nil then v = {} end
return v
elseif term.type == 'val' then
return term.v
elseif term.type == 'app' then
local v = {}
for i, t in ipairs(term.args) do
v[i] = eval(t, opts, args)
end
return term.func(unpack(v))
elseif term.type == 'all' then
local v = {}
for k, t in pairs(term.spec) do
v[k] = eval(t, opts, args)
end
return v
else
assert(false)
end
end
local function run(cmd, start_idx)
local has_subs = cmd.subs and #cmd.subs > 0
assert(not has_subs or #cmd.lookup.args == 1)
local opts, args = {}, {}
for idx, term, value in terms(cmd, start_idx) do
if term.type == 'opt' then
if not cmd.disable_help and term.name == "--help" then
coroutine.yield {
action = 'help',
cmd = cmd,
}
elseif not cmd.disable_version and term.name == "--version" then
coroutine.yield {
action = 'version',
cmd = cmd,
}
else
if term.plural then
if opts[term.name] ~= nil then
table.insert(opts[term.name], value)
else
opts[term.name] = {value}
end
else
if opts[term.name] ~= nil then
coroutine.yield {
action = 'error',
message = string.format("supplied multiple values for option '%s'", term.name)
}
else
opts[term.name] = value
end
end
end
elseif term.type == 'arg' then
if has_subs then
-- First (and the only) argument for the command with subcommands is a
-- subcommand name. Lookup subcommand and continue running with the
-- subcommand.
local next_cmd = cmd.lookup.subs[value]
if next_cmd == nil then
coroutine.yield {
action = 'error',
message = string.format("unknown subcommand '%s'", value)
}
else
coroutine.yield {
action = 'value',
term = cmd.term,
opts = opts,
args = {},
}
return run(next_cmd, idx)
end
else
if term.plural then
if type(args[#args]) == 'table' then
table.insert(args[#args], value)
else
table.insert(args, {value})
end
else
table.insert(args, value)
end
end
end
end
if has_subs and cmd.lookup.subs.required then
coroutine.yield {
action = 'error',
message = 'missing a subcommand',
}
else
local next_arg = cmd.lookup.args[#args + 1]
if next_arg and next_arg.required then
coroutine.yield {
action = 'error',
message = string.format("missing a required argument '%s'", next_arg.name),
}
end
coroutine.yield {
action = 'value',
term = cmd.term,
opts = opts,
args = args,
}
end
end
local values = {n = 0}
local errors = {}
local show_version, show_help
do
local co = coroutine.create(function() run(cmd) end)
while coroutine.status(co) ~= 'dead' do
local ok, val = coroutine.resume(co)
if not ok then
error(val .. '\n' .. debug.traceback(co))
elseif not val then
-- do nothing
elseif val.action == 'value' then
values.n = values.n + 1
values[values.n] = val
elseif val.action == 'error' then
table.insert(errors, val.message)
elseif val.action == 'help' then
show_help = {cmd = val.cmd}
elseif val.action == 'version' then
show_version = {cmd = val.cmd}
elseif val.action == 'completion' then
assert(is_completion)
return 'completion', val.cmd:completion(line.cword, val.opt, val.arg)
else
assert(false)
end
end
end
if show_help ~= nil then
return 'help', show_help
end
if show_version ~= nil then
return 'version', show_version
end
if #errors > 0 then
return 'error', errors[1]
end
do
local co = coroutine.create(function()
for i=1,values.n do
local val = values[i]
values[i] = eval(val.term, val.opts, val.args)
end
end)
local status, val = true, nil
while coroutine.status(co) ~= 'dead' do
local ok, val = coroutine.resume(co)
if not ok then
error(val .. '\n' .. debug.traceback(co))
elseif not val then
-- do nothing
elseif val.action == 'error' then
return 'error', val.message
else
assert(false)
end
end
end
local function unwind(i)
if i > values.n then return nil
else return values[i], unwind(i + 1)
end
end
return 'value', unwind(1)
end
local function run(cmd, line)
line = line or argv
-- Print shell completion script onto stdout and exit
local comp_prog = os.getenv "COMP_PROG"
local comp_shell = os.getenv "COMP_SHELL"
if comp_prog ~= nil and comp_shell ~=nil then
if comp_shell == "zsh" then
print((ZSH_COMPLETION_SCRIPT:gsub("NAME", comp_prog)))
elseif comp_shell == "bash" then
print((BASH_COMPLETION_SCRIPT:gsub("NAME", comp_prog)))
end
os.exit(0)
end
-- Check if we are running completion
do
local comp_words = os.getenv "COMP_WORDS"
local comp_cword = tonumber(os.getenv "COMP_CWORD")
if comp_words ~= nil and comp_cword ~= nil then
line = shell_split(comp_words)
line.cword = line[comp_cword] or ""
line[comp_cword] = __COMPLETE__
table.remove(line, 1)
end
end
local function handle(type, v, ...)
if type == 'value' then
return v, ...
elseif type == 'error' then
cmd:print_error(v)
os.exit(1)
elseif type == 'help' then
v.cmd:print_help()
os.exit(0)
elseif type == 'version' then
v.cmd:print_version()
os.exit(0)
elseif type == 'completion' then
for type, name, desc in v do
print(type); print(name or ""); print(desc or "")
end
os.exit(0)
end
end
return handle(parse_and_eval(cmd, line))
end
-- }}}
--
-- TERMS
--
-- Terms form an algebra, there are primitive terms and then term compositions
-- (which are also terms!) so you can compose more complex terms out of simpler
-- terms.
--
-- The primitive terms represent command line options and arguments.
--
-- {{{
local app
local cmd
--
-- Completion function which completes nothing.
--
local empty_complete = function()
return pairs {}
end
--
-- Completion functions which completes filenames.
--
local file_complete = function()
local e = false
return function()
if not e then
e = true
return "file", nil, nil
end
end
end
--
-- Completion functions which completes dirnames.
--
local dir_complete = function()
local e = false
return function()
if not e then
e = true
return "dir", nil, nil
end
end
end
local term_mt = {__index = {}}
function term_mt.__index:and_then(func)
return app(func, self)
end
function term_mt.__index:parse_and_eval(line)
local command = cmd { term = self, disable_help = true, disable_version = true }
return parse_and_eval(command, line)
end
--
-- Construct a term out of Lua value.
--
local function val(v)
return setmetatable({type = 'val', v = v}, term_mt)
end
--
-- Construct a term which applies a given Lua function to the result of
-- evaluating argument terms.
--
function app(func, ...)
return setmetatable({type = 'app', func = func, args = {...}}, term_mt)
end
--
-- Construct a term which evaluates into a table.
--
local function all(spec)
return setmetatable({ type = 'all', spec = spec }, term_mt)
end
--
-- Construct a term which represents a command line option.
--
function opt(spec)
if type(spec) == 'string' then spec = {spec} end
assert(
not (spec.flag and spec.plural),
"opt { plural = true, flag = true, ...} does not make sense"
)
-- add '-' short options or '--' for long options
local names = {}
for _, n in ipairs(spec) do
if not n:sub(1, 1) == "-" then
if #n == 1 then
n = "-" .. n
else
n = "--" .. n
end
end
table.insert(names, n)
end
local complete
if spec.complete == "file" then
complete = file_complete
elseif spec.complete == "dir" then
complete = dir_complete
elseif spec.complete then
complete = spec.complete
else
complete = empty_complete
end
return setmetatable({
type = 'opt',
name = names[1],
names = names,
desc = spec.desc or "NOT DOCUMENTED",
vdesc = spec.vdesc or 'VALUE',
flag = spec.flag or false,
plural = spec.plural or false,
complete = complete,
}, term_mt)
end
--
-- Construct a term which represents a command line positional argument.
--
function arg(spec)
local name
if type(spec) == 'string' then
name = spec
else
name = spec[1]
end
assert(name, "missing arg name")
local complete
if spec.complete == "file" then
complete = file_complete
elseif spec.complete == "dir" then
complete = dir_complete
elseif spec.complete then
complete = spec.complete
else
complete = empty_complete
end
local required = false
if spec.required == nil or spec.required then
required = true
end
return setmetatable({
type = 'arg',
name = name,
desc = spec.desc or "NOT DOCUMENTED",
complete = complete,
required = required,
plural = spec.plural or false,
}, term_mt)
end
-- }}}
--
-- COMMANDS
--
-- A command wraps a term and adds some convenience like automatic parsing and
-- processing of --help and --version options, handling of user errors.
--
-- Commands can be contain other subcommands enabling command line interfaces
-- like git or kubectl which became popular recently.
--
-- {{{
local help_opt = opt {
'--help', '-h',
flag = true,
desc = 'Show this message and exit',
}
local version_opt = opt {
'--version',
flag = true,
desc = 'Print version and exit',
}
local cmd_mt = {
__index = {
run = run,
parse_and_eval = parse_and_eval,
print_error = function(self, err)
io.stderr:write(string.format("%s: error: %s\n", self.name, err))
end,
print_version = function(self)
print(self.version)
end,
print_help = function(self)
local function print_tabular(rows, opts)
opts = opts or {}
local margin = opts.margin or 2
local width = {}
for _, row in ipairs(rows) do
for i, col in ipairs(row) do
if #col > (width[i] or 0) then width[i] = #col end
end
end
for _, row in ipairs(rows) do
local line = ''
local prev_col_width = 0
for i, col in ipairs(row) do
local padding = (' '):rep((width[i - 1] or 0) - prev_col_width + margin)
line = line .. padding .. col
prev_col_width = #col
end
print(line)
end
end
if self.version and self.name then
print(string.format("%s v%s", self.name, self.version))
elseif self.name then
print(self.name)
end
if self.desc then
print("")
print(self.desc)
end
local opts, args = {}, {}
for item in primitives(self.term) do
if item.type == 'opt' then
table.insert(opts, item)
elseif item.type == 'arg' then
table.insert(args, item)
end
end
if not self.disable_help then
table.insert(opts, help_opt)
end
if not self.disable_version then
table.insert(opts, version_opt)
end
if #opts > 0 then
print('\nOptions:')
local rows = {}
for _, o in ipairs(opts) do
local name = table.concat(o.names, ',')
if not o.flag then
name = name .. ' ' .. o.vdesc
end
table.insert(rows, {name, o.desc})
end
print_tabular(rows)
end
if not self.subs and #args > 0 then
print('\nArguments:')
local rows = {}
for _, a in ipairs(args) do
table.insert(rows, {a.name, a.desc})
end
print_tabular(rows)
end
if self.subs and #self.subs > 0 then
print('\nCommands:')
local rows = {}
for _, c in ipairs(self.subs) do
local name = table.concat(c.names, ',')
table.insert(rows, {name, c.desc or ''})
end
print_tabular(rows)
end
end,
completion = function(self, cword, opt, arg)
local co = coroutine.create(function()
local function out(type, name, desc)
if name == nil or name:sub(1, #cword) == cword then
coroutine.yield(type, name, desc)
end
end
local function complete_term(term)
for type, name, desc in term.complete(cword) do
out(type, name, desc)
end
end
if opt then
-- Complete option value
-- TODO(andreypopp): handle `--name=value` syntax here
complete_term(opt)
else
if self.subs and #self.subs > 0 then
-- Complete subcommands
for _, c in ipairs(self.subs) do
out("item", c.name, c.desc)
end
end
-- Finally complete option names
for t in primitives(self.term) do
if t.type == 'opt' then
out("item", t.name, t.desc)
end
end
if not self.disable_help then
out("item", help_opt.name, help_opt.desc)
end
if not self.disable_version then
out("item", version_opt.name, version_opt.desc)
end
if arg then
-- Complete argument value
complete_term(arg)
end
end
end)
return function()
local ok, type, name, desc = coroutine.resume(co)
if not ok then error(type) end
return type, name, desc
end
end,
}
}
function cmd(spec)
if type(spec) == 'string' then
spec = {spec}
end
-- Build an lookup for args, opts and subcommands.
local lookup = {}
do
local opts, args, subs = {}, {}, {}
for term in primitives(spec.term) do
if term.type == 'opt' then
for _, n in ipairs(term.names) do
opts[n] = term
end
elseif term.type == 'arg' then
table.insert(args, term)
end
end
if spec.subs and #spec.subs > 0 then
subs.required = false
if spec.subs.required == nil or spec.subs.required then
subs.required = true
end
for _, s in ipairs(spec.subs) do
for _, n in ipairs(s.names) do
subs[n] = s
end
end
end
if subs and next(subs) ~= nil then
assert(#args == 0, "a command with subcommands cannot accept arguments")
table.insert(args, arg {'subcommand', required = subs.required})
end
if not spec.disable_help then
opts['--help'] = help_opt
opts['-h'] = help_opt
end
if not spec.disable_version then
opts['--version'] = version_opt
end
lookup.opts = opts
lookup.args = args
lookup.subs = subs
end
local names = {}
for _, n in ipairs(spec) do
table.insert(names, n)
end
return setmetatable({
name = names[1],
names = names,
version = spec.version or "0.0.0",
desc = spec.desc or "NOT DOCUMENTED",
term = spec.term or val(),
subs = spec.subs,
lookup = lookup,
disable_help = spec.disable_help,
disable_version = spec.disable_version,
}, cmd_mt)
end
-- }}}
--
-- EXPORTS
--
return {
cmd = cmd,
opt = opt,
arg = arg,
all = all,
app = app,
val = val,
err = err,
-- This is exported for testing purposes only
__COMPLETE__ = __COMPLETE__,
shell_split = shell_split,
}