Changeset

13739:347991cd1cc3

Merge 13.0->trunk
author Matthew Wild <mwild1@gmail.com>
date Mon, 17 Feb 2025 19:22:54 +0000
parents 13730:c653c1d3e8da (current diff) 13738:26a0f653793e (diff)
children 13742:47e537e340c4
files
diffstat 5 files changed, 324 insertions(+), 177 deletions(-) [+]
line wrap: on
line diff
--- a/plugins/mod_admin_shell.lua	Mon Feb 17 12:37:58 2025 +0100
+++ b/plugins/mod_admin_shell.lua	Mon Feb 17 19:22:54 2025 +0000
@@ -19,6 +19,7 @@
 local server = require "prosody.net.server";
 local schema = require "prosody.util.jsonschema";
 local st = require "prosody.util.stanza";
+local parse_args = require "prosody.util.argparse".parse;
 
 local _G = _G;
 
@@ -255,6 +256,83 @@
 	return session;
 end
 
+local function process_cmd_line(session, arg_line)
+	local chunk = load("return "..arg_line, "=shell", "t", {});
+	local ok, args = pcall(chunk);
+	if not ok then return nil, args; end
+
+	local section_name, command = args[1], args[2];
+
+	local section_mt = getmetatable(def_env[section_name]);
+	local section_help = section_mt and section_mt.help;
+	local command_help = section_help and section_help.commands[command];
+
+	if not command_help then
+		if commands[section_name] then
+			commands[section_name](session, table.concat(args, " "));
+			return;
+		end
+		if section_help then
+			return nil, "Command not found or necessary module not loaded. Try 'help "..section_name.." for a list of available commands.";
+		end
+		return nil, "Command not found. Is the necessary module loaded?";
+	end
+
+	local fmt = { "%s"; ":%s("; ")" };
+
+	if command_help.flags then
+		local flags, flags_err, flags_err_extra = parse_args(args, command_help.flags);
+		if not flags then
+			if flags_err == "missing-value" then
+				return nil, "Expected value after "..flags_err_extra;
+			elseif flags_err == "param-not-found" then
+				return nil, "Unknown parameter: "..flags_err_extra;
+			end
+			return nil, flags_err;
+		end
+
+		table.remove(flags, 2);
+		table.remove(flags, 1);
+
+		local n_fixed_args = #command_help.args;
+
+		local arg_str = {};
+		for i = 1, n_fixed_args do
+			if flags[i] ~= nil then
+				table.insert(arg_str, ("%q"):format(flags[i]));
+			else
+				table.insert(arg_str, "nil");
+			end
+		end
+
+		table.insert(arg_str, "flags");
+
+		for i = n_fixed_args + 1, #flags do
+			if flags[i] ~= nil then
+				table.insert(arg_str, ("%q"):format(flags[i]));
+			else
+				table.insert(arg_str, "nil");
+			end
+		end
+
+		table.insert(fmt, 3, "%s");
+
+		return "local flags = ...; return "..string.format(table.concat(fmt), section_name, command, table.concat(arg_str, ", ")), flags;
+	end
+
+	for i = 3, #args do
+		if args[i]:sub(1, 1) == ":" then
+			table.insert(fmt, i, ")%s(");
+		elseif i > 3 and fmt[i - 1]:match("%%q$") then
+			table.insert(fmt, i, ", %q");
+		else
+			table.insert(fmt, i, "%q");
+		end
+	end
+
+	return "return "..string.format(table.concat(fmt), table.unpack(args));
+end
+
 local function handle_line(event)
 	local session = event.origin.shell_session;
 	if not session then
@@ -295,23 +373,6 @@
 		session.globalenv = redirect_output(_G, session);
 	end
 
-	local chunkname = "=console";
-	local env = (useglobalenv and session.globalenv) or session.env or nil
-	-- luacheck: ignore 311/err
-	local chunk, err = envload("return "..line, chunkname, env);
-	if not chunk then
-		chunk, err = envload(line, chunkname, env);
-		if not chunk then
-			err = err:gsub("^%[string .-%]:%d+: ", "");
-			err = err:gsub("^:%d+: ", "");
-			err = err:gsub("'<eof>'", "the end of the line");
-			result.attr.type = "error";
-			result:text("Sorry, I couldn't understand that... "..err);
-			event.origin.send(result);
-			return;
-		end
-	end
-
 	local function send_result(taskok, message)
 		if not message then
 			if type(taskok) ~= "string" and useglobalenv then
@@ -328,7 +389,49 @@
 		event.origin.send(result);
 	end
 
-	local taskok, message = chunk();
+	local taskok, message;
+	local env = (useglobalenv and session.globalenv) or session.env or nil;
+	local flags;
+
+	local source;
+	if line:match("^{") then
+		-- Input is a serialized array of strings, typically from
+		-- a command-line invocation of 'prosodyctl shell something'
+		source, flags = process_cmd_line(session, line);
+		if not source then
+			if flags then -- err
+				send_result(false, flags);
+			else -- no err, but nothing more to do
+				-- This happens if it was a "simple" command
+				event.origin.send(result);
+			end
+			return;
+		end
+	end
+
+	local chunkname = "=console";
+	-- luacheck: ignore 311/err
+	local chunk, err = envload(source or ("return "..line), chunkname, env);
+	if not chunk then
+		if not source then
+			chunk, err = envload(line, chunkname, env);
+		end
+		if not chunk then
+			err = err:gsub("^%[string .-%]:%d+: ", "");
+			err = err:gsub("^:%d+: ", "");
+			err = err:gsub("'<eof>'", "the end of the line");
+			result.attr.type = "error";
+			result:text("Sorry, I couldn't understand that... "..err);
+			event.origin.send(result);
+			return;
+		end
+	end
+
+	if not source then
+		session.repl = true;
+	end
+
+	taskok, message = chunk(flags);
 
 	if promise.is_promise(taskok) then
 		taskok:next(function (resolved_message)
@@ -462,9 +565,43 @@
 				for command, command_help in it.sorted_pairs(section_help.commands or {}) do
 					if not command_help.hidden then
 						c = c + 1;
-						local args = array.pluck(command_help.args, "name"):concat(", ");
 						local desc = command_help.desc or command_help.module and ("Provided by mod_"..command_help.module) or "";
-						print(("%s:%s(%s) - %s"):format(section_name, command, args, desc));
+						if self.session.repl then
+							local args = array.pluck(command_help.args, "name"):concat(", ");
+							print(("%s:%s(%s) - %s"):format(section_name, command, args, desc));
+						else
+							local args = array.pluck(command_help.args, "name"):concat("> <");
+							if args ~= "" then
+								args = "<"..args..">";
+							end
+							print(("%s %s %s"):format(section_name, command, args));
+							print(("    %s"):format(desc));
+							if command_help.flags then
+								local flags = command_help.flags;
+								print("");
+								print(("    Flags:"));
+
+								if flags.kv_params then
+									for name in it.sorted_pairs(flags.kv_params) do
+										print("      --"..name:gsub("_", "-"));
+									end
+								end
+
+								if flags.value_params then
+									for name in it.sorted_pairs(flags.value_params) do
+										print("      --"..name:gsub("_", "-").." <"..name..">");
+									end
+								end
+
+								if flags.array_params then
+									for name in it.sorted_pairs(flags.array_params) do
+										print("      --"..name:gsub("_", "-").." <"..name..">, ...");
+									end
+								end
+
+							end
+							print("");
+						end
 					end
 				end
 			elseif help_topics[section_name] then
@@ -2641,10 +2778,20 @@
 			section_mt.help = section_help;
 		end
 
+		if command.flags then
+			if command.flags.stop_on_positional == nil then
+				command.flags.stop_on_positional = false;
+			end
+			if command.flags.strict == nil then
+				command.flags.strict = true;
+			end
+		end
+
 		section_help.commands[command.name] = {
 			desc = command.desc;
 			full = command.help;
 			args = array(command.args);
+			flags = command.flags;
 			module = command._provided_by;
 		};
 
--- a/plugins/mod_invites.lua	Mon Feb 17 12:37:58 2025 +0100
+++ b/plugins/mod_invites.lua	Mon Feb 17 19:22:54 2025 +0000
@@ -244,13 +244,38 @@
 	section_desc = "Create and manage invitations";
 	name = "create_account";
 	desc = "Create an invitation to make an account on this server with the specified JID (supply only a hostname to allow any username)";
-	args = { { name = "user_jid", type = "string" } };
+	args = {
+		{ name = "user_jid", type = "string" };
+	};
 	host_selector = "user_jid";
+	flags = {
+		array_params = { role = true, group = true };
+		value_params = { expires_after = true };
+	};
+
+	handler = function (self, user_jid, opts) --luacheck: ignore 212/self
+		local username = jid_split(user_jid);
+		local roles = opts.role or {};
+		local groups = opts.group or {};
 
-	handler = function (self, user_jid) --luacheck: ignore 212/self
-		local username = jid_split(user_jid);
-		local invite, err = create_account(username);
-		if not invite then return nil, err; end
+		if opts.admin then
+			-- Insert it first since we don't get order out of argparse
+			table.insert(roles, 1, "prosody:admin");
+		end
+
+		local ttl;
+		if opts.expires_after then
+			ttl = human_io.parse_duration(opts.expires_after);
+			if not ttl then
+				return false, "Unable to parse duration: "..opts.expires_after;
+			end
+		end
+
+		local invite = assert(create_account(username, {
+			roles = roles;
+			groups = groups;
+		}, ttl));
+
 		return true, invite.landing_page or invite.uri;
 	end;
 });
@@ -260,12 +285,21 @@
 	section_desc = "Create and manage invitations";
 	name = "create_reset";
 	desc = "Create a password reset link for the specified user";
-	args = { { name = "user_jid", type = "string" }, { name = "duration", type = "string" } };
+	args = { { name = "user_jid", type = "string" } };
 	host_selector = "user_jid";
+	flags = {
+		value_params = { expires_after = true };
+	};
 
-	handler = function (self, user_jid, duration) --luacheck: ignore 212/self
+	handler = function (self, user_jid, opts) --luacheck: ignore 212/self
 		local username = jid_split(user_jid);
-		local duration_sec = require "prosody.util.human.io".parse_duration(duration or "1d");
+		if not username then
+			return nil, "Supply the JID of the account you want to generate a password reset for";
+		end
+		local duration_sec = require "prosody.util.human.io".parse_duration(opts and opts.expires_after or "1d");
+		if not duration_sec then
+			return nil, "Unable to parse duration: "..opts.expires_after;
+		end
 		local invite, err = create_account_reset(username, duration_sec);
 		if not invite then return nil, err; end
 		self.session.print(invite.landing_page or invite.uri);
@@ -278,12 +312,26 @@
 	section_desc = "Create and manage invitations";
 	name = "create_contact";
 	desc = "Create an invitation to become contacts with the specified user";
-	args = { { name = "user_jid", type = "string" }, { name = "allow_registration" } };
+	args = { { name = "user_jid", type = "string" } };
 	host_selector = "user_jid";
+	flags = {
+		value_params = { expires_after = true };
+		kv_params = { allow_registration = true };
+	};
 
-	handler = function (self, user_jid, allow_registration) --luacheck: ignore 212/self
+	handler = function (self, user_jid, opts) --luacheck: ignore 212/self
 		local username = jid_split(user_jid);
-		local invite, err = create_contact(username, allow_registration);
+		if not username then
+			return nil, "Supply the JID of the account you want the recipient to become a contact of";
+		end
+		local ttl;
+		if opts.expires_after then
+			ttl = require "prosody.util.human.io".parse_duration(opts.expires_after);
+			if not ttl then
+				return nil, "Unable to parse duration: "..opts.expires_after;
+			end
+		end
+		local invite, err = create_contact(username, opts.allow_registration, nil, ttl);
 		if not invite then return nil, err; end
 		return true, invite.landing_page or invite.uri;
 	end;
@@ -442,102 +490,7 @@
 	return subcommands[cmd](arg);
 end
 
-function subcommands.generate(arg)
-	local function help(short)
-		print("usage: prosodyctl mod_" .. module.name .. " generate DOMAIN --reset USERNAME")
-		print("usage: prosodyctl mod_" .. module.name .. " generate DOMAIN [--admin] [--role ROLE] [--group GROUPID]...")
-		if short then return 2 end
-		print()
-		print("This command has two modes: password reset and new account.")
-		print("If --reset is given, the command operates in password reset mode and in new account mode otherwise.")
-		print()
-		print("required arguments in password reset mode:")
-		print()
-		print("    --reset USERNAME  Generate a password reset link for the given USERNAME.")
-		print()
-		print("optional arguments in new account mode:")
-		print()
-		print("    --admin           Make the new user privileged")
-		print("                      Equivalent to --role prosody:admin")
-		print("    --role ROLE       Grant the given ROLE to the new user")
-		print("    --group GROUPID   Add the user to the group with the given ID")
-		print("                      Can be specified multiple times")
-		print("    --expires-after T Time until the invite expires (e.g. '1 week')")
-		print()
-		print("--group can be specified multiple times; the user will be added to all groups.")
-		print()
-		print("--reset and the other options cannot be mixed.")
-		return 2
-	end
-
-	local earlyopts = argparse.parse(arg, { short_params = { h = "help"; ["?"] = "help" } });
-	if earlyopts.help or not earlyopts[1] then
-		return help();
-	end
-
-	local sm = require "prosody.core.storagemanager";
-	local mm = require "prosody.core.modulemanager";
-
-	local host = table.remove(arg, 1); -- pop host
-	if not host then return help(true) end
-	sm.initialize_host(host);
-	module.host = host; --luacheck: ignore 122/module
-	token_storage = module:open_store("invite_token", "map");
-
-	local opts = argparse.parse(arg, {
-		short_params = { h = "help"; ["?"] = "help"; g = "group" };
-		value_params = { group = true; reset = true; role = true };
-		array_params = { group = true; role = true };
-	});
-
-	if opts.help then
-		return help();
-	end
-
-	-- Load mod_invites
-	local invites = module:depends("invites");
-	-- Optional community module that if used, needs to be loaded here
-	local invites_page_module = module:get_option_string("invites_page_module", "invites_page");
-	if mm.get_modules_for_host(host):contains(invites_page_module) then
-		module:depends(invites_page_module);
-	end
-
-	local allow_reset;
-
-	if opts.reset then
-		local nodeprep = require "prosody.util.encodings".stringprep.nodeprep;
-		local username = nodeprep(opts.reset)
-		if not username then
-			print("Please supply a valid username to generate a reset link for");
-			return 2;
-		end
-		allow_reset = username;
-	end
-
-	local roles = opts.role or {};
-	local groups = opts.group or {};
-
-	if opts.admin then
-		-- Insert it first since we don't get order out of argparse
-		table.insert(roles, 1, "prosody:admin");
-	end
-
-	local invite;
-	if allow_reset then
-		if roles[1] then
-			print("--role/--admin and --reset are mutually exclusive")
-			return 2;
-		end
-		if #groups > 0 then
-			print("--group and --reset are mutually exclusive")
-		end
-		invite = assert(invites.create_account_reset(allow_reset));
-	else
-		invite = assert(invites.create_account(nil, {
-			roles = roles,
-			groups = groups
-		}, opts.expires_after and human_io.parse_duration(opts.expires_after)));
-	end
-
-	print(invite.landing_page or invite.uri);
+function subcommands.generate()
+	print("This command is deprecated. Please see 'prosodyctl shell help invite' for available commands.");
+	return 1;
 end
--- a/spec/util_argparse_spec.lua	Mon Feb 17 12:37:58 2025 +0100
+++ b/spec/util_argparse_spec.lua	Mon Feb 17 19:22:54 2025 +0000
@@ -24,10 +24,28 @@
 		assert.same({ "bar"; "--baz" }, arg);
 	end);
 
+	it("allows continuation beyond first positional argument", function()
+		local arg = { "--foo"; "bar"; "--baz" };
+		local opts, err = parse(arg, { stop_on_positional = false });
+		assert.falsy(err);
+		assert.same({ foo = true, baz = true, "bar" }, opts);
+		-- All input should have been consumed:
+		assert.same({ }, arg);
+	end);
+
 	it("expands short options", function()
-		local opts, err = parse({ "--foo"; "-b" }, { short_params = { b = "bar" } });
-		assert.falsy(err);
-		assert.same({ foo = true; bar = true }, opts);
+		do
+			local opts, err = parse({ "--foo"; "-b" }, { short_params = { b = "bar" } });
+			assert.falsy(err);
+			assert.same({ foo = true; bar = true }, opts);
+		end
+
+		do
+			-- Same test with strict mode enabled and all parameters declared
+			local opts, err = parse({ "--foo"; "-b" }, { kv_params = { foo = true, bar = true }; short_params = { b = "bar" }, strict = true });
+			assert.falsy(err);
+			assert.same({ foo = true; bar = true }, opts);
+		end
 	end);
 
 	it("supports value arguments", function()
@@ -51,8 +69,30 @@
 	end);
 
 	it("supports array arguments", function ()
-		local opts, err = parse({ "--item"; "foo"; "--item"; "bar" }, { array_params = { item = true } });
+		do
+			local opts, err = parse({ "--item"; "foo"; "--item"; "bar" }, { array_params = { item = true } });
+			assert.falsy(err);
+			assert.same({"foo","bar"}, opts.item);
+		end
+
+		do
+			-- Same test with strict mode enabled
+			local opts, err = parse({ "--item"; "foo"; "--item"; "bar" }, { array_params = { item = true }, strict = true });
+			assert.falsy(err);
+			assert.same({"foo","bar"}, opts.item);
+		end
+	end)
+
+	it("rejects unknown parameters in strict mode", function ()
+		local opts, err, err2 = parse({ "--item"; "foo"; "--item"; "bar", "--foobar" }, { array_params = { item = true }, strict = true });
+		assert.falsy(opts);
+		assert.same("param-not-found", err);
+		assert.same("--foobar", err2);
+	end);
+
+	it("accepts known kv parameters in strict mode", function ()
+		local opts, err = parse({ "--item=foo" }, { kv_params = { item = true }, strict = true });
 		assert.falsy(err);
-		assert.same({"foo","bar"}, opts.item);
-	end)
+		assert.same("foo", opts.item);
+	end);
 end);
--- a/util/argparse.lua	Mon Feb 17 12:37:58 2025 +0100
+++ b/util/argparse.lua	Mon Feb 17 19:22:54 2025 +0000
@@ -2,6 +2,9 @@
 	local short_params = config and config.short_params or {};
 	local value_params = config and config.value_params or {};
 	local array_params = config and config.array_params or {};
+	local kv_params = config and config.kv_params or {};
+	local strict = config and config.strict;
+	local stop_on_positional = not config or config.stop_on_positional ~= false;
 
 	local parsed_opts = {};
 
@@ -15,51 +18,65 @@
 		end
 
 		local prefix = raw_param:match("^%-%-?");
-		if not prefix then
+		if not prefix and stop_on_positional then
 			break;
 		elseif prefix == "--" and raw_param == "--" then
 			table.remove(arg, 1);
 			break;
 		end
-		local param = table.remove(arg, 1):sub(#prefix+1);
-		if #param == 1 and short_params then
-			param = short_params[param];
-		end
 
-		if not param then
-			return nil, "param-not-found", raw_param;
-		end
+		if prefix then
+			local param = table.remove(arg, 1):sub(#prefix+1);
+			if #param == 1 and short_params then
+				param = short_params[param];
+			end
+
+			if not param then
+				return nil, "param-not-found", raw_param;
+			end
+
+			local uparam = param:match("^[^=]*"):gsub("%-", "_");
 
-		local param_k, param_v;
-		if value_params[param] or array_params[param] then
-			param_k, param_v = param, table.remove(arg, 1);
-			if not param_v then
-				return nil, "missing-value", raw_param;
-			end
-		else
-			param_k, param_v = param:match("^([^=]+)=(.+)$");
-			if not param_k then
-				if param:match("^no%-") then
-					param_k, param_v = param:sub(4), false;
-				else
-					param_k, param_v = param, true;
+			local param_k, param_v;
+			if value_params[uparam] or array_params[uparam] then
+				param_k, param_v = uparam, table.remove(arg, 1);
+				if not param_v then
+					return nil, "missing-value", raw_param;
+				end
+			else
+				param_k, param_v = param:match("^([^=]+)=(.+)$");
+				if not param_k then
+					if param:match("^no%-") then
+						param_k, param_v = param:sub(4), false;
+					else
+						param_k, param_v = param, true;
+					end
+				end
+				param_k = param_k:gsub("%-", "_");
+				if strict and not kv_params[param_k] then
+					return nil, "param-not-found", raw_param;
 				end
 			end
-			param_k = param_k:gsub("%-", "_");
-		end
-		if array_params[param] then
-			if parsed_opts[param_k] then
-				table.insert(parsed_opts[param_k], param_v);
+			if array_params[uparam] then
+				if parsed_opts[param_k] then
+					table.insert(parsed_opts[param_k], param_v);
+				else
+					parsed_opts[param_k] = { param_v };
+				end
 			else
-				parsed_opts[param_k] = { param_v };
+				parsed_opts[param_k] = param_v;
 			end
-		else
-			parsed_opts[param_k] = param_v;
+		elseif not stop_on_positional then
+			table.insert(parsed_opts, table.remove(arg, 1));
 		end
 	end
-	for i = 1, #arg do
-		parsed_opts[i] = arg[i];
+
+	if stop_on_positional then
+		for i = 1, #arg do
+			parsed_opts[i] = arg[i];
+		end
 	end
+
 	return parsed_opts;
 end
 
--- a/util/prosodyctl/shell.lua	Mon Feb 17 12:37:58 2025 +0100
+++ b/util/prosodyctl/shell.lua	Mon Feb 17 19:22:54 2025 +0000
@@ -87,17 +87,7 @@
 
 	if arg[1] then
 		if arg[2] then
-			local fmt = { "%s"; ":%s("; ")" };
-			for i = 3, #arg do
-				if arg[i]:sub(1, 1) == ":" then
-					table.insert(fmt, i, ")%s(");
-				elseif i > 3 and fmt[i - 1]:match("%%q$") then
-					table.insert(fmt, i, ", %q");
-				else
-					table.insert(fmt, i, "%q");
-				end
-			end
-			arg[1] = string.format(table.concat(fmt), table.unpack(arg));
+			arg[1] = ("{"..string.rep("%q", #arg, ", ").."}"):format(table.unpack(arg, 1, #arg));
 		end
 
 		client.events.add_handler("connected", function()