Changeset

6191:94399ad6b5ab

mod_invites_register_api: Use set_password() for password resets Previously the code relied on the (weird) behaviour of create_user(), which would update the password for a user account if it already existed. This has several issues, and we plan to deprecate this behaviour of create_user(). The larger issue is that this route does not trigger the user-password-changed event, which can be a security problem. For example, it did not disconnect existing user sessions (this occurs in mod_c2s in response to the event). Switching to set_password() is the right thing to do.
author Matthew Wild <mwild1@gmail.com>
date Thu, 06 Feb 2025 10:13:39 +0000
parents 6190:aa240145aa22
children 6192:76ae646563ea
files mod_anti_spam/mod_anti_spam.lua mod_anti_spam/trie.lib.lua mod_invites_register_api/mod_invites_register_api.lua
diffstat 3 files changed, 200 insertions(+), 47 deletions(-) [+]
line wrap: on
line diff
--- a/mod_anti_spam/mod_anti_spam.lua	Wed Feb 05 11:04:15 2025 -0500
+++ b/mod_anti_spam/mod_anti_spam.lua	Thu Feb 06 10:13:39 2025 +0000
@@ -1,5 +1,7 @@
+local cache = require "util.cache";
 local ip = require "util.ip";
 local jid_bare = require "util.jid".bare;
+local jid_host = require "util.jid".host;
 local jid_split = require "util.jid".split;
 local set = require "util.set";
 local sha256 = require "util.hashes".sha256;
@@ -11,10 +13,27 @@
 
 local new_rtbl_subscription = module:require("rtbl").new_rtbl_subscription;
 local trie = module:require("trie");
+local pset = module:require("pset");
 
-local spam_source_domains = set.new();
-local spam_source_ips = trie.new();
-local spam_source_jids = set.new();
+-- { [service_jid] = set, ... }
+local spam_source_domains_by_service = {};
+local spam_source_ips_by_service = {};
+local spam_source_jids_by_service = {};
+
+local service_probabilities = {
+	-- if_present = probability the address is a spammer if they are on the list
+	-- if_absent (optional): probability the address is a spammer if they are not on the list
+	-- [service_jid] = { if_present = 0.9, if_absent = 0.5 };
+};
+
+
+-- These "probabilistic sets" combine the multiple lists according to their weights
+local p_spam_source_domains = pset.new(spam_source_domains_by_service, service_probabilities);
+local p_spam_source_ips = pset.new(spam_source_ips_by_service, service_probabilities);
+local p_spam_source_jids = pset.new(spam_source_jids_by_service, service_probabilities);
+
+local domain_local_report_threshold = module:get_option_number("anti_spam_local_report_threshold", 2);
+
 local default_spam_action = module:get_option("anti_spam_default_action", "bounce");
 local custom_spam_actions = module:get_option("anti_spam_actions", {});
 
@@ -92,20 +111,20 @@
 end
 
 function is_spammy_server(session)
-	if spam_source_domains:contains(session.from_host) then
+	if p_spam_source_domains:contains(session.from_host) then
 		return true;
 	end
 	local raw_ip = session.ip;
 	local parsed_ip = raw_ip and ip.new_ip(session.ip);
 	-- Not every session has an ip - for example, stanzas sent from a
 	-- local host session
-	if parsed_ip and spam_source_ips:contains_ip(parsed_ip) then
+	if parsed_ip and p_spam_source_ips:contains_ip(parsed_ip) then
 		return true;
 	end
 end
 
 function is_spammy_sender(sender_jid)
-	return spam_source_jids:contains(sha256(sender_jid, true));
+	return p_spam_source_jids:contains(sha256(sender_jid, true));
 end
 
 local spammy_strings = module:get_option_array("anti_spam_block_strings");
@@ -140,6 +159,16 @@
 local anti_spam_services = module:get_option_array("anti_spam_services", {});
 
 for _, rtbl_service_jid in ipairs(anti_spam_services) do
+	service_probabilities[rtbl_service_jid] = { if_present = 0.95 };
+
+	local spam_source_domains = set.new();
+	local spam_source_ips = trie.new();
+	local spam_source_jids = set.new();
+
+	spam_source_domains_by_service[rtbl_service_jid] = spam_source_domains;
+	spam_source_ips_by_service[rtbl_service_jid] = spam_source_ips;
+	spam_source_jids_by_service[rtbl_service_jid] = spam_source_jids;
+
 	new_rtbl_subscription(rtbl_service_jid, "spam_source_domains", {
 		added = function (item)
 			spam_source_domains:add(item);
@@ -174,6 +203,68 @@
 	});
 end
 
+-- And local reports...
+
+do
+	local spam_source_domains = set.new();
+	local spam_source_ips = set.new();
+
+	local domain_counts = cache.new(100);
+
+	service_probabilities[module.host] = { if_present = 0.6, if_absent = 0.4 };
+
+	module:hook("mod_spam_reporting/spam-report", function (event)
+		-- TODO: check for >= prosody:member
+		local reported_jid = event.jid;
+		local reported_domain = jid_host(reported_jid);
+		local report_count = (domain_counts:get(reported_domain) or 0) + 1;
+		domain_counts:set(reported_domain, report_count);
+
+		if report_count >= domain_local_report_threshold then
+			spam_source_domains:add(reported_domain);
+		end
+	end);
+
+	module:add_item("shell-command", {
+		section = "antispam";
+		section_desc = "Anti-spam management commands";
+		name = "filter_domain";
+		desc = "Restrict interactions from a remote domain to a virtual host";
+		args = {
+			{ name = "host", type = "string" };
+			{ name = "remote_domain", type = "string" };
+		};
+		host_selector = "host";
+		handler = function(self, host, remote_domain) --luacheck: ignore 212/self 212/host
+			spam_source_domains:add(remote_domain);
+			return true, "Remote domain now restricted: "..remote_domain;
+		end;
+	});
+
+	module:add_item("shell-command", {
+		section = "antispam";
+		section_desc = "Anti-spam management commands";
+		name = "filter_ip";
+		desc = "Restrict interactions from a remote IP/CIDR to a virtual host";
+		args = {
+			{ name = "host", type = "string" };
+			{ name = "remote_ip", type = "string" };
+		};
+		host_selector = "host";
+		handler = function(self, host, remote_ip) --luacheck: ignore 212/self 212/host
+			local subnet_ip, subnet_bits = ip.parse_cidr(remote_ip);
+			if not subnet_ip then
+				return false, subnet_bits; -- false, err
+			end
+
+			spam_source_ips:add_subnet(subnet_ip, subnet_bits);
+
+			return true, "Remote IP now restricted: "..remote_ip;
+		end;
+	});
+
+end
+
 module:hook("message/bare", function (event)
 	local to_user, to_host = jid_split(event.stanza.attr.to);
 
@@ -237,3 +328,4 @@
 
 	module:log("debug", "Allowing subscription request through");
 end, 500);
+
--- a/mod_anti_spam/trie.lib.lua	Wed Feb 05 11:04:15 2025 -0500
+++ b/mod_anti_spam/trie.lib.lua	Thu Feb 06 10:13:39 2025 +0000
@@ -120,6 +120,29 @@
 	end
 end
 
+local function find_match_in_descendents(node, item, len, i)
+	for child_byte, child_node in pairs(node) do
+		if type(child_byte) == "number" then
+			if child_node.terminal then
+				local bits = child_node.value;
+				for j = #bits, 1, -1 do
+					local b = bits[j]-((i-1)*8);
+					if b ~= 8 then
+						local mask = bit.bnot(2^b-1);
+						if bit.band(bit.bxor(c, child_byte), mask) == 0 then
+							return true;
+						end
+					end
+				end
+			else
+				
+			end
+		end
+	end
+	return false;
+end
+
+--
 function trie_methods:contains_ip(item)
 	item = item.packed;
 	local node = self.root;
@@ -132,25 +155,57 @@
 		local c = item:byte(i);
 		local child = node[c];
 		if not child then
-			for child_byte, child_node in pairs(node) do
-				if type(child_byte) == "number" and child_node.terminal then
-					local bits = child_node.value;
-					for j = #bits, 1, -1 do
-						local b = bits[j]-((i-1)*8);
-						if b ~= 8 then
-							local mask = bit.bnot(2^b-1);
-							if bit.band(bit.bxor(c, child_byte), mask) == 0 then
-								return true;
-							end
-						end
-					end
-				end
-			end
-			return false;
+			return find_match_in_descendents(node, item, len, i);
 		end
 		node = child;
 	end
 end
+--]]
+
+--[[
+function trie_methods:contains_ip(item)
+	item = item.packed
+	local node = self.root
+	local len = #item
+
+	print(string.byte(item, 1, 4))
+
+	local function search(node, index)
+		if node.terminal then
+			print("S", "TERM")
+			return true
+		end
+
+		if index > len then
+			print("S", "MAX LEN")
+			return false
+		end
+
+		local c = item:byte(index)
+		local child = node[c]
+
+		print("S", (" "):rep(index), ("item[%d] = %d, has_child = %s"):format(index, c, not not child));
+
+		if child then
+			-- Continue searching down the current path
+			return search(child, index + 1)
+		else
+			-- Check all children for a terminal node
+			for child_byte, child_node in pairs(node) do
+				if type(child_byte) == "number" and child_byte then
+					if search(child_node, index + 1) then
+						return true
+					end
+				end
+			end
+		end
+
+		return false
+	end
+
+	return search(node, 1)
+end
+--]]
 
 local function new()
 	return setmetatable({
--- a/mod_invites_register_api/mod_invites_register_api.lua	Wed Feb 05 11:04:15 2025 -0500
+++ b/mod_invites_register_api/mod_invites_register_api.lua	Thu Feb 06 10:13:39 2025 +0000
@@ -75,39 +75,45 @@
 		if reset_for ~= prepped_username then
 			return 403; -- Attempt to use reset invite for incorrect user
 		end
+		local ok, err = usermanager.set_password(prepped_username, password, module.host);
+		if not ok then
+			module:log("error", "Unable to reset password for %s@%s: %s", prepped_username, module.host, err);
+			return 500;
+		end
+		module:fire_event("user-password-reset", user);
 	elseif usermanager.user_exists(prepped_username, module.host) then
 		return 409; -- Conflict
-	end
+	else
+		local registering = {
+			validated_invite = invite;
+			username = prepped_username;
+			host = module.host;
+			ip = request.ip;
+			allowed = true;
+		};
 
-	local registering = {
-		validated_invite = invite;
-		username = prepped_username;
-		host = module.host;
-		ip = request.ip;
-		allowed = true;
-	};
+		module:fire_event("user-registering", registering);
 
-	module:fire_event("user-registering", registering);
-
-	if not registering.allowed then
-		return 403;
-	end
+		if not registering.allowed then
+			return 403;
+		end
 
-	local ok, err = usermanager.create_user(prepped_username, password, module.host);
+		local ok, err = usermanager.create_user(prepped_username, password, module.host);
 
-	if not ok then
-		local err_id = id.short();
-		module:log("warn", "Registration failed (%s): %s", err_id, tostring(err));
-		return 500;
-	end
+		if not ok then
+			local err_id = id.short();
+			module:log("warn", "Registration failed (%s): %s", err_id, tostring(err));
+			return 500;
+		end
 
-	module:fire_event("user-registered", {
-		username = prepped_username;
-		host = module.host;
-		source = "mod_"..module.name;
-		validated_invite = invite;
-		ip = request.ip;
-	});
+		module:fire_event("user-registered", {
+			username = prepped_username;
+			host = module.host;
+			source = "mod_"..module.name;
+			validated_invite = invite;
+			ip = request.ip;
+		});
+	end
 
 	return json.encode({
 		jid = prepped_username .. "@" .. module.host;