Comparison

net/resolvers/service.lua @ 12401:c029ddcad258

net.resolvers.service: Honour record 'weight' when picking SRV targets #NotHappyEyeballs
author Matthew Wild <mwild1@gmail.com>
date Thu, 17 Mar 2022 18:20:26 +0000
parent 12129:7a68d5828f3b
child 12808:12bd40b8e105
comparison
equal deleted inserted replaced
12400:728d1c1dc7db 12401:c029ddcad258
1 local adns = require "net.adns"; 1 local adns = require "net.adns";
2 local basic = require "net.resolvers.basic"; 2 local basic = require "net.resolvers.basic";
3 local inet_pton = require "util.net".pton; 3 local inet_pton = require "util.net".pton;
4 local idna_to_ascii = require "util.encodings".idna.to_ascii; 4 local idna_to_ascii = require "util.encodings".idna.to_ascii;
5 local unpack = table.unpack or unpack; -- luacheck: ignore 113
6 5
7 local methods = {}; 6 local methods = {};
8 local resolver_mt = { __index = methods }; 7 local resolver_mt = { __index = methods };
9 8
9 local function new_target_selector(rrset)
10 local rr_count = rrset and #rrset;
11 if not rr_count or rr_count == 0 then
12 rrset = nil;
13 else
14 table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
15 end
16 local rrset_pos = 1;
17 local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
18 return function ()
19 if not rrset then return; end
20
21 if not priority_bucket or bucket_used >= bucket_len then
22 if rrset_pos > rr_count then return; end -- Used up all records
23
24 -- Going to start on a new priority now. Gather up all the next
25 -- records with the same priority and add them to priority_bucket
26 priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
27 local current_priority;
28 repeat
29 local curr_record = rrset[rrset_pos].srv;
30 if not current_priority then
31 current_priority = curr_record.priority;
32 elseif current_priority ~= curr_record.priority then
33 break;
34 end
35 table.insert(priority_bucket, curr_record);
36 bucket_total_weight = bucket_total_weight + curr_record.weight;
37 bucket_len = bucket_len + 1;
38 rrset_pos = rrset_pos + 1;
39 until rrset_pos > rr_count;
40 end
41
42 bucket_used = bucket_used + 1;
43 local n, running_total = math.random(0, bucket_total_weight), 0;
44 local target_record;
45 for i = 1, bucket_len do
46 local candidate = priority_bucket[i];
47 if candidate then
48 running_total = running_total + candidate.weight;
49 if running_total >= n then
50 target_record = candidate;
51 bucket_total_weight = bucket_total_weight - candidate.weight;
52 priority_bucket[i] = nil;
53 break;
54 end
55 end
56 end
57 return target_record;
58 end;
59 end
60
10 -- Find the next target to connect to, and 61 -- Find the next target to connect to, and
11 -- pass it to cb() 62 -- pass it to cb()
12 function methods:next(cb) 63 function methods:next(cb)
13 if self.targets then 64 if self.resolver or self._get_next_target then
14 if not self.resolver then 65 if not self.resolver then -- Do we have a basic resolver currently?
15 if #self.targets == 0 then 66 -- We don't, so fetch a new SRV target, create a new basic resolver for it
67 local next_srv_target = self._get_next_target and self._get_next_target();
68 if not next_srv_target then
69 -- No more SRV targets left
16 cb(nil); 70 cb(nil);
17 return; 71 return;
18 end 72 end
19 local next_target = table.remove(self.targets, 1); 73 -- Create a new basic resolver for this SRV target
20 self.resolver = basic.new(unpack(next_target, 1, 4)); 74 self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
21 end 75 end
76 -- Look up the next (basic) target from the current target's resolver
22 self.resolver:next(function (...) 77 self.resolver:next(function (...)
23 if self.resolver then 78 if self.resolver then
24 self.last_error = self.resolver.last_error; 79 self.last_error = self.resolver.last_error;
25 end 80 end
26 if ... == nil then 81 if ... == nil then
29 else 84 else
30 cb(...); 85 cb(...);
31 end 86 end
32 end); 87 end);
33 return; 88 return;
89 elseif self.in_progress then
90 cb(nil);
91 return;
34 end 92 end
35 93
36 if not self.hostname then 94 if not self.hostname then
37 self.last_error = "hostname failed IDNA"; 95 self.last_error = "hostname failed IDNA";
38 cb(nil); 96 cb(nil);
39 return; 97 return;
40 end 98 end
41 99
42 local targets = {}; 100 self.in_progress = true;
101
43 local function ready() 102 local function ready()
44 self.targets = targets;
45 self:next(cb); 103 self:next(cb);
46 end 104 end
47 105
48 -- Resolve DNS to target list 106 -- Resolve DNS to target list
49 local dns_resolver = adns.resolver(); 107 local dns_resolver = adns.resolver();
61 return; 119 return;
62 end 120 end
63 121
64 if #answer == 0 then 122 if #answer == 0 then
65 if self.extra and self.extra.default_port then 123 if self.extra and self.extra.default_port then
66 table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); 124 self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
67 else 125 else
68 self.last_error = "zero SRV records found"; 126 self.last_error = "zero SRV records found";
69 end 127 end
70 ready(); 128 ready();
71 return; 129 return;
75 self.last_error = "service explicitly unavailable"; 133 self.last_error = "service explicitly unavailable";
76 ready(); 134 ready();
77 return; 135 return;
78 end 136 end
79 137
80 table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); 138 self._get_next_target = new_target_selector(answer);
81 for _, record in ipairs(answer) do
82 table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
83 end
84 else 139 else
85 self.last_error = err; 140 self.last_error = err;
86 end 141 end
87 ready(); 142 ready();
88 end, "_" .. self.service .. "._" .. self.conn_type .. "." .. self.hostname, "SRV", "IN"); 143 end, "_" .. self.service .. "._" .. self.conn_type .. "." .. self.hostname, "SRV", "IN");