File

mod_anti_spam/trie.lib.lua @ 5884:fdff8cb54302

mod_report_forward: fix address detection when there are multiple field values
author Stephen Paul Weber <singpolyma@singpolyma.net>
date Mon, 11 Mar 2024 19:28:19 -0500
parent 5883:259ffdbf8906
line wrap: on
line source

local bit = require "prosody.util.bitcompat";

local trie_methods = {};
local trie_mt = { __index = trie_methods };

local function new_node()
	return {};
end

function trie_methods:set(item, value)
	local node = self.root;
	for i = 1, #item do
		local c = item:byte(i);
		if not node[c] then
			node[c] = new_node();
		end
		node = node[c];
	end
	node.terminal = true;
	node.value = value;
end

local function _remove(node, item, i)
	if i > #item then
		if node.terminal then
			node.terminal = nil;
			node.value = nil;
		end
		if next(node) ~= nil then
			return node;
		end
		return nil;
	end
	local c = item:byte(i);
	local child = node[c];
	local ret;
	if child then
		ret = _remove(child, item, i+1);
		node[c] = ret;
	end
	if ret == nil and next(node) == nil then
		return nil;
	end
	return node;
end

function trie_methods:remove(item)
	return _remove(self.root, item, 1);
end

function trie_methods:get(item, partial)
	local value;
	local node = self.root;
	local len = #item;
	for i = 1, len do
		if partial and node.terminal then
			value = node.value;
		end
		local c = item:byte(i);
		node = node[c];
		if not node then
			return value, i - 1;
		end
	end
	return node.value, len;
end

function trie_methods:add(item)
	return self:set(item, true);
end

function trie_methods:contains(item, partial)
	return self:get(item, partial) ~= nil;
end

function trie_methods:longest_prefix(item)
	return select(2, self:get(item));
end

function trie_methods:add_subnet(item, bits)
	item = item.packed:sub(1, math.ceil(bits/8));
	local existing = self:get(item);
	if not existing then
		existing = { bits };
		return self:set(item, existing);
	end

	-- Simple insertion sort
	for i = 1, #existing do
		local v = existing[i];
		if v == bits then
			return; -- Already in there
		elseif v > bits then
			table.insert(existing, v, i);
			return;
		end
	end
end

function trie_methods:remove_subnet(item, bits)
	item = item.packed:sub(1, math.ceil(bits/8));
	local existing = self:get(item);
	if not existing then
		return;
	end

	-- Simple insertion sort
	for i = 1, #existing do
		local v = existing[i];
		if v == bits then
			table.remove(existing, i);
			break;
		elseif v > bits then
			return; -- Stop search
		end
	end

	if #existing == 0 then
		self:remove(item);
	end
end

function trie_methods:has_ip(item)
	item = item.packed;
	local node = self.root;
	local len = #item;
	for i = 1, len do
		if node.terminal then
			return true;
		end

		local c = item:byte(i);
		local child = node[c];
		if not child then
			for child_byte, child_node in pairs(node) do
				if type(child_byte) == "number" and child_node.terminal then
					local bits = child_node.value;
					for j = #bits, 1, -1 do
						local b = bits[j]-((i-1)*8);
						if b ~= 8 then
							local mask = bit.bnot(2^b-1);
							if bit.band(bit.bxor(c, child_byte), mask) == 0 then
								return true;
							end
						end
					end
				end
			end
			return false;
		end
		node = child;
	end
end

local function new()
	return setmetatable({
		root = new_node();
	}, trie_mt);
end

local function is_trie(o)
	return getmetatable(o) == trie_mt;
end

return {
	new = new;
	is_trie = is_trie;
};