Diff

net/unbound.lua @ 10962:92f30e8ecdfc

net.unbound: Async DNS resolver library based on libunbound via luaunbound
author Kim Alvefur <zash@zash.se>
date Sat, 09 Mar 2019 21:19:24 +0100
child 10967:67aabf83230b
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/net/unbound.lua	Sat Mar 09 21:19:24 2019 +0100
@@ -0,0 +1,245 @@
+-- libunbound based net.adns replacement for Prosody IM
+-- Copyright (C) 2013-2015 Kim Alvefur
+--
+-- This file is MIT licensed.
+--
+-- luacheck: ignore prosody
+
+local setmetatable = setmetatable;
+local tostring = tostring;
+local t_concat = table.concat;
+local s_format = string.format;
+local s_lower = string.lower;
+local s_upper = string.upper;
+local noop = function() end;
+local zero = function() return 0 end;
+local truop = function() return true; end;
+
+local log = require "util.logger".init("unbound");
+local net_server = require "net.server";
+local libunbound = require"lunbound";
+local have_promise, promise = pcall(require, "util.promise");
+
+local gettime = require"socket".gettime;
+local dns_utils = require"util.dns";
+local classes, types, errors = dns_utils.classes, dns_utils.types, dns_utils.errors;
+local parsers = dns_utils.parsers;
+
+local function add_defaults(conf)
+	if conf then
+		for option, default in pairs(libunbound.config) do
+			if conf[option] == nil then
+				conf[option] = default;
+			end
+		end
+	end
+	return conf;
+end
+
+local unbound_config;
+if prosody then
+	local config = require"core.configmanager";
+	unbound_config = add_defaults(config.get("*", "unbound"));
+	prosody.events.add_handler("config-reloaded", function()
+		unbound_config = add_defaults(config.get("*", "unbound"));
+	end);
+end
+-- Note: libunbound will default to using root hints if resolvconf is unset
+
+local function connect_server(unbound, server)
+	if server.watchfd then
+		return server.watchfd(unbound, function ()
+			unbound:process()
+		end);
+	elseif server.event and server.addevent then
+		local EV_READ = server.event.EV_READ;
+		local function event_callback()
+			unbound:process();
+			return EV_READ;
+		end
+		return server.addevent(unbound:getfd(), EV_READ, event_callback)
+	elseif server.wrapclient then
+		local conn = {
+			getfd = function()
+				return unbound:getfd();
+			end,
+
+			send = zero,
+			receive = noop,
+			settimeout = noop,
+			close = truop,
+		}
+
+		local function process()
+			unbound:process();
+		end
+		local listener = {
+			onincoming = process,
+
+			onconnect = noop,
+			ondisconnect = noop,
+			onreadtimeout = truop,
+		};
+		return server.wrapclient(conn, "dns", 0, listener, "*a" );
+	end
+end
+
+local unbound = libunbound.new(unbound_config);
+
+local server_conn = connect_server(unbound, net_server);
+
+local answer_mt = {
+	__tostring = function(self)
+		if self._string then return self._string end
+		local h = s_format("Status: %s", errors[self.status]);
+		if self.secure then
+			h = h .. ", Secure";
+		elseif self.bogus then
+			h = h .. s_format(", Bogus: %s", self.bogus);
+		end
+		local t = { h };
+		for i = 1, #self do
+			t[i+1]=self.qname.."\t"..classes[self.qclass].."\t"..types[self.qtype].."\t"..tostring(self[i]);
+		end
+		local _string = t_concat(t, "\n");
+		self._string = _string;
+		return _string;
+	end;
+};
+
+local waiting_queries = {};
+
+local function prep_answer(a)
+	if not a then return end
+	local status = errors[a.rcode];
+	local qclass = classes[a.qclass];
+	local qtype = types[a.qtype];
+	a.status, a.class, a.type = status, qclass, qtype;
+
+	local t = s_lower(qtype);
+	local rr_mt = { __index = a, __tostring = function(self) return tostring(self[t]) end };
+	local parser = parsers[qtype];
+	for i = 1, #a do
+		if a.bogus then
+			-- Discard bogus data
+			a[i] = nil;
+		else
+			a[i] = setmetatable({[t] = parser(a[i])}, rr_mt);
+		end
+	end
+	return setmetatable(a, answer_mt);
+end
+
+local function lookup(callback, qname, qtype, qclass)
+	qtype = qtype and s_upper(qtype) or "A";
+	qclass = qclass and s_upper(qclass) or "IN";
+	local ntype, nclass = types[qtype], classes[qclass];
+	local startedat = gettime();
+	local ret;
+	local function callback_wrapper(a, err)
+		local gotdataat = gettime();
+		waiting_queries[ret] = nil;
+		if a then
+			prep_answer(a);
+			log("debug", "Results for %s %s %s: %s (%s, %f sec)", qname, qclass, qtype, a.rcode == 0 and (#a .. " items") or a.status,
+				a.secure and "Secure" or a.bogus or "Insecure", gotdataat - startedat); -- Insecure as in unsigned
+		else
+			log("error", "Results for %s %s %s: %s", qname, qclass, qtype, tostring(err));
+		end
+		local ok, cerr = pcall(callback, a, err);
+		if not ok then log("error", "Error in callback: %s", cerr); end
+	end
+	log("debug", "Resolve %s %s %s", qname, qclass, qtype);
+	local err;
+	ret, err = unbound:resolve_async(callback_wrapper, qname, ntype, nclass);
+	if ret then
+		waiting_queries[ret] = callback;
+	else
+		log("warn", err);
+	end
+	return ret, err;
+end
+
+local function lookup_sync(qname, qtype, qclass)
+	qtype = qtype and s_upper(qtype) or "A";
+	qclass = qclass and s_upper(qclass) or "IN";
+	local ntype, nclass = types[qtype], classes[qclass];
+	local a, err = unbound:resolve(qname, ntype, nclass);
+	if not a then return a, err; end
+	return prep_answer(a);
+end
+
+local function cancel(id)
+	local cb = waiting_queries[id];
+	unbound:cancel(id);
+	if cb then
+		cb(nil, "canceled");
+		waiting_queries[id] = nil;
+	end
+	return true;
+end
+
+-- Reinitiate libunbound context, drops cache
+local function purge()
+	for id in pairs(waiting_queries) do cancel(id); end
+	if server_conn then server_conn:close(); end
+	unbound = libunbound.new(unbound_config);
+	server_conn = connect_server(unbound, net_server);
+	return true;
+end
+
+local function not_implemented()
+	error "not implemented";
+end
+-- Public API
+local _M = {
+	lookup = lookup;
+	cancel = cancel;
+	new_async_socket = not_implemented;
+	dns = {
+		lookup = lookup_sync;
+		cancel = cancel;
+		cache = noop;
+		socket_wrapper_set = noop;
+		settimeout = noop;
+		query = noop;
+		purge = purge;
+		random = noop;
+		peek = noop;
+
+		types = types;
+		classes = classes;
+	};
+};
+
+local lookup_promise;
+if have_promise then
+	function lookup_promise(_, qname, qtype, qclass)
+		return promise.new(function (resolve, reject)
+			local function callback(answer, err)
+				if err then
+					return reject(err);
+				else
+					return resolve(answer);
+				end
+			end
+			local ret, err = lookup(callback, qname, qtype, qclass)
+			if not ret then reject(err); end
+		end);
+	end
+end
+
+local wrapper = {
+	lookup = function (_, callback, qname, qtype, qclass)
+		return lookup(callback, qname, qtype, qclass)
+	end;
+	lookup_promise = lookup_promise;
+	_resolver = {
+		settimeout = function () end;
+		closeall = function () end;
+	};
+}
+
+function _M.resolver() return wrapper; end
+
+return _M;