Diff

plugins/mod_admin_shell.lua @ 13739:347991cd1cc3

Merge 13.0->trunk
author Matthew Wild <mwild1@gmail.com>
date Mon, 17 Feb 2025 19:22:54 +0000
parent 13737:46e7cc4de5e6
child 13769:5cc4a3e0335c
line wrap: on
line diff
--- a/plugins/mod_admin_shell.lua	Mon Feb 17 12:37:58 2025 +0100
+++ b/plugins/mod_admin_shell.lua	Mon Feb 17 19:22:54 2025 +0000
@@ -19,6 +19,7 @@
 local server = require "prosody.net.server";
 local schema = require "prosody.util.jsonschema";
 local st = require "prosody.util.stanza";
+local parse_args = require "prosody.util.argparse".parse;
 
 local _G = _G;
 
@@ -255,6 +256,83 @@
 	return session;
 end
 
+local function process_cmd_line(session, arg_line)
+	local chunk = load("return "..arg_line, "=shell", "t", {});
+	local ok, args = pcall(chunk);
+	if not ok then return nil, args; end
+
+	local section_name, command = args[1], args[2];
+
+	local section_mt = getmetatable(def_env[section_name]);
+	local section_help = section_mt and section_mt.help;
+	local command_help = section_help and section_help.commands[command];
+
+	if not command_help then
+		if commands[section_name] then
+			commands[section_name](session, table.concat(args, " "));
+			return;
+		end
+		if section_help then
+			return nil, "Command not found or necessary module not loaded. Try 'help "..section_name.." for a list of available commands.";
+		end
+		return nil, "Command not found. Is the necessary module loaded?";
+	end
+
+	local fmt = { "%s"; ":%s("; ")" };
+
+	if command_help.flags then
+		local flags, flags_err, flags_err_extra = parse_args(args, command_help.flags);
+		if not flags then
+			if flags_err == "missing-value" then
+				return nil, "Expected value after "..flags_err_extra;
+			elseif flags_err == "param-not-found" then
+				return nil, "Unknown parameter: "..flags_err_extra;
+			end
+			return nil, flags_err;
+		end
+
+		table.remove(flags, 2);
+		table.remove(flags, 1);
+
+		local n_fixed_args = #command_help.args;
+
+		local arg_str = {};
+		for i = 1, n_fixed_args do
+			if flags[i] ~= nil then
+				table.insert(arg_str, ("%q"):format(flags[i]));
+			else
+				table.insert(arg_str, "nil");
+			end
+		end
+
+		table.insert(arg_str, "flags");
+
+		for i = n_fixed_args + 1, #flags do
+			if flags[i] ~= nil then
+				table.insert(arg_str, ("%q"):format(flags[i]));
+			else
+				table.insert(arg_str, "nil");
+			end
+		end
+
+		table.insert(fmt, 3, "%s");
+
+		return "local flags = ...; return "..string.format(table.concat(fmt), section_name, command, table.concat(arg_str, ", ")), flags;
+	end
+
+	for i = 3, #args do
+		if args[i]:sub(1, 1) == ":" then
+			table.insert(fmt, i, ")%s(");
+		elseif i > 3 and fmt[i - 1]:match("%%q$") then
+			table.insert(fmt, i, ", %q");
+		else
+			table.insert(fmt, i, "%q");
+		end
+	end
+
+	return "return "..string.format(table.concat(fmt), table.unpack(args));
+end
+
 local function handle_line(event)
 	local session = event.origin.shell_session;
 	if not session then
@@ -295,23 +373,6 @@
 		session.globalenv = redirect_output(_G, session);
 	end
 
-	local chunkname = "=console";
-	local env = (useglobalenv and session.globalenv) or session.env or nil
-	-- luacheck: ignore 311/err
-	local chunk, err = envload("return "..line, chunkname, env);
-	if not chunk then
-		chunk, err = envload(line, chunkname, env);
-		if not chunk then
-			err = err:gsub("^%[string .-%]:%d+: ", "");
-			err = err:gsub("^:%d+: ", "");
-			err = err:gsub("'<eof>'", "the end of the line");
-			result.attr.type = "error";
-			result:text("Sorry, I couldn't understand that... "..err);
-			event.origin.send(result);
-			return;
-		end
-	end
-
 	local function send_result(taskok, message)
 		if not message then
 			if type(taskok) ~= "string" and useglobalenv then
@@ -328,7 +389,49 @@
 		event.origin.send(result);
 	end
 
-	local taskok, message = chunk();
+	local taskok, message;
+	local env = (useglobalenv and session.globalenv) or session.env or nil;
+	local flags;
+
+	local source;
+	if line:match("^{") then
+		-- Input is a serialized array of strings, typically from
+		-- a command-line invocation of 'prosodyctl shell something'
+		source, flags = process_cmd_line(session, line);
+		if not source then
+			if flags then -- err
+				send_result(false, flags);
+			else -- no err, but nothing more to do
+				-- This happens if it was a "simple" command
+				event.origin.send(result);
+			end
+			return;
+		end
+	end
+
+	local chunkname = "=console";
+	-- luacheck: ignore 311/err
+	local chunk, err = envload(source or ("return "..line), chunkname, env);
+	if not chunk then
+		if not source then
+			chunk, err = envload(line, chunkname, env);
+		end
+		if not chunk then
+			err = err:gsub("^%[string .-%]:%d+: ", "");
+			err = err:gsub("^:%d+: ", "");
+			err = err:gsub("'<eof>'", "the end of the line");
+			result.attr.type = "error";
+			result:text("Sorry, I couldn't understand that... "..err);
+			event.origin.send(result);
+			return;
+		end
+	end
+
+	if not source then
+		session.repl = true;
+	end
+
+	taskok, message = chunk(flags);
 
 	if promise.is_promise(taskok) then
 		taskok:next(function (resolved_message)
@@ -462,9 +565,43 @@
 				for command, command_help in it.sorted_pairs(section_help.commands or {}) do
 					if not command_help.hidden then
 						c = c + 1;
-						local args = array.pluck(command_help.args, "name"):concat(", ");
 						local desc = command_help.desc or command_help.module and ("Provided by mod_"..command_help.module) or "";
-						print(("%s:%s(%s) - %s"):format(section_name, command, args, desc));
+						if self.session.repl then
+							local args = array.pluck(command_help.args, "name"):concat(", ");
+							print(("%s:%s(%s) - %s"):format(section_name, command, args, desc));
+						else
+							local args = array.pluck(command_help.args, "name"):concat("> <");
+							if args ~= "" then
+								args = "<"..args..">";
+							end
+							print(("%s %s %s"):format(section_name, command, args));
+							print(("    %s"):format(desc));
+							if command_help.flags then
+								local flags = command_help.flags;
+								print("");
+								print(("    Flags:"));
+
+								if flags.kv_params then
+									for name in it.sorted_pairs(flags.kv_params) do
+										print("      --"..name:gsub("_", "-"));
+									end
+								end
+
+								if flags.value_params then
+									for name in it.sorted_pairs(flags.value_params) do
+										print("      --"..name:gsub("_", "-").." <"..name..">");
+									end
+								end
+
+								if flags.array_params then
+									for name in it.sorted_pairs(flags.array_params) do
+										print("      --"..name:gsub("_", "-").." <"..name..">, ...");
+									end
+								end
+
+							end
+							print("");
+						end
 					end
 				end
 			elseif help_topics[section_name] then
@@ -2641,10 +2778,20 @@
 			section_mt.help = section_help;
 		end
 
+		if command.flags then
+			if command.flags.stop_on_positional == nil then
+				command.flags.stop_on_positional = false;
+			end
+			if command.flags.strict == nil then
+				command.flags.strict = true;
+			end
+		end
+
 		section_help.commands[command.name] = {
 			desc = command.desc;
 			full = command.help;
 			args = array(command.args);
+			flags = command.flags;
 			module = command._provided_by;
 		};