Comparison

mod_firewall/mod_firewall.lua @ 971:53e158e44a44

mod_firewall: Add rate limiting capabilities, and keep zones and throttle objects in shared tables
author Matthew Wild <mwild1@gmail.com>
date Sat, 06 Apr 2013 22:20:59 +0100
parent 967:a88f33fe6970
child 980:aeb11522a44f
comparison
equal deleted inserted replaced
970:adcb751f22f3 971:53e158e44a44
1 1
2 local resolve_relative_path = require "core.configmanager".resolve_relative_path; 2 local resolve_relative_path = require "core.configmanager".resolve_relative_path;
3 local logger = require "util.logger".init; 3 local logger = require "util.logger".init;
4 local set = require "util.set"; 4 local set = require "util.set";
5 local it = require "util.iterators";
5 local add_filter = require "util.filters".add_filter; 6 local add_filter = require "util.filters".add_filter;
6 7 local new_throttle = require "util.throttle".create;
7 zones = {}; 8
8 local zones = zones; 9 local zones, throttles = module:shared("zones", "throttles");
9 setmetatable(zones, { 10 local active_zones, active_throttles = {}, {};
10 __index = function (zones, zone)
11 local t = { [zone] = true };
12 rawset(zones, zone, t);
13 return t;
14 end;
15 });
16 11
17 local chains = { 12 local chains = {
18 preroute = { 13 preroute = {
19 type = "event"; 14 type = "event";
20 priority = 0.1; 15 priority = 0.1;
32 deliver_remote = { 27 deliver_remote = {
33 type = "event"; "route/remote"; 28 type = "event"; "route/remote";
34 priority = 0.1; 29 priority = 0.1;
35 }; 30 };
36 }; 31 };
32
33 local function idsafe(name)
34 return not not name:match("^%a[%w_]*$")
35 end
37 36
38 -- Dependency locations: 37 -- Dependency locations:
39 -- <type lib> 38 -- <type lib>
40 -- <type global> 39 -- <type global>
41 -- function handler() 40 -- function handler()
71 global_code = [[local group_contains = module:depends("groups").group_contains]]; 70 global_code = [[local group_contains = module:depends("groups").group_contains]];
72 }; 71 };
73 is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]}; 72 is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]};
74 core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] }; 73 core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] };
75 zone = { global_code = function (zone) 74 zone = { global_code = function (zone)
76 assert(zone:match("^%a[%w_]*$"), "Invalid zone name: "..zone); 75 assert(idsafe(zone), "Invalid zone name: "..zone);
77 return ("local zone_%s = zones[%q] or {};"):format(zone, zone); 76 return ("local zone_%s = zones[%q] or {};"):format(zone, zone);
78 end }; 77 end };
79 date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] }; 78 date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] };
80 time = { local_code = function (what) 79 time = { local_code = function (what)
81 local defs = {}; 80 local defs = {};
82 for field in what:gmatch("%a+") do 81 for field in what:gmatch("%a+") do
83 table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field)); 82 table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field));
84 end 83 end
85 return table.concat(defs, " "); 84 return table.concat(defs, " ");
86 end, depends = { "date_time" }; }; 85 end, depends = { "date_time" }; };
86 throttle = {
87 global_code = function (throttle)
88 assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
89 assert(throttles[throttle], "Unknown rate limit: "..throttle);
90 return ("local throttle_%s = throttles.%s;"):format(throttle, throttle);
91 end;
92 };
87 }; 93 };
88 94
89 local function include_dep(dep, code) 95 local function include_dep(dep, code)
90 local dep, dep_param = dep:match("^([^:]+):?(.*)$"); 96 local dep, dep_param = dep:match("^([^:]+):?(.*)$");
91 local dep_info = available_deps[dep]; 97 local dep_info = available_deps[dep];
186 local zone_member_list = {}; 192 local zone_member_list = {};
187 for member in zone_members:gmatch("[^, ]+") do 193 for member in zone_members:gmatch("[^, ]+") do
188 zone_member_list[#zone_member_list+1] = member; 194 zone_member_list[#zone_member_list+1] = member;
189 end 195 end
190 zones[zone_name] = set.new(zone_member_list)._items; 196 zones[zone_name] = set.new(zone_member_list)._items;
197 table.insert(active_zones, zone_name);
198 elseif not(state) and line:match("^RATE ") then
199 local name = line:match("^RATE ([^:]+)");
200 assert(idsafe(name), "Invalid rate limit name: "..name);
201 local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate");
202 local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1;
203 throttles[name] = new_throttle(rate*burst, burst);
204 table.insert(active_throttles, name);
191 elseif line:match("^[^%s:]+[%.=]") then 205 elseif line:match("^[^%s:]+[%.=]") then
192 -- Action 206 -- Action
193 if state == nil then 207 if state == nil then
194 -- This is a standalone action with no conditions 208 -- This is a standalone action with no conditions
195 rule = new_rule(ruleset, chain); 209 rule = new_rule(ruleset, chain);
263 .."\n end\n"; 277 .."\n end\n";
264 end 278 end
265 table.insert(code, rule_code); 279 table.insert(code, rule_code);
266 end 280 end
267 281
268 local code_string = [[return function (zones, fire_event, log) 282 local code_string = [[return function (zones, throttles, fire_event, log)
269 ]]..table.concat(code.global_header, "\n")..[[ 283 ]]..table.concat(code.global_header, "\n")..[[
270 local db = require 'util.debug' 284 local db = require 'util.debug'
271 return function (event) 285 return function (event)
272 local stanza, session = event.stanza, event.origin; 286 local stanza, session = event.stanza, event.origin;
273 287
289 return nil, "Error compiling (probably a compiler bug, please report): "..err; 303 return nil, "Error compiling (probably a compiler bug, please report): "..err;
290 end 304 end
291 local function fire_event(name, data) 305 local function fire_event(name, data)
292 return module:fire_event(name, data); 306 return module:fire_event(name, data);
293 end 307 end
294 chunk = chunk()(zones, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. 308 chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
295 return chunk; 309 return chunk;
296 end 310 end
297 311
312 local function cleanup(t, active_list)
313 local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list);
314 for k in unused do t[k] = nil; end
315 end
316
298 function module.load() 317 function module.load()
318 active_zones, active_throttles = {}, {};
299 local firewall_scripts = module:get_option_set("firewall_scripts", {}); 319 local firewall_scripts = module:get_option_set("firewall_scripts", {});
300 for script in firewall_scripts do 320 for script in firewall_scripts do
301 script = resolve_relative_path(prosody.paths.config, script); 321 script = resolve_relative_path(prosody.paths.config, script);
302 local chain_functions, err = compile_firewall_rules(script) 322 local chain_functions, err = compile_firewall_rules(script)
303 323
320 module:hook("firewall/chains/"..chain, handler); 340 module:hook("firewall/chains/"..chain, handler);
321 end 341 end
322 end 342 end
323 end 343 end
324 end 344 end
325 end 345 -- Remove entries from tables that are no longer in use
346 cleanup(zones, active_zones);
347 cleanup(throttles, active_throttles);
348 end