Diff

mod_rest/mod_rest.lua @ 6211:750d64c47ec6 draft

Merge
author Trần H. Trung <xmpp:trần.h.trung@trung.fun>
date Tue, 18 Mar 2025 00:31:36 +0700
parent 6206:ac7e2992fe6e
child 6244:c71d8bc77c95
line wrap: on
line diff
--- a/mod_rest/mod_rest.lua	Tue Mar 18 00:19:25 2025 +0700
+++ b/mod_rest/mod_rest.lua	Tue Mar 18 00:31:36 2025 +0700
@@ -23,7 +23,7 @@
 -- Lower than the default c2s size limit to account for possible JSON->XML size increase
 local stanza_size_limit = module:get_option_number("rest_stanza_size_limit", 1024 * 192);
 
-local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" });
+local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" }) / string.lower;
 
 local www_authenticate_header;
 do
@@ -34,35 +34,69 @@
 	www_authenticate_header = table.concat(header, ", ");
 end
 
-local function check_credentials(request)
+local post_errors = errors.init("mod_rest", {
+	noauthz = { code = 401; type = "auth"; condition = "not-authorized"; text = "No credentials provided" };
+	unauthz = { code = 403; type = "auth"; condition = "not-authorized"; text = "Credentials not accepted" };
+	malformauthz = { code = 403; type = "auth"; condition = "not-authorized"; text = "Credentials malformed" };
+	prepauthz = { code = 403; type = "auth"; condition = "not-authorized"; text = "Credentials failed stringprep" };
+	parse = { code = 400; type = "modify"; condition = "not-well-formed"; text = "Failed to parse payload" };
+	xmlns = { code = 422; type = "modify"; condition = "invalid-namespace"; text = "'xmlns' attribute must be empty" };
+	name = { code = 422; type = "modify"; condition = "unsupported-stanza-type"; text = "Invalid stanza, must be 'message', 'presence' or 'iq'." };
+	to = { code = 422; type = "modify"; condition = "improper-addressing"; text = "Invalid destination JID" };
+	from = { code = 422; type = "modify"; condition = "invalid-from"; text = "Invalid source JID" };
+	from_auth = { code = 403; type = "auth"; condition = "not-authorized"; text = "Not authorized to send stanza with requested 'from'" };
+	iq_type = { code = 422; type = "modify"; condition = "invalid-xml"; text = "'iq' stanza must be of type 'get' or 'set'" };
+	iq_tags = { code = 422; type = "modify"; condition = "bad-format"; text = "'iq' stanza must have exactly one child tag" };
+	mediatype = { code = 415; type = "cancel"; condition = "bad-format"; text = "Unsupported media type" };
+	size = { code = 413; type = "modify"; condition = "resource-constraint", text = "Payload too large" };
+});
+
+local token_session_errors = errors.init("mod_tokenauth", {
+	["internal-error"] = { code = 500; type = "wait"; condition = "internal-server-error" };
+	["invalid-token-format"] = { code = 403; type = "auth"; condition = "not-authorized"; text = "Credentials malformed" };
+	["not-authorized"] = { code = 403; type = "auth"; condition = "not-authorized"; text = "Credentials not accepted" };
+});
+
+local function check_credentials(request) -- > session | boolean, error
 	local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$");
+	auth_type = auth_type and auth_type:lower();
 	if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then
-		return false;
+		return nil, post_errors.new("noauthz", { request = request });
 	end
 
-	if auth_type == "Basic" then
+	if auth_type == "basic" then
 		local creds = base64.decode(auth_data);
-		if not creds then return false; end
+		if not creds then
+			return nil, post_errors.new("malformauthz", { request = request });
+		end
 		local username, password = string.match(creds, "^([^:]+):(.*)$");
-		if not username then return false; end
+		if not username then
+			return nil, post_errors.new("malformauthz", { request = request });
+		end
 		username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password);
-		if not username then return false; end
+		if not username or not password then
+			return false, post_errors.new("prepauthz", { request = request });
+		end
 		if not um.test_password(username, module.host, password) then
-			return false;
+			return false, post_errors.new("unauthz", { request = request });
 		end
-		return { username = username, host = module.host };
-	elseif auth_type == "Bearer" then
+		return { username = username; host = module.host };
+	elseif auth_type == "bearer" then
 		if tokens.get_token_session then
-			return tokens.get_token_session(auth_data);
+			local token_session, err = tokens.get_token_session(auth_data);
+			if not token_session then
+				return false, token_session_errors.new(err or "not-authorized", { request = request });
+			end
+			return token_session;
 		else -- COMPAT w/0.12
 			local token_info = tokens.get_token_info(auth_data);
 			if not token_info or not token_info.session then
-				return false;
+				return false, post_errors.new("unauthz", { request = request });
 			end
 			return token_info.session;
 		end
 	end
-	return nil;
+	return nil, post_errors.new("noauthz", { request = request });
 end
 
 if module:get_option_string("authentication") == "anonymous" and module:get_option_boolean("anonymous_rest") then
@@ -125,7 +159,7 @@
 
 -- (table, string) -> table
 local function amend_from_path(data, path)
-	local st_kind, st_type, st_to = path:match("^([mpi]%w+)/(%w+)/(.*)$");
+	local st_kind, st_type, st_to = path:match("^([mpi]%w+)/([%w_]+)/(.*)$");
 	if not st_kind then return; end
 	if st_kind == "iq" and st_type ~= "get" and st_type ~= "set" then
 		-- GET /iq/disco/jid
@@ -268,21 +302,6 @@
 	error "unsupported encoding";
 end
 
-local post_errors = errors.init("mod_rest", {
-	noauthz = { code = 401; type = "auth"; condition = "not-authorized"; text = "No credentials provided" };
-	unauthz = { code = 403; type = "auth"; condition = "not-authorized"; text = "Credentials not accepted" };
-	parse = { code = 400; type = "modify"; condition = "not-well-formed"; text = "Failed to parse payload" };
-	xmlns = { code = 422; type = "modify"; condition = "invalid-namespace"; text = "'xmlns' attribute must be empty" };
-	name = { code = 422; type = "modify"; condition = "unsupported-stanza-type"; text = "Invalid stanza, must be 'message', 'presence' or 'iq'." };
-	to = { code = 422; type = "modify"; condition = "improper-addressing"; text = "Invalid destination JID" };
-	from = { code = 422; type = "modify"; condition = "invalid-from"; text = "Invalid source JID" };
-	from_auth = { code = 403; type = "auth"; condition = "not-authorized"; text = "Not authorized to send stanza with requested 'from'" };
-	iq_type = { code = 422; type = "modify"; condition = "invalid-xml"; text = "'iq' stanza must be of type 'get' or 'set'" };
-	iq_tags = { code = 422; type = "modify"; condition = "bad-format"; text = "'iq' stanza must have exactly one child tag" };
-	mediatype = { code = 415; type = "cancel"; condition = "bad-format"; text = "Unsupported media type" };
-	size = { code = 413; type = "modify"; condition = "resource-constraint", text = "Payload too large" };
-});
-
 -- GET → iq-get
 local function parse_request(request, path)
 	if path and request.method == "GET" then
@@ -308,9 +327,10 @@
 		response.headers.www_authenticate = www_authenticate_header;
 		return post_errors.new("noauthz");
 	else
-		origin = check_credentials(request);
+		local err;
+		origin, err = check_credentials(request);
 		if not origin then
-			return post_errors.new("unauthz");
+			return err or post_errors.new("unauthz");
 		end
 		from = jid.join(origin.username, origin.host, origin.resource);
 		origin.full_jid = from;
@@ -642,6 +662,17 @@
 	"application/json",
 };
 
+-- strip some stuff, notably the optional traceback table that casues stack overflow in util.json
+local function simplify_error(e)
+	return {
+		type = e.type;
+		condition = e.condition;
+		text = e.text;
+		extra = e.extra;
+		source = e.source;
+	};
+end
+
 local http_server = require "net.http.server";
 module:hook_object_event(http_server, "http-error", function (event)
 	local request, response = event.request, event.response;
@@ -664,7 +695,7 @@
 		end
 		return json.encode({
 				type = "error",
-				error = event.error,
+				error = simplify_error(event.error),
 				code = event.code,
 			});
 	end