Changeset

5254:b0ccdd12a70d

mod_http_oauth2: Prepare to handle multiple e.g. non-role scopes This is to prepare to handle scopes like "openid" that don't map to roles.
author Kim Alvefur <zash@zash.se>
date Thu, 16 Mar 2023 17:03:48 +0100
parents 5253:d3b2d42daaee
children 5255:001c8fdc91a4
files mod_http_oauth2/mod_http_oauth2.lua
diffstat 1 files changed, 34 insertions(+), 17 deletions(-) [+]
line wrap: on
line diff
--- a/mod_http_oauth2/mod_http_oauth2.lua	Thu Mar 16 14:27:46 2023 +0100
+++ b/mod_http_oauth2/mod_http_oauth2.lua	Thu Mar 16 17:03:48 2023 +0100
@@ -74,20 +74,33 @@
 	jwt_sign, jwt_verify = jwt.init(registration_algo, registration_key, registration_key, registration_options);
 end
 
+local function parse_scopes(scope_string)
+	return array(scope_string:gmatch("%S+"));
+end
+
 local function filter_scopes(username, host, requested_scope_string)
 	if host ~= module.host then
 		return usermanager.get_jid_role(username.."@"..host, module.host).name;
 	end
 
-	if requested_scope_string then -- Specific role requested
-		-- TODO: The requested scope string is technically a space-delimited list
-		-- of scopes, but for simplicity we're mapping this slot to role names.
-		if usermanager.user_can_assume_role(username, module.host, requested_scope_string) then
-			return requested_scope_string;
+	local selected_role, granted_scopes = nil, array();
+
+	if requested_scope_string then -- Specific role(s) requested
+		local requested_scopes = parse_scopes(requested_scope_string);
+		for _, scope in ipairs(requested_scopes) do
+			if selected_role == nil and usermanager.user_can_assume_role(username, module.host, scope) then
+				selected_role = scope;
+			end
 		end
 	end
 
-	return usermanager.get_user_role(username, module.host).name;
+	if not selected_role then
+		-- By default use the users' default role
+		selected_role = usermanager.get_user_role(username, module.host).name;
+	end
+	granted_scopes:push(selected_role);
+
+	return granted_scopes:concat(" "), selected_role;
 end
 
 local function code_expires_in(code) --> number, seconds until code expires
@@ -140,12 +153,15 @@
 	return { name = client.client_name; uri = client.client_uri };
 end
 
-local function new_access_token(token_jid, scope, ttl, client)
-	local token_data;
+local function new_access_token(token_jid, role, scope, ttl, client)
+	local token_data = {};
 	if client then
-		token_data = { oauth2_client = client_subset(client) };
+		token_data.oauth2_client = client_subset(client);
 	end
-	local token = tokens.create_jid_token(token_jid, token_jid, scope, ttl, token_data, "oauth2");
+	if next(token_data) == nil then
+		token_data = nil;
+	end
+	local token = tokens.create_jid_token(token_jid, token_jid, role, ttl, token_data, "oauth2");
 	return {
 		token_type = "bearer";
 		access_token = token;
@@ -188,19 +204,20 @@
 	end
 
 	local granted_jid = jid.join(request_username, request_host, request_resource);
-	local granted_scopes = filter_scopes(request_username, request_host, params.scope);
-	return json.encode(new_access_token(granted_jid, granted_scopes, nil));
+	local granted_scopes, granted_role = filter_scopes(request_username, request_host, params.scope);
+	return json.encode(new_access_token(granted_jid, granted_role, granted_scopes, nil));
 end
 
 function response_type_handlers.code(client, params, granted_jid)
 	local request_username, request_host = jid.split(granted_jid);
-	local granted_scopes = filter_scopes(request_username, request_host, params.scope);
+	local granted_scopes, granted_role = filter_scopes(request_username, request_host, params.scope);
 
 	local code = id.medium();
 	local ok = codes:set(params.client_id .. "#" .. code, {
 		expires = os.time() + 600;
 		granted_jid = granted_jid;
 		granted_scopes = granted_scopes;
+		granted_role = granted_role;
 	});
 	if not ok then
 		return {status_code = 429};
@@ -245,8 +262,8 @@
 -- Implicit flow
 function response_type_handlers.token(client, params, granted_jid)
 	local request_username, request_host = jid.split(granted_jid);
-	local granted_scopes = filter_scopes(request_username, request_host, params.scope);
-	local token_info = new_access_token(granted_jid, granted_scopes, nil, client);
+	local granted_scopes, granted_role = filter_scopes(request_username, request_host, params.scope);
+	local token_info = new_access_token(granted_jid, granted_role, granted_scopes, nil, client);
 
 	local redirect = url.parse(get_redirect_uri(client, params.redirect_uri));
 	token_info.state = params.state;
@@ -295,7 +312,7 @@
 		return oauth_error("invalid_client", "incorrect credentials");
 	end
 
-	return json.encode(new_access_token(code.granted_jid, code.granted_scopes, nil, client));
+	return json.encode(new_access_token(code.granted_jid, code.granted_role, code.granted_scopes, nil, client));
 end
 
 -- Used to issue/verify short-lived tokens for the authorization process below
@@ -414,7 +431,7 @@
 		end
 		if request_password == component_secret then
 			local granted_jid = jid.join(request_username, request_host, request_resource);
-			return json.encode(new_access_token(granted_jid, nil, nil));
+			return json.encode(new_access_token(granted_jid, nil, nil, nil));
 		end
 		return oauth_error("invalid_grant", "incorrect credentials");
 	end