Software /
code /
prosody-modules
File
mod_firewall/mod_firewall.lua @ 996:37af655ca575
mod_firewall: Cache conditions, so that they are only calculated once per chain execution
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Tue, 07 May 2013 09:28:20 +0100 |
parent | 980:aeb11522a44f |
child | 998:6fdcebbd2284 |
line wrap: on
line source
local resolve_relative_path = require "core.configmanager".resolve_relative_path; local logger = require "util.logger".init; local set = require "util.set"; local it = require "util.iterators"; local add_filter = require "util.filters".add_filter; local new_throttle = require "util.throttle".create; local zones, throttles = module:shared("zones", "throttles"); local active_zones, active_throttles = {}, {}; local chains = { preroute = { type = "event"; priority = 0.1; "pre-message/bare", "pre-message/full", "pre-message/host"; "pre-presence/bare", "pre-presence/full", "pre-presence/host"; "pre-iq/bare", "pre-iq/full", "pre-iq/host"; }; deliver = { type = "event"; priority = 0.1; "message/bare", "message/full", "message/host"; "presence/bare", "presence/full", "presence/host"; "iq/bare", "iq/full", "iq/host"; }; deliver_remote = { type = "event"; "route/remote"; priority = 0.1; }; }; local function idsafe(name) return not not name:match("^%a[%w_]*$") end -- Dependency locations: -- <type lib> -- <type global> -- function handler() -- <local deps> -- if <conditions> then -- <actions> -- end -- end local available_deps = { st = { global_code = [[local st = require "util.stanza"]]}; jid_split = { global_code = [[local jid_split = require "util.jid".split;]]; }; jid_bare = { global_code = [[local jid_bare = require "util.jid".bare;]]; }; to = { local_code = [[local to = stanza.attr.to;]] }; from = { local_code = [[local from = stanza.attr.from;]] }; type = { local_code = [[local type = stanza.attr.type;]] }; name = { local_code = [[local name = stanza.name]] }; split_to = { -- The stanza's split to address depends = { "jid_split", "to" }; local_code = [[local to_node, to_host, to_resource = jid_split(to);]]; }; split_from = { -- The stanza's split from address depends = { "jid_split", "from" }; local_code = [[local from_node, from_host, from_resource = jid_split(from);]]; }; bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"}; bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"}; group_contains = { global_code = [[local group_contains = module:depends("groups").group_contains]]; }; is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]}; core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] }; zone = { global_code = function (zone) assert(idsafe(zone), "Invalid zone name: "..zone); return ("local zone_%s = zones[%q] or {};"):format(zone, zone); end }; date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] }; time = { local_code = function (what) local defs = {}; for field in what:gmatch("%a+") do table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field)); end return table.concat(defs, " "); end, depends = { "date_time" }; }; throttle = { global_code = function (throttle) assert(idsafe(throttle), "Invalid rate limit name: "..throttle); assert(throttles[throttle], "Unknown rate limit: "..throttle); return ("local throttle_%s = throttles.%s;"):format(throttle, throttle); end; }; }; local function include_dep(dep, code) local dep, dep_param = dep:match("^([^:]+):?(.*)$"); local dep_info = available_deps[dep]; if not dep_info then module:log("error", "Dependency not found: %s", dep); return; end if code.included_deps[dep] then if code.included_deps[dep] ~= true then module:log("error", "Circular dependency on %s", dep); end return; end code.included_deps[dep] = false; -- Pending flag (used to detect circular references) for _, dep_dep in ipairs(dep_info.depends or {}) do include_dep(dep_dep, code); end if dep_info.global_code then if dep_param ~= "" then table.insert(code.global_header, dep_info.global_code(dep_param)); else table.insert(code.global_header, dep_info.global_code); end end if dep_info.local_code then if dep_param ~= "" then table.insert(code, "\n\t-- "..dep.."\n\t"..dep_info.local_code(dep_param).."\n\n\t"); else table.insert(code, "\n\t-- "..dep.."\n\t"..dep_info.local_code.."\n\n\t"); end end code.included_deps[dep] = true; end local condition_handlers = module:require("conditions"); local action_handlers = module:require("actions"); local function new_rule(ruleset, chain) assert(chain, "no chain specified"); local rule = { conditions = {}, actions = {}, deps = {} }; table.insert(ruleset[chain], rule); return rule; end local function compile_firewall_rules(filename) local line_no = 0; local function errmsg(err) return "Error compiling "..filename.." on line "..line_no..": "..err; end local ruleset = { deliver = {}; }; local chain = "deliver"; -- Default chain local rule; local file, err = io.open(filename); if not file then return nil, err; end local state; -- nil -> "rules" -> "actions" -> nil -> ... local line_hold; for line in file:lines() do line = line:match("^%s*(.-)%s*$"); if line_hold and line:sub(-1,-1) ~= "\\" then line = line_hold..line; line_hold = nil; elseif line:sub(-1,-1) == "\\" then line_hold = (line_hold or "")..line:sub(1,-2); end line_no = line_no + 1; if line_hold or line:match("^[#;]") then -- No action; comment or partial line elseif line == "" then if state == "rules" then return nil, ("Expected an action on line %d for preceding criteria") :format(line_no); end state = nil; elseif not(state) and line:match("^::") then chain = line:gsub("^::%s*", ""); local chain_info = chains[chain]; if not chain_info then return nil, errmsg("Unknown chain: "..chain); elseif chain_info.type ~= "event" then return nil, errmsg("Only event chains supported at the moment"); end ruleset[chain] = ruleset[chain] or {}; elseif not(state) and line:match("^ZONE ") then local zone_name = line:match("^ZONE ([^:]+)"); if not zone_name:match("^%a[%w_]*$") then return nil, errmsg("Invalid character(s) in zone name: "..zone_name); end local zone_members = line:match("^ZONE .-: ?(.*)"); local zone_member_list = {}; for member in zone_members:gmatch("[^, ]+") do zone_member_list[#zone_member_list+1] = member; end zones[zone_name] = set.new(zone_member_list)._items; table.insert(active_zones, zone_name); elseif not(state) and line:match("^RATE ") then local name = line:match("^RATE ([^:]+)"); assert(idsafe(name), "Invalid rate limit name: "..name); local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate"); local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1; throttles[name] = new_throttle(rate*burst, burst); table.insert(active_throttles, name); elseif line:match("^[^%s:]+[%.=]") then -- Action if state == nil then -- This is a standalone action with no conditions rule = new_rule(ruleset, chain); end state = "actions"; -- Action handlers? local action = line:match("^%P+"); if not action_handlers[action] then return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>"); end table.insert(rule.actions, "-- "..line) local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$")); if not ok then return nil, errmsg(action_string); end table.insert(rule.actions, action_string); for _, dep in ipairs(action_deps or {}) do table.insert(rule.deps, dep); end elseif state == "actions" then -- state is actions but action pattern did not match state = nil; -- Awaiting next rule, etc. table.insert(ruleset[chain], rule); rule = nil; else if not state then state = "rules"; rule = new_rule(ruleset, chain); end -- Check standard modifiers for the condition (e.g. NOT) local negated; local condition = line:match("^[^:=%.]*"); if condition:match("%f[%w]NOT%f[^%w]") then local s, e = condition:match("%f[%w]()NOT()%f[^%w]"); condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$"); negated = true; end condition = condition:gsub(" ", ""); if not condition_handlers[condition] then return nil, ("Unknown condition on line %d: %s"):format(line_no, condition); end -- Get the code for this condition local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$")); if not ok then return nil, errmsg(condition_code); end if negated then condition_code = "not("..condition_code..")"; end table.insert(rule.conditions, condition_code); for _, dep in ipairs(condition_deps or {}) do table.insert(rule.deps, dep); end end end -- Compile ruleset and return complete code local chain_handlers = {}; -- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing) for chain_name, rules in pairs(ruleset) do local code = { included_deps = {}, global_header = {} }; local condition_cache, n_conditions = {}, 0; -- This inner loop assumes chain is an event-based, not a filter-based -- chain (filter-based will be added later) for _, rule in ipairs(rules) do for _, dep in ipairs(rule.deps) do include_dep(dep, code); end local rule_code = table.concat(rule.actions, "\n\t"); if #rule.conditions > 0 then for i, condition in ipairs(rule.conditions) do local negated = condition:match("^not%b()$"); if negated then condition = condition:match("^not%((.+)%)$"); end if condition_cache[condition] then rule.conditions[i] = (negated and "not(" or "")..condition_cache[condition]..(negated and "_" or ""); else n_conditions = n_conditions + 1; local name = "condition"..n_conditions; condition_cache[condition] = name; table.insert(code, "local "..name.." = "..condition..";\n\t"); rule.conditions[i] = (negated and "not(" or "")..name..(negated and ")" or ""); end end rule_code = "if "..table.concat(rule.conditions, " and ").." then\n\t" ..rule_code .."\n end\n"; end table.insert(code, rule_code); end local code_string = [[return function (zones, throttles, fire_event, log) ]]..table.concat(code.global_header, "\n")..[[ local db = require 'util.debug' return function (event) local stanza, session = event.stanza, event.origin; ]]..table.concat(code, " ")..[[ end; end]]; chain_handlers[chain_name] = code_string; end return chain_handlers; end local function compile_handler(code_string, filename) print(code_string) -- Prepare event handler function local chunk, err = loadstring(code_string, "="..filename); if not chunk then return nil, "Error compiling (probably a compiler bug, please report): "..err; end local function fire_event(name, data) return module:fire_event(name, data); end chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. return chunk; end local function cleanup(t, active_list) local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list); for k in unused do t[k] = nil; end end function module.load() active_zones, active_throttles = {}, {}; local firewall_scripts = module:get_option_set("firewall_scripts", {}); for script in firewall_scripts do script = resolve_relative_path(prosody.paths.config, script); local chain_functions, err = compile_firewall_rules(script) if not chain_functions then module:log("error", "Error compiling %s: %s", script, err or "unknown error"); else for chain, handler_code in pairs(chain_functions) do local handler, err = compile_handler(handler_code, "mod_firewall::"..chain); if not handler then module:log("error", "Compilation error for %s: %s", script, err); else local chain_definition = chains[chain]; if chain_definition and chain_definition.type == "event" then for _, event_name in ipairs(chain_definition) do module:hook(event_name, handler, chain_definition.priority); end elseif not chain:match("^user/") then module:log("warn", "Unknown chain %q", chain); end module:hook("firewall/chains/"..chain, handler); end end end end -- Remove entries from tables that are no longer in use cleanup(zones, active_zones); cleanup(throttles, active_throttles); end