Software /
code /
prosody-modules
File
mod_firewall/mod_firewall.lua @ 2494:d300ae5dba87
mod_smacks: Fix some bugs with smacks-ack-delayed event triggering.
The old code had several flaws which are addressed here.
First of all this fixes the if statement guarding the event generation
There where some timing glitches addressed by this commit as well.
author | tmolitor <thilo@eightysoft.de> |
---|---|
date | Sun, 12 Feb 2017 21:23:22 +0100 |
parent | 2464:01babf1caa4a |
child | 2517:3990b1bca308 |
line wrap: on
line source
local resolve_relative_path = require "core.configmanager".resolve_relative_path; local logger = require "util.logger".init; local it = require "util.iterators"; local definitions = module:shared("definitions"); local active_definitions = { ZONE = { -- Default zone that includes all local hosts ["$local"] = setmetatable({}, { __index = prosody.hosts }); } }; local default_chains = { preroute = { type = "event"; priority = 0.1; "pre-message/bare", "pre-message/full", "pre-message/host"; "pre-presence/bare", "pre-presence/full", "pre-presence/host"; "pre-iq/bare", "pre-iq/full", "pre-iq/host"; }; deliver = { type = "event"; priority = 0.1; "message/bare", "message/full", "message/host"; "presence/bare", "presence/full", "presence/host"; "iq/bare", "iq/full", "iq/host"; }; deliver_remote = { type = "event"; "route/remote"; priority = 0.1; }; }; local extra_chains = module:get_option("firewall_extra_chains", {}); local chains = {}; for k,v in pairs(default_chains) do chains[k] = v; end for k,v in pairs(extra_chains) do chains[k] = v; end -- Returns the input if it is safe to be used as a variable name, otherwise nil function idsafe(name) return name:match("^%a[%w_]*$") end -- Run quoted (%q) strings through this to allow them to contain code. e.g.: LOG=Received: $(stanza:top_tag()) function meta(s, extra) return (s:gsub("$(%b())", function (expr) expr = expr:gsub("\\(.)", "%1"); return [["..tostring(]]..expr..[[).."]]; end) :gsub("$(%b<>)", function (expr) expr = expr:sub(2,-2); local default = expr:match("||([^|]+)$"); if default then expr = expr:sub(1, -(#default+2)); else default = "<undefined>"; end if expr:match("^@") then return "\"..(stanza.attr["..("%q"):format(expr:sub(2)).."] or "..("%q"):format(default)..")..\""; end return "\"..(stanza:find("..("%q"):format(expr)..") or "..("%q"):format(default)..")..\""; end) :gsub("$$(%a+)", extra or {}) :gsub([[^""%.%.]], "") :gsub([[%.%.""$]], "")); end -- Dependency locations: -- <type lib> -- <type global> -- function handler() -- <local deps> -- if <conditions> then -- <actions> -- end -- end local available_deps = { st = { global_code = [[local st = require "util.stanza";]]}; jid_split = { global_code = [[local jid_split = require "util.jid".split;]]; }; jid_bare = { global_code = [[local jid_bare = require "util.jid".bare;]]; }; to = { local_code = [[local to = stanza.attr.to or jid_bare(session.full_jid);]]; depends = { "jid_bare" } }; from = { local_code = [[local from = stanza.attr.from;]] }; type = { local_code = [[local type = stanza.attr.type;]] }; name = { local_code = [[local name = stanza.name;]] }; split_to = { -- The stanza's split to address depends = { "jid_split", "to" }; local_code = [[local to_node, to_host, to_resource = jid_split(to);]]; }; split_from = { -- The stanza's split from address depends = { "jid_split", "from" }; local_code = [[local from_node, from_host, from_resource = jid_split(from);]]; }; bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"}; bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"}; group_contains = { global_code = [[local group_contains = module:depends("groups").group_contains]]; }; is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin;]]}; core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza;]] }; zone = { global_code = function (zone) assert(idsafe(zone), "Invalid zone name: "..zone); return ("local zone_%s = zones[%q] or {};"):format(zone, zone); end }; date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] }; time = { local_code = function (what) local defs = {}; for field in what:gmatch("%a+") do table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field)); end return table.concat(defs, " "); end, depends = { "date_time" }; }; timestamp = { global_code = [[local get_time = require "socket".gettime;]]; local_code = [[local current_timestamp = get_time();]]; }; globalthrottle = { global_code = function (throttle) assert(idsafe(throttle), "Invalid rate limit name: "..throttle); assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle); return ("local global_throttle_%s = rates.%s:single();"):format(throttle, throttle); end; }; multithrottle = { global_code = function (throttle) assert(pcall(require, "util.cache"), "Using LIMIT with 'on' requires Prosody 0.10 or higher"); assert(idsafe(throttle), "Invalid rate limit name: "..throttle); assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle); return ("local multi_throttle_%s = rates.%s:multi();"):format(throttle, throttle); end; }; rostermanager = { global_code = [[local rostermanager = require "core.rostermanager";]]; }; roster_entry = { local_code = [[local roster_entry = (to_node and rostermanager.load_roster(to_node, to_host) or {})[bare_from];]]; depends = { "rostermanager", "split_to", "bare_from" }; }; }; local function include_dep(dependency, code) local dep, dep_param = dependency:match("^([^:]+):?(.*)$"); local dep_info = available_deps[dep]; if not dep_info then module:log("error", "Dependency not found: %s", dep); return; end if code.included_deps[dep] then if code.included_deps[dep] ~= true then module:log("error", "Circular dependency on %s", dep); end return; end code.included_deps[dep] = false; -- Pending flag (used to detect circular references) for _, dep_dep in ipairs(dep_info.depends or {}) do include_dep(dep_dep, code); end if dep_info.global_code then if dep_param ~= "" then table.insert(code.global_header, dep_info.global_code(dep_param)); else table.insert(code.global_header, dep_info.global_code); end end if dep_info.local_code then if dep_param ~= "" then table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..dep_info.local_code(dep_param).."\n"); else table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..dep_info.local_code.."\n"); end end code.included_deps[dep] = true; end local definition_handlers = module:require("definitions"); local condition_handlers = module:require("conditions"); local action_handlers = module:require("actions"); local function new_rule(ruleset, chain) assert(chain, "no chain specified"); local rule = { conditions = {}, actions = {}, deps = {} }; table.insert(ruleset[chain], rule); return rule; end local function parse_firewall_rules(filename) local line_no = 0; local function errmsg(err) return "Error compiling "..filename.." on line "..line_no..": "..err; end local ruleset = { deliver = {}; }; local chain = "deliver"; -- Default chain local rule; local file, err = io.open(filename); if not file then return nil, err; end local state; -- nil -> "rules" -> "actions" -> nil -> ... local line_hold; for line in file:lines() do line = line:match("^%s*(.-)%s*$"); if line_hold and line:sub(-1,-1) ~= "\\" then line = line_hold..line; line_hold = nil; elseif line:sub(-1,-1) == "\\" then line_hold = (line_hold or "")..line:sub(1,-2); end line_no = line_no + 1; if line_hold or line:find("^[#;]") then -- luacheck: ignore 542 -- No action; comment or partial line elseif line == "" then if state == "rules" then return nil, ("Expected an action on line %d for preceding criteria") :format(line_no); end state = nil; elseif not(state) and line:sub(1, 2) == "::" then chain = line:gsub("^::%s*", ""); local chain_info = chains[chain]; if not chain_info then if chain:match("^user/") then chains[chain] = { type = "event", priority = 1, "firewall/chains/"..chain }; else return nil, errmsg("Unknown chain: "..chain); end elseif chain_info.type ~= "event" then return nil, errmsg("Only event chains supported at the moment"); end ruleset[chain] = ruleset[chain] or {}; elseif not(state) and line:sub(1,1) == "%" then -- Definition (zone, limit, etc.) local what, name = line:match("^%%%s*(%w+) +([^ :]+)"); if not definition_handlers[what] then return nil, errmsg("Definition of unknown object: "..what); elseif not name or not idsafe(name) then return nil, errmsg("Invalid "..what.." name"); end local val = line:match(": ?(.*)$"); if not val and line:find(":<") then -- Read from file local fn = line:match(":< ?(.-)%s*$"); if not fn then return nil, errmsg("Unable to parse filename"); end local f, err = io.open(fn); if not f then return nil, errmsg(err); end val = f:read("*a"):gsub("\r?\n", " "):gsub("%s+5", ""); end if not val then return nil, errmsg("No value given for definition"); end local ok, ret = pcall(definition_handlers[what], name, val); if not ok then return nil, errmsg(ret); end if not active_definitions[what] then active_definitions[what] = {}; end active_definitions[what][name] = ret; elseif line:find("^[^%s:]+[%.=]") then -- Action if state == nil then -- This is a standalone action with no conditions rule = new_rule(ruleset, chain); end state = "actions"; -- Action handlers? local action = line:match("^[%w_]+"); if not action_handlers[action] then return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>"); end table.insert(rule.actions, "-- "..line) local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$")); if not ok then return nil, errmsg(action_string); end table.insert(rule.actions, action_string); for _, dep in ipairs(action_deps or {}) do table.insert(rule.deps, dep); end elseif state == "actions" then -- state is actions but action pattern did not match state = nil; -- Awaiting next rule, etc. table.insert(ruleset[chain], rule); rule = nil; else if not state then state = "rules"; rule = new_rule(ruleset, chain); end -- Check standard modifiers for the condition (e.g. NOT) local negated; local condition = line:match("^[^:=%.?]*"); if condition:find("%f[%w]NOT%f[^%w]") then local s, e = condition:match("%f[%w]()NOT()%f[^%w]"); condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$"); negated = true; end condition = condition:gsub(" ", "_"); if not condition_handlers[condition] then return nil, ("Unknown condition on line %d: %s"):format(line_no, (condition:gsub("_", " "))); end -- Get the code for this condition local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$")); if not ok then return nil, errmsg(condition_code); end if negated then condition_code = "not("..condition_code..")"; end table.insert(rule.conditions, condition_code); for _, dep in ipairs(condition_deps or {}) do table.insert(rule.deps, dep); end end end return ruleset; end local function process_firewall_rules(ruleset) -- Compile ruleset and return complete code local chain_handlers = {}; -- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing) for chain_name, rules in pairs(ruleset) do local code = { included_deps = {}, global_header = {} }; local condition_uses = {}; -- This inner loop assumes chain is an event-based, not a filter-based -- chain (filter-based will be added later) for _, rule in ipairs(rules) do for _, condition in ipairs(rule.conditions) do if condition:find("^not%(.+%)$") then condition = condition:match("^not%((.+)%)$"); end condition_uses[condition] = (condition_uses[condition] or 0) + 1; end end local condition_cache, n_conditions = {}, 0; for _, rule in ipairs(rules) do for _, dep in ipairs(rule.deps) do include_dep(dep, code); end table.insert(code, "\n\t\t"); local rule_code; if #rule.conditions > 0 then for i, condition in ipairs(rule.conditions) do local negated = condition:match("^not%(.+%)$"); if negated then condition = condition:match("^not%((.+)%)$"); end if condition_uses[condition] > 1 then local name = condition_cache[condition]; if not name then n_conditions = n_conditions + 1; name = "condition"..n_conditions; condition_cache[condition] = name; table.insert(code, "local "..name.." = "..condition..";\n\t\t"); end rule.conditions[i] = (negated and "not(" or "")..name..(negated and ")" or ""); else rule.conditions[i] = (negated and "not(" or "(")..condition..")"; end end rule_code = "if "..table.concat(rule.conditions, " and ").." then\n\t\t\t" ..table.concat(rule.actions, "\n\t\t\t") .."\n\t\tend\n"; else rule_code = table.concat(rule.actions, "\n\t\t"); end table.insert(code, rule_code); end for name in pairs(definition_handlers) do table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";"); end local code_string = "return function (definitions, fire_event, log)\n\t" ..table.concat(code.global_header, "\n\t") .."\n\tlocal db = require 'util.debug';\n\n\t" .."return function (event)\n\t\t" .."local stanza, session = event.stanza, event.origin;\n" ..table.concat(code, "") .."\n\tend;\nend"; chain_handlers[chain_name] = code_string; end return chain_handlers; end local function compile_firewall_rules(filename) local ruleset, err = parse_firewall_rules(filename); if not ruleset then return nil, err; end local chain_handlers = process_firewall_rules(ruleset); return chain_handlers; end local function compile_handler(code_string, filename) -- Prepare event handler function local chunk, err = loadstring(code_string, "="..filename); if not chunk then return nil, "Error compiling (probably a compiler bug, please report): "..err; end local function fire_event(name, data) return module:fire_event(name, data); end chunk = chunk()(active_definitions, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. return chunk; end local function resolve_script_path(script_path) local relative_to = prosody.paths.config; if script_path:match("^module:") then relative_to = module.path:sub(1, -#("/mod_"..module.name..".lua")); script_path = script_path:match("^module:(.+)$"); end return resolve_relative_path(relative_to, script_path); end function module.load() if not prosody.arg then return end -- Don't run in prosodyctl active_definitions = {}; local firewall_scripts = module:get_option_set("firewall_scripts", {}); for script in firewall_scripts do script = resolve_script_path(script); local chain_functions, err = compile_firewall_rules(script) if not chain_functions then module:log("error", "Error compiling %s: %s", script, err or "unknown error"); else for chain, handler_code in pairs(chain_functions) do local handler, err = compile_handler(handler_code, "mod_firewall::"..chain); if not handler then module:log("error", "Compilation error for %s: %s", script, err); else local chain_definition = chains[chain]; if chain_definition and chain_definition.type == "event" then for _, event_name in ipairs(chain_definition) do module:hook(event_name, handler, chain_definition.priority); end elseif not chain:sub(1, 5) == "user/" then module:log("warn", "Unknown chain %q", chain); end module:hook("firewall/chains/"..chain, handler); end end end end -- Replace contents of definitions table (shared) with active definitions for k in it.keys(definitions) do definitions[k] = nil; end for k,v in pairs(active_definitions) do definitions[k] = v; end end function module.command(arg) if not arg[1] or arg[1] == "--help" then require"util.prosodyctl".show_usage([[mod_firewall <firewall.pfw>]], [[Compile files with firewall rules to Lua code]]); return 1; end local verbose = arg[1] == "-v"; if verbose then table.remove(arg, 1); end local serialize = require "util.serialization".serialize; if verbose then print("local logger = require \"util.logger\".init;"); print(); print("local function fire_event(name, data)\n\tmodule:fire_event(name, data)\nend"); print(); end for _, filename in ipairs(arg) do filename = resolve_script_path(filename); print("do -- File "..filename); local chain_functions = assert(compile_firewall_rules(filename)); if verbose then print(); print("local active_definitions = "..serialize(active_definitions)..";"); print(); end local c = 0; for chain, handler_code in pairs(chain_functions) do c = c + 1; print("---- Chain "..chain:gsub("_", " ")); local chain_func_name = "chain_"..tostring(c).."_"..chain:gsub("%p", "_"); if not verbose then print(("%s = %s;"):format(chain_func_name, handler_code:sub(8))); else print(("local %s = (%s)(active_definitions, fire_event, logger(%q));"):format(chain_func_name, handler_code:sub(8), filename)); print(); local chain_definition = chains[chain]; if chain_definition and chain_definition.type == "event" then for _, event_name in ipairs(chain_definition) do print(("module:hook(%q, %s, %d);"):format(event_name, chain_func_name, chain_definition.priority or 0)); end end print(("module:hook(%q, %s, %d);"):format("firewall/chains/"..chain, chain_func_name, chain_definition.priority or 0)); end print("---- End of chain "..chain); print(); end print("end -- End of file "..filename); end end