Diff

mod_rest/mod_rest.lua @ 3910:49efd1323a1b

mod_rest: Add support for token authentication
author Matthew Wild <mwild1@gmail.com>
date Wed, 26 Feb 2020 18:36:40 +0000
parent 3909:eb27e51cf2c9
child 3911:064c32a5be7c
line wrap: on
line diff
--- a/mod_rest/mod_rest.lua	Wed Feb 26 18:04:17 2020 +0000
+++ b/mod_rest/mod_rest.lua	Wed Feb 26 18:36:40 2020 +0000
@@ -4,34 +4,41 @@
 --
 -- This file is MIT/X11 licensed.
 
+local encodings = require "util.encodings";
+local base64 = encodings.base64;
 local errors = require "util.error";
 local http = require "net.http";
 local id = require "util.id";
 local jid = require "util.jid";
 local json = require "util.json";
 local st = require "util.stanza";
+local um = require "core.usermanager";
 local xml = require "util.xml";
 
-local allow_any_source = module:get_host_type() == "component";
-local validate_from_addresses = module:get_option_boolean("validate_from_addresses", true);
-local secret = assert(module:get_option_string("rest_credentials"), "rest_credentials is a required setting");
-local auth_type = assert(secret:match("^%S+"), "Format of rest_credentials MUST be like 'Bearer secret'");
-assert(auth_type == "Bearer" or auth_type == "Basic", "Only 'Bearer' and 'Basic' are supported in rest_credentials");
+local jsonmap = module:require"jsonmap";
+
+local tokens = module:depends("authtokens");
+
+local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" });
 
-local jsonmap = module:require"jsonmap";
+local www_authenticate_header;
+do
+	local header, realm = {}, module.host.."/"..module.name;
+	for mech in auth_mechanisms do
+		header[#header+1] = ("%s realm=%q"):format(mech, realm);
+	end
+	www_authenticate_header = table.concat(header, ", ");
+end
+
 -- Bearer token
 local function check_credentials(request)
-	return request.headers.authorization == secret;
-end
-if secret == "Basic" and module:get_host_type() == "local" then
-	local um = require "core.usermanager";
-	local encodings = require "util.encodings";
-	local base64 = encodings.base64;
+	local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$");
+	if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then
+		return false;
+	end
 
-	function check_credentials(request)
-		local creds = string.match(request.headers.authorization, "^Basic%s+([A-Za-z0-9+/]+=?=?)%s*$");
-		if not creds then return false; end
-		creds = base64.decode(creds);
+	if auth_type == "Basic" then
+		local creds = base64.decode(auth_data);
 		if not creds then return false; end
 		local username, password = string.match(creds, "^([^:]+):(.*)$");
 		if not username then return false; end
@@ -40,8 +47,15 @@
 		if not um.test_password(username, module.host, password) then
 			return false;
 		end
-		return jid.join(username, module.host);
+		return { username = username, host = module.host };
+	elseif auth_type == "Bearer" then
+		local token_info = tokens.get_token_info(auth_data);
+		if not token_info or not token_info.session then
+			return false;
+		end
+		return token_info.session;
 	end
+	return nil;
 end
 
 local function parse(mimetype, data)
@@ -84,18 +98,18 @@
 
 local function handle_post(event)
 	local request, response = event.request, event.response;
-	local from = module.host;
+	local from;
+	local origin;
+
 	if not request.headers.authorization then
-		response.headers.www_authenticate = ("%s realm=%q"):format(auth_type, module.host.."/"..module.name);
+		response.headers.www_authenticate = www_authenticate_header;
 		return 401;
 	else
-		local authz = check_credentials(request);
-		if not authz then
+		origin = check_credentials(request);
+		if not origin then
 			return 401;
 		end
-		if type(authz) == "string" then
-			from = authz;
-		end
+		from = jid.join(origin.username, origin.host, origin.resource);
 	end
 	local payload, err = parse(request.headers.content_type, request.body);
 	if not payload then
@@ -111,13 +125,15 @@
 	if not to then
 		return errors.new({ code = 422, text = "Invalid destination JID" });
 	end
-	if allow_any_source and payload.attr.from then
-		from = jid.prep(payload.attr.from);
-		if not from then
+	if payload.attr.from then
+		local requested_from = jid.prep(payload.attr.from);
+		if not requested_from then
 			return errors.new({ code = 422, text = "Invalid source JID" });
 		end
-		if validate_from_addresses and not jid.compare(from, module.host) then
-			return errors.new({ code = 403, text = "Source JID must belong to current host" });
+		if jid.compare(requested_from, from) then
+			from = requested_from;
+		else
+			return errors.new({ code = 403, text = "Not authorized to send from "..requested_from });
 		end
 	end
 	payload.attr = {
@@ -130,12 +146,15 @@
 	module:log("debug", "Received[rest]: %s", payload:top_tag());
 	local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type)
 	if payload.name == "iq" then
+		function origin.send(stanza)
+			prosody.core_route_stanza(nil, stanza);
+		end
 		if payload.attr.type ~= "get" and payload.attr.type ~= "set" then
 			return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" });
 		elseif #payload.tags ~= 1 then
 			return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" });
 		end
-		return module:send_iq(payload):next(
+		return module:send_iq(payload, origin):next(
 			function (result)
 				module:log("debug", "Sending[rest]: %s", result.stanza:top_tag());
 				response.headers.content_type = send_type;
@@ -154,7 +173,6 @@
 				end
 			end);
 	else
-		local origin = {};
 		function origin.send(stanza)
 			module:log("debug", "Sending[rest]: %s", stanza:top_tag());
 			response.headers.content_type = send_type;