Changeset

12408:acfc51b9530c

net.resolvers.basic: Refactor to remove code duplication ...and prepare for Happy Eyeballs
author Matthew Wild <mwild1@gmail.com>
date Fri, 18 Mar 2022 16:09:22 +0000
parents 12407:b6b01724e04f
children 12409:9f0baf15e792
files net/resolvers/basic.lua
diffstat 1 files changed, 72 insertions(+), 80 deletions(-) [+]
line wrap: on
line diff
--- a/net/resolvers/basic.lua	Fri Mar 18 16:43:06 2022 +0100
+++ b/net/resolvers/basic.lua	Fri Mar 18 16:09:22 2022 +0000
@@ -2,13 +2,51 @@
 local inet_pton = require "util.net".pton;
 local inet_ntop = require "util.net".ntop;
 local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
+local promise = require "util.promise";
+local t_move = require "util.table".move;
 
 local methods = {};
 local resolver_mt = { __index = methods };
 
 -- FIXME RFC 6724
 
+local function do_dns_lookup(self, dns_resolver, record_type, name)
+	return promise.new(function (resolve, reject)
+		local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil;
+		if ipv and self.extra["use_ipv"..ipv] == false then
+			return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type));
+		elseif record_type == "TLSA" and self.extra.use_dane ~= true then
+			return reject("DANE disabled - TLSA lookup skipped");
+		end
+		dns_resolver:lookup(function (answer, err)
+			if not answer then
+				return reject(err);
+			elseif answer.bogus then
+				return reject(("Validation error in %s lookup"):format(record_type));
+			elseif answer.status and #answer == 0 then
+				return reject(("%s in %s lookup"):format(answer.status, record_type));
+			end
+
+			local targets = { secure = answer.secure };
+			for _, record in ipairs(answer) do
+				if ipv then
+					table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra });
+				else
+					table.insert(targets, record[record_type:lower()]);
+				end
+			end
+			return resolve(targets);
+		end, name, record_type, "IN");
+	end);
+end
+
+local function merge_targets(ipv4_targets, ipv6_targets)
+	local result = { secure = ipv4_targets.secure and ipv6_targets.secure };
+	t_move(ipv6_targets, 1, #ipv6_targets, 1, result);
+	t_move(ipv4_targets, 1, #ipv4_targets, #result+1, result);
+	return result;
+end
+
 -- Find the next target to connect to, and
 -- pass it to cb()
 function methods:next(cb)
@@ -18,7 +56,7 @@
 			return;
 		end
 		local next_target = table.remove(self.targets, 1);
-		cb(unpack(next_target, 1, 4));
+		cb(next_target[1], next_target[2], next_target[3], next_target[4]);
 		return;
 	end
 
@@ -28,91 +66,45 @@
 		return;
 	end
 
-	local secure = true;
-	local tlsa = {};
-	local targets = {};
-	local n = 3;
-	local function ready()
-		n = n - 1;
-		if n > 0 then return; end
-		self.targets = targets;
+	-- Resolve DNS to target list
+	local dns_resolver = adns.resolver();
+
+	local dns_lookups = {
+		ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname);
+		ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname);
+		tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conntype, self.hostname));
+	};
+
+	promise.all_settled(dns_lookups):next(function (dns_results)
+		-- Combine targets, assign to self.targets, self:next(cb)
+		local have_ipv4 = dns_results.ipv4.status == "fulfilled";
+		local have_ipv6 = dns_results.ipv6.status == "fulfilled";
+
+		if have_ipv4 and have_ipv6 then
+			self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value);
+		elseif have_ipv4 then
+			self.targets = dns_results.ipv4.value;
+		elseif have_ipv6 then
+			self.targets = dns_results.ipv6.value;
+		else
+			self.targets = {};
+		end
+
 		if self.extra and self.extra.use_dane then
-			if secure and tlsa[1] then
-				self.extra.tlsa = tlsa;
+			if self.targets.secure and dns_results.tlsa.status == "fulfilled" then
+				self.extra.tlsa = dns_results.tlsa.value;
 				self.extra.dane_hostname = self.hostname;
 			else
 				self.extra.tlsa = nil;
 				self.extra.dane_hostname = nil;
 			end
 		end
+
 		self:next(cb);
-	end
-
-	-- Resolve DNS to target list
-	local dns_resolver = adns.resolver();
-
-	if not self.extra or self.extra.use_ipv4 ~= false then
-		dns_resolver:lookup(function (answer, err)
-			if answer then
-				secure = secure and answer.secure;
-				for _, record in ipairs(answer) do
-					table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
-				end
-				if answer.bogus then
-					self.last_error = "Validation error in A lookup";
-				elseif answer.status then
-					self.last_error = answer.status .. " in A lookup";
-				end
-			else
-				self.last_error = err;
-			end
-			ready();
-		end, self.hostname, "A", "IN");
-	else
-		ready();
-	end
-
-	if not self.extra or self.extra.use_ipv6 ~= false then
-		dns_resolver:lookup(function (answer, err)
-			if answer then
-				secure = secure and answer.secure;
-				for _, record in ipairs(answer) do
-					table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
-				end
-				if answer.bogus then
-					self.last_error = "Validation error in AAAA lookup";
-				elseif answer.status then
-					self.last_error = answer.status .. " in AAAA lookup";
-				end
-			else
-				self.last_error = err;
-			end
-			ready();
-		end, self.hostname, "AAAA", "IN");
-	else
-		ready();
-	end
-
-	if self.extra and self.extra.use_dane == true then
-		dns_resolver:lookup(function (answer, err)
-			if answer then
-				secure = secure and answer.secure;
-				for _, record in ipairs(answer) do
-					table.insert(tlsa, record.tlsa);
-				end
-				if answer.bogus then
-					self.last_error = "Validation error in TLSA lookup";
-				elseif answer.status then
-					self.last_error = answer.status .. " in TLSA lookup";
-				end
-			else
-				self.last_error = err;
-			end
-			ready();
-		end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN");
-	else
-		ready();
-	end
+	end):catch(function (err)
+		self.last_error = err;
+		self.targets = {};
+	end);
 end
 
 local function new(hostname, port, conn_type, extra)
@@ -137,7 +129,7 @@
 		hostname = ascii_host;
 		port = port;
 		conn_type = conn_type;
-		extra = extra;
+		extra = extra or {};
 		targets = targets;
 	}, resolver_mt);
 end