Software /
code /
prosody-modules
Comparison
mod_firewall/mod_firewall.lua @ 971:53e158e44a44
mod_firewall: Add rate limiting capabilities, and keep zones and throttle objects in shared tables
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Sat, 06 Apr 2013 22:20:59 +0100 |
parent | 967:a88f33fe6970 |
child | 980:aeb11522a44f |
comparison
equal
deleted
inserted
replaced
970:adcb751f22f3 | 971:53e158e44a44 |
---|---|
1 | 1 |
2 local resolve_relative_path = require "core.configmanager".resolve_relative_path; | 2 local resolve_relative_path = require "core.configmanager".resolve_relative_path; |
3 local logger = require "util.logger".init; | 3 local logger = require "util.logger".init; |
4 local set = require "util.set"; | 4 local set = require "util.set"; |
5 local it = require "util.iterators"; | |
5 local add_filter = require "util.filters".add_filter; | 6 local add_filter = require "util.filters".add_filter; |
6 | 7 local new_throttle = require "util.throttle".create; |
7 zones = {}; | 8 |
8 local zones = zones; | 9 local zones, throttles = module:shared("zones", "throttles"); |
9 setmetatable(zones, { | 10 local active_zones, active_throttles = {}, {}; |
10 __index = function (zones, zone) | |
11 local t = { [zone] = true }; | |
12 rawset(zones, zone, t); | |
13 return t; | |
14 end; | |
15 }); | |
16 | 11 |
17 local chains = { | 12 local chains = { |
18 preroute = { | 13 preroute = { |
19 type = "event"; | 14 type = "event"; |
20 priority = 0.1; | 15 priority = 0.1; |
32 deliver_remote = { | 27 deliver_remote = { |
33 type = "event"; "route/remote"; | 28 type = "event"; "route/remote"; |
34 priority = 0.1; | 29 priority = 0.1; |
35 }; | 30 }; |
36 }; | 31 }; |
32 | |
33 local function idsafe(name) | |
34 return not not name:match("^%a[%w_]*$") | |
35 end | |
37 | 36 |
38 -- Dependency locations: | 37 -- Dependency locations: |
39 -- <type lib> | 38 -- <type lib> |
40 -- <type global> | 39 -- <type global> |
41 -- function handler() | 40 -- function handler() |
71 global_code = [[local group_contains = module:depends("groups").group_contains]]; | 70 global_code = [[local group_contains = module:depends("groups").group_contains]]; |
72 }; | 71 }; |
73 is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]}; | 72 is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]}; |
74 core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] }; | 73 core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] }; |
75 zone = { global_code = function (zone) | 74 zone = { global_code = function (zone) |
76 assert(zone:match("^%a[%w_]*$"), "Invalid zone name: "..zone); | 75 assert(idsafe(zone), "Invalid zone name: "..zone); |
77 return ("local zone_%s = zones[%q] or {};"):format(zone, zone); | 76 return ("local zone_%s = zones[%q] or {};"):format(zone, zone); |
78 end }; | 77 end }; |
79 date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] }; | 78 date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] }; |
80 time = { local_code = function (what) | 79 time = { local_code = function (what) |
81 local defs = {}; | 80 local defs = {}; |
82 for field in what:gmatch("%a+") do | 81 for field in what:gmatch("%a+") do |
83 table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field)); | 82 table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field)); |
84 end | 83 end |
85 return table.concat(defs, " "); | 84 return table.concat(defs, " "); |
86 end, depends = { "date_time" }; }; | 85 end, depends = { "date_time" }; }; |
86 throttle = { | |
87 global_code = function (throttle) | |
88 assert(idsafe(throttle), "Invalid rate limit name: "..throttle); | |
89 assert(throttles[throttle], "Unknown rate limit: "..throttle); | |
90 return ("local throttle_%s = throttles.%s;"):format(throttle, throttle); | |
91 end; | |
92 }; | |
87 }; | 93 }; |
88 | 94 |
89 local function include_dep(dep, code) | 95 local function include_dep(dep, code) |
90 local dep, dep_param = dep:match("^([^:]+):?(.*)$"); | 96 local dep, dep_param = dep:match("^([^:]+):?(.*)$"); |
91 local dep_info = available_deps[dep]; | 97 local dep_info = available_deps[dep]; |
186 local zone_member_list = {}; | 192 local zone_member_list = {}; |
187 for member in zone_members:gmatch("[^, ]+") do | 193 for member in zone_members:gmatch("[^, ]+") do |
188 zone_member_list[#zone_member_list+1] = member; | 194 zone_member_list[#zone_member_list+1] = member; |
189 end | 195 end |
190 zones[zone_name] = set.new(zone_member_list)._items; | 196 zones[zone_name] = set.new(zone_member_list)._items; |
197 table.insert(active_zones, zone_name); | |
198 elseif not(state) and line:match("^RATE ") then | |
199 local name = line:match("^RATE ([^:]+)"); | |
200 assert(idsafe(name), "Invalid rate limit name: "..name); | |
201 local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate"); | |
202 local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1; | |
203 throttles[name] = new_throttle(rate*burst, burst); | |
204 table.insert(active_throttles, name); | |
191 elseif line:match("^[^%s:]+[%.=]") then | 205 elseif line:match("^[^%s:]+[%.=]") then |
192 -- Action | 206 -- Action |
193 if state == nil then | 207 if state == nil then |
194 -- This is a standalone action with no conditions | 208 -- This is a standalone action with no conditions |
195 rule = new_rule(ruleset, chain); | 209 rule = new_rule(ruleset, chain); |
263 .."\n end\n"; | 277 .."\n end\n"; |
264 end | 278 end |
265 table.insert(code, rule_code); | 279 table.insert(code, rule_code); |
266 end | 280 end |
267 | 281 |
268 local code_string = [[return function (zones, fire_event, log) | 282 local code_string = [[return function (zones, throttles, fire_event, log) |
269 ]]..table.concat(code.global_header, "\n")..[[ | 283 ]]..table.concat(code.global_header, "\n")..[[ |
270 local db = require 'util.debug' | 284 local db = require 'util.debug' |
271 return function (event) | 285 return function (event) |
272 local stanza, session = event.stanza, event.origin; | 286 local stanza, session = event.stanza, event.origin; |
273 | 287 |
289 return nil, "Error compiling (probably a compiler bug, please report): "..err; | 303 return nil, "Error compiling (probably a compiler bug, please report): "..err; |
290 end | 304 end |
291 local function fire_event(name, data) | 305 local function fire_event(name, data) |
292 return module:fire_event(name, data); | 306 return module:fire_event(name, data); |
293 end | 307 end |
294 chunk = chunk()(zones, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. | 308 chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. |
295 return chunk; | 309 return chunk; |
296 end | 310 end |
297 | 311 |
312 local function cleanup(t, active_list) | |
313 local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list); | |
314 for k in unused do t[k] = nil; end | |
315 end | |
316 | |
298 function module.load() | 317 function module.load() |
318 active_zones, active_throttles = {}, {}; | |
299 local firewall_scripts = module:get_option_set("firewall_scripts", {}); | 319 local firewall_scripts = module:get_option_set("firewall_scripts", {}); |
300 for script in firewall_scripts do | 320 for script in firewall_scripts do |
301 script = resolve_relative_path(prosody.paths.config, script); | 321 script = resolve_relative_path(prosody.paths.config, script); |
302 local chain_functions, err = compile_firewall_rules(script) | 322 local chain_functions, err = compile_firewall_rules(script) |
303 | 323 |
320 module:hook("firewall/chains/"..chain, handler); | 340 module:hook("firewall/chains/"..chain, handler); |
321 end | 341 end |
322 end | 342 end |
323 end | 343 end |
324 end | 344 end |
325 end | 345 -- Remove entries from tables that are no longer in use |
346 cleanup(zones, active_zones); | |
347 cleanup(throttles, active_throttles); | |
348 end |