Diff

mod_sasl2_fast/mod_sasl2_fast.lua @ 6211:750d64c47ec6

Merge
author Trần H. Trung <xmpp:trần.h.trung@trung.fun>
date Tue, 18 Mar 2025 00:31:36 +0700 (3 months ago)
parent 6150:f77f5e408d6a
line wrap: on
line diff
--- a/mod_sasl2_fast/mod_sasl2_fast.lua	Tue Mar 18 00:19:25 2025 +0700
+++ b/mod_sasl2_fast/mod_sasl2_fast.lua	Tue Mar 18 00:31:36 2025 +0700
@@ -8,6 +8,11 @@
 local now = require "util.time".now;
 local hash = require "util.hashes";
 
+local sasl_mt = getmetatable(sasl.new("", { mechanisms = {} }));
+local function is_util_sasl(sasl_handler)
+	return getmetatable(sasl_handler) == sasl_mt;
+end
+
 module:depends("sasl2");
 
 -- Tokens expire after 21 days by default
@@ -49,7 +54,7 @@
 			log("debug", "Looking for %s token %s/%s", mechanism, username, key);
 			token = token_store:get(username, key);
 			if token and token.mechanism == mechanism then
-				local expected_hash = hmac_f(token.secret, "Initiator"..cb_data);
+				local expected_hash = hmac_f(token.secret, "Initiator"..(cb_data or ""));
 				if hash.equals(expected_hash, token_hash) then
 					local current_time = now();
 					if token.expires_at < current_time then
@@ -77,7 +82,7 @@
 						log("debug", "FAST token due for rotation (age: %d)", current_time - token.issued_at);
 						rotation_needed = true;
 					end
-					return true, username, hmac_f(token.secret, "Responder"..cb_data), rotation_needed;
+					return true, username, hmac_f(token.secret, "Responder"..(cb_data or "")), rotation_needed;
 				end
 			end
 			if not tried_current_token then
@@ -93,12 +98,19 @@
 	end
 end
 
-function get_sasl_handler()
+-- If FAST fails, we want to restore back to a non-FAST handler
+local function _clean_clone_shim(self)
+	return self.nonfast_sasl_handler:clean_clone();
+end
+
+function get_sasl_handler(username, nonfast_sasl_handler) -- luacheck: ignore 212/username
 	local token_auth_profile = {
 		ht_sha_256 = new_token_tester(hash.hmac_sha256);
 	};
 	local handler = sasl.new(module.host, token_auth_profile);
 	handler.fast = true;
+	handler.nonfast_sasl_handler = nonfast_sasl_handler;
+	handler.clean_clone = _clean_clone_shim;
 	return handler;
 end
 
@@ -110,12 +122,14 @@
 		username = jid.node(event.stream.from);
 		if not username then return; end
 	end
-	local sasl_handler = get_sasl_handler(username);
+	local sasl_handler = get_sasl_handler(username, session.sasl_handler);
 	if not sasl_handler then return; end
 	sasl_handler.fast_auth = true; -- For informational purposes
-	-- Copy channel binding info from primary SASL handler
-	sasl_handler.profile.cb = session.sasl_handler.profile.cb;
-	sasl_handler.userdata = session.sasl_handler.userdata;
+	-- Copy channel binding info from primary SASL handler if it's compatible
+	if is_util_sasl(session.sasl_handler) then
+		sasl_handler.profile.cb = session.sasl_handler.profile.cb;
+		sasl_handler.userdata = session.sasl_handler.userdata;
+	end
 	-- Store this handler, in case we later want to use it for authenticating
 	session.fast_sasl_handler = sasl_handler;
 	local fast = st.stanza("fast", { xmlns = xmlns_fast });
@@ -196,14 +210,17 @@
 		if not authc_username then
 			return "failure", "malformed-request";
 		end
-		if not sasl_handler.profile.cb then
-			module:log("warn", "Attempt to use channel binding %s with SASL profile that does not support any channel binding (FAST: %s)", cb_name, sasl_handler.fast);
-			return "failure", "malformed-request";
-		elseif not sasl_handler.profile.cb[cb_name] then
-			module:log("warn", "SASL profile does not support %s channel binding (FAST: %s)", cb_name, sasl_handler.fast);
-			return "failure", "malformed-request";
+		local cb_data;
+		if cb_name then
+			if not sasl_handler.profile.cb then
+				module:log("warn", "Attempt to use channel binding %s with SASL profile that does not support any channel binding (FAST: %s)", cb_name, sasl_handler.fast);
+				return "failure", "malformed-request";
+			elseif not sasl_handler.profile.cb[cb_name] then
+				module:log("warn", "SASL profile does not support %s channel binding (FAST: %s)", cb_name, sasl_handler.fast);
+				return "failure", "malformed-request";
+			end
+			cb_data = sasl_handler.profile.cb[cb_name](sasl_handler) or "";
 		end
-		local cb_data = cb_name and sasl_handler.profile.cb[cb_name](sasl_handler) or "";
 		local ok, authz_username, response, rotation_needed = backend(
 			mechanism_name,
 			authc_username,