Software /
code /
prosody
Comparison
net/resolvers/service.lua @ 12401:c029ddcad258
net.resolvers.service: Honour record 'weight' when picking SRV targets
#NotHappyEyeballs
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Thu, 17 Mar 2022 18:20:26 +0000 |
parent | 12129:7a68d5828f3b |
child | 12808:12bd40b8e105 |
comparison
equal
deleted
inserted
replaced
12400:728d1c1dc7db | 12401:c029ddcad258 |
---|---|
1 local adns = require "net.adns"; | 1 local adns = require "net.adns"; |
2 local basic = require "net.resolvers.basic"; | 2 local basic = require "net.resolvers.basic"; |
3 local inet_pton = require "util.net".pton; | 3 local inet_pton = require "util.net".pton; |
4 local idna_to_ascii = require "util.encodings".idna.to_ascii; | 4 local idna_to_ascii = require "util.encodings".idna.to_ascii; |
5 local unpack = table.unpack or unpack; -- luacheck: ignore 113 | |
6 | 5 |
7 local methods = {}; | 6 local methods = {}; |
8 local resolver_mt = { __index = methods }; | 7 local resolver_mt = { __index = methods }; |
9 | 8 |
9 local function new_target_selector(rrset) | |
10 local rr_count = rrset and #rrset; | |
11 if not rr_count or rr_count == 0 then | |
12 rrset = nil; | |
13 else | |
14 table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end); | |
15 end | |
16 local rrset_pos = 1; | |
17 local priority_bucket, bucket_total_weight, bucket_len, bucket_used; | |
18 return function () | |
19 if not rrset then return; end | |
20 | |
21 if not priority_bucket or bucket_used >= bucket_len then | |
22 if rrset_pos > rr_count then return; end -- Used up all records | |
23 | |
24 -- Going to start on a new priority now. Gather up all the next | |
25 -- records with the same priority and add them to priority_bucket | |
26 priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0; | |
27 local current_priority; | |
28 repeat | |
29 local curr_record = rrset[rrset_pos].srv; | |
30 if not current_priority then | |
31 current_priority = curr_record.priority; | |
32 elseif current_priority ~= curr_record.priority then | |
33 break; | |
34 end | |
35 table.insert(priority_bucket, curr_record); | |
36 bucket_total_weight = bucket_total_weight + curr_record.weight; | |
37 bucket_len = bucket_len + 1; | |
38 rrset_pos = rrset_pos + 1; | |
39 until rrset_pos > rr_count; | |
40 end | |
41 | |
42 bucket_used = bucket_used + 1; | |
43 local n, running_total = math.random(0, bucket_total_weight), 0; | |
44 local target_record; | |
45 for i = 1, bucket_len do | |
46 local candidate = priority_bucket[i]; | |
47 if candidate then | |
48 running_total = running_total + candidate.weight; | |
49 if running_total >= n then | |
50 target_record = candidate; | |
51 bucket_total_weight = bucket_total_weight - candidate.weight; | |
52 priority_bucket[i] = nil; | |
53 break; | |
54 end | |
55 end | |
56 end | |
57 return target_record; | |
58 end; | |
59 end | |
60 | |
10 -- Find the next target to connect to, and | 61 -- Find the next target to connect to, and |
11 -- pass it to cb() | 62 -- pass it to cb() |
12 function methods:next(cb) | 63 function methods:next(cb) |
13 if self.targets then | 64 if self.resolver or self._get_next_target then |
14 if not self.resolver then | 65 if not self.resolver then -- Do we have a basic resolver currently? |
15 if #self.targets == 0 then | 66 -- We don't, so fetch a new SRV target, create a new basic resolver for it |
67 local next_srv_target = self._get_next_target and self._get_next_target(); | |
68 if not next_srv_target then | |
69 -- No more SRV targets left | |
16 cb(nil); | 70 cb(nil); |
17 return; | 71 return; |
18 end | 72 end |
19 local next_target = table.remove(self.targets, 1); | 73 -- Create a new basic resolver for this SRV target |
20 self.resolver = basic.new(unpack(next_target, 1, 4)); | 74 self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra); |
21 end | 75 end |
76 -- Look up the next (basic) target from the current target's resolver | |
22 self.resolver:next(function (...) | 77 self.resolver:next(function (...) |
23 if self.resolver then | 78 if self.resolver then |
24 self.last_error = self.resolver.last_error; | 79 self.last_error = self.resolver.last_error; |
25 end | 80 end |
26 if ... == nil then | 81 if ... == nil then |
29 else | 84 else |
30 cb(...); | 85 cb(...); |
31 end | 86 end |
32 end); | 87 end); |
33 return; | 88 return; |
89 elseif self.in_progress then | |
90 cb(nil); | |
91 return; | |
34 end | 92 end |
35 | 93 |
36 if not self.hostname then | 94 if not self.hostname then |
37 self.last_error = "hostname failed IDNA"; | 95 self.last_error = "hostname failed IDNA"; |
38 cb(nil); | 96 cb(nil); |
39 return; | 97 return; |
40 end | 98 end |
41 | 99 |
42 local targets = {}; | 100 self.in_progress = true; |
101 | |
43 local function ready() | 102 local function ready() |
44 self.targets = targets; | |
45 self:next(cb); | 103 self:next(cb); |
46 end | 104 end |
47 | 105 |
48 -- Resolve DNS to target list | 106 -- Resolve DNS to target list |
49 local dns_resolver = adns.resolver(); | 107 local dns_resolver = adns.resolver(); |
61 return; | 119 return; |
62 end | 120 end |
63 | 121 |
64 if #answer == 0 then | 122 if #answer == 0 then |
65 if self.extra and self.extra.default_port then | 123 if self.extra and self.extra.default_port then |
66 table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); | 124 self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra); |
67 else | 125 else |
68 self.last_error = "zero SRV records found"; | 126 self.last_error = "zero SRV records found"; |
69 end | 127 end |
70 ready(); | 128 ready(); |
71 return; | 129 return; |
75 self.last_error = "service explicitly unavailable"; | 133 self.last_error = "service explicitly unavailable"; |
76 ready(); | 134 ready(); |
77 return; | 135 return; |
78 end | 136 end |
79 | 137 |
80 table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); | 138 self._get_next_target = new_target_selector(answer); |
81 for _, record in ipairs(answer) do | |
82 table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra }); | |
83 end | |
84 else | 139 else |
85 self.last_error = err; | 140 self.last_error = err; |
86 end | 141 end |
87 ready(); | 142 ready(); |
88 end, "_" .. self.service .. "._" .. self.conn_type .. "." .. self.hostname, "SRV", "IN"); | 143 end, "_" .. self.service .. "._" .. self.conn_type .. "." .. self.hostname, "SRV", "IN"); |