Software /
code /
prosody
Changeset
337:4a1dd1c2c219
We have SRV resolving \o/
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Tue, 18 Nov 2008 22:41:04 +0000 |
parents | 336:d97d59cfd1e8 |
children | 338:804f5e62a41f |
files | core/s2smanager.lua net/dns.lua tests/test.lua util/ztact.lua |
diffstat | 4 files changed, 1188 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/core/s2smanager.lua Tue Nov 18 22:40:23 2008 +0000 +++ b/core/s2smanager.lua Tue Nov 18 22:41:04 2008 +0000 @@ -3,7 +3,7 @@ local sessions = sessions; local socket = require "socket"; local format = string.format; -local t_insert = table.insert; +local t_insert, t_sort = table.insert, table.sort; local get_traceback = debug.traceback; local tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber = tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber; @@ -24,17 +24,19 @@ local dialback_secret = "This is very secret!!! Ha!"; -local srvmap = { ["gmail.com"] = "talk.google.com", ["identi.ca"] = "hampton.controlezvous.ca", ["cdr.se"] = "jabber.cdr.se" }; +local dns = require "net.dns"; module "s2smanager" +local function compare_srv_priorities(a,b) return a.priority < b.priority or a.weight < b.weight; end + function send_to_host(from_host, to_host, data) if data.name then data = tostring(data); end local host = hosts[from_host].s2sout[to_host]; if host then -- We have a connection to this host already if host.type == "s2sout_unauthed" then - host.log("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now..."); + (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now..."); if not host.notopen and not host.dialback_key then host.log("debug", "dialback had not been initiated"); initiate_dialback(host); @@ -87,11 +89,31 @@ local conn, handler = socket.tcp() --FIXME: Below parameters (ports/ip) are incorrect (use SRV) - to_host = srvmap[to_host] or to_host; + + local connect_host, connect_port = to_host, 5269; + + local answer = dns.lookup("_xmpp-server._tcp."..to_host..".", "SRV"); + + if answer then + log("debug", to_host.." has SRV records, handling..."); + local srv_hosts = {}; + host_session.srv_hosts = srv_hosts; + for _, record in ipairs(answer) do + t_insert(srv_hosts, record.srv); + end + t_sort(srv_hosts, compare_srv_priorities); + + local srv_choice = srv_hosts[1]; + if srv_choice then + log("debug", "Best record found"); + connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; + log("debug", "Best record found, will connect to %s:%d", connect_host, connect_port); + end + end conn:settimeout(0); - local success, err = conn:connect(to_host, 5269); - if not success then + local success, err = conn:connect(connect_host, connect_port); + if not success and err ~= "timeout" then log("warn", "s2s connect() failed: %s", err); end
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/net/dns.lua Tue Nov 18 22:41:04 2008 +0000 @@ -0,0 +1,795 @@ + + +-- public domain 20080404 lua@ztact.com + + +-- todo: quick (default) header generation +-- todo: nxdomain, error handling +-- todo: cache results of encodeName + + +-- reference: http://tools.ietf.org/html/rfc1035 +-- reference: http://tools.ietf.org/html/rfc1876 (LOC) + + +require 'socket' +local ztact = require 'util.ztact' + + +local coroutine, io, math, socket, string, table = + coroutine, io, math, socket, string, table + +local ipairs, next, pairs, print, setmetatable, tostring = + ipairs, next, pairs, print, setmetatable, tostring + +local get, set = ztact.get, ztact.set + + +-------------------------------------------------- module dns +module ('dns') +local dns = _M; + + +-- dns type & class codes ------------------------------ dns type & class codes + + +local append = table.insert + + +local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte + return (i-(i%0x100))/0x100 + end + + +local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment + local a = {} + for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end + return a + end + + +local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode + local code = {} + for i,s in pairs (t) do + local word = string.char (highbyte (i), i %0x100) + code[i] = word + code[s] = word + code[string.lower (s)] = word + end + return code + end + + +dns.types = { + 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', + 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', + [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', + [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' } + + +dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' } + + +dns.type = augment (dns.types) +dns.class = augment (dns.classes) +dns.typecode = encode (dns.types) +dns.classcode = encode (dns.classes) + + + +local function standardize (qname, qtype, qclass) -- - - - - - - standardize + if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end + qname = string.lower (qname) + return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'] + end + + +local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune + + time = time or socket.gettime () + for i,rr in pairs (rrs) do + + if rr.tod then + -- rr.tod = rr.tod - 50 -- accelerated decripitude + rr.ttl = math.floor (rr.tod - time) + if rr.ttl <= 0 then rrs[i] = nil end + + elseif soft == 'soft' then -- What is this? I forget! + assert (rr.ttl == 0) + rrs[i] = nil + end end end + + +-- metatables & co. ------------------------------------------ metatables & co. + + +local resolver = {} +resolver.__index = resolver + + +local SRV_tostring + + +local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable +function rr_metatable.__tostring (rr) + local s0 = string.format ( + '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name ) + local s1 = '' + if rr.type == 'A' then s1 = ' '..rr.a + elseif rr.type == 'MX' then + s1 = string.format (' %2i %s', rr.pref, rr.mx) + elseif rr.type == 'CNAME' then s1 = ' '..rr.cname + elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr) + elseif rr.type == 'NS' then s1 = ' '..rr.ns + elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr) + elseif rr.type == 'TXT' then s1 = ' '..rr.txt + else s1 = ' <UNKNOWN RDATA TYPE>' end + return s0..s1 + end + + +local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable +function rrs_metatable.__tostring (rrs) + t = {} + for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end + return table.concat (t) + end + + +local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable +function cache_metatable.__tostring (cache) + local time = socket.gettime () + local t = {} + for class,types in pairs (cache) do + for type,names in pairs (types) do + for name,rrs in pairs (names) do + prune (rrs, time) + append (t, tostring (rrs)) end end end + return table.concat (t) + end + + +function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver + local r = { active = {}, cache = {}, unsorted = {} } + setmetatable (r, resolver) + setmetatable (r.cache, cache_metatable) + setmetatable (r.unsorted, { __mode = 'kv' }) + return r + end + + +-- packet layer -------------------------------------------------- packet layer + + +function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random + math.randomseed (10000*socket.gettime ()) + dns.random = math.random + return dns.random (...) + end + + +local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader + + o = o or {} + + o.id = o.id or -- 16b (random) id + dns.random (0, 0xffff) + + o.rd = o.rd or 1 -- 1b 1 recursion desired + o.tc = o.tc or 0 -- 1b 1 truncated response + o.aa = o.aa or 0 -- 1b 1 authoritative response + o.opcode = o.opcode or 0 -- 4b 0 query + -- 1 inverse query + -- 2 server status request + -- 3-15 reserved + o.qr = o.qr or 0 -- 1b 0 query, 1 response + + o.rcode = o.rcode or 0 -- 4b 0 no error + -- 1 format error + -- 2 server failure + -- 3 name error + -- 4 not implemented + -- 5 refused + -- 6-15 reserved + o.z = o.z or 0 -- 3b 0 resvered + o.ra = o.ra or 0 -- 1b 1 recursion available + + o.qdcount = o.qdcount or 1 -- 16b number of question RRs + o.ancount = o.ancount or 0 -- 16b number of answers RRs + o.nscount = o.nscount or 0 -- 16b number of nameservers RRs + o.arcount = o.arcount or 0 -- 16b number of additional RRs + + -- string.char() rounds, so prevent roundup with -0.4999 + local header = string.char ( + highbyte (o.id), o.id %0x100, + o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, + o.rcode + 16*o.z + 128*o.ra, + highbyte (o.qdcount), o.qdcount %0x100, + highbyte (o.ancount), o.ancount %0x100, + highbyte (o.nscount), o.nscount %0x100, + highbyte (o.arcount), o.arcount %0x100 ) + + return header, o.id + end + + +local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName + local t = {} + for part in string.gmatch (name, '[^.]+') do + append (t, string.char (string.len (part))) + append (t, part) + end + append (t, string.char (0)) + return table.concat (t) + end + + +local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion + qname = encodeName (qname) + qtype = dns.typecode[qtype or 'a'] + qclass = dns.classcode[qclass or 'in'] + return qname..qtype..qclass; + end + + +function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte + len = len or 1 + local offset = self.offset + local last = offset + len - 1 + if last > #self.packet then + error (string.format ('out of bounds: %i>%i', last, #self.packet)) end + self.offset = offset + len + return string.byte (self.packet, offset, last) + end + + +function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word + local b1, b2 = self:byte (2) + return 0x100*b1 + b2 + end + + +function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword + local b1, b2, b3, b4 = self:byte (4) + -- print ('dword', b1, b2, b3, b4) + return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 + end + + +function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub + len = len or 1 + local s = string.sub (self.packet, self.offset, self.offset + len - 1) + self.offset = self.offset + len + return s + end + + +function resolver:header (force) -- - - - - - - - - - - - - - - - - - header + + local id = self:word () + -- print (string.format (':header id %x', id)) + if not self.active[id] and not force then return nil end + + local h = { id = id } + + local b1, b2 = self:byte (2) + + h.rd = b1 %2 + h.tc = b1 /2%2 + h.aa = b1 /4%2 + h.opcode = b1 /8%16 + h.qr = b1 /128 + + h.rcode = b2 %16 + h.z = b2 /16%8 + h.ra = b2 /128 + + h.qdcount = self:word () + h.ancount = self:word () + h.nscount = self:word () + h.arcount = self:word () + + for k,v in pairs (h) do h[k] = v-v%1 end + + return h + end + + +function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name + local remember, pointers = nil, 0 + local len = self:byte () + local n = {} + while len > 0 do + if len >= 0xc0 then -- name is "compressed" + pointers = pointers + 1 + if pointers >= 20 then error ('dns error: 20 pointers') end + local offset = ((len-0xc0)*0x100) + self:byte () + remember = remember or self.offset + self.offset = offset + 1 -- +1 for lua + else -- name is not compressed + append (n, self:sub (len)..'.') + end + len = self:byte () + end + self.offset = remember or self.offset + return table.concat (n) + end + + +function resolver:question () -- - - - - - - - - - - - - - - - - - question + local q = {} + q.name = self:name () + q.type = dns.type[self:word ()] + q.class = dns.type[self:word ()] + return q + end + + +function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A + local b1, b2, b3, b4 = self:byte (4) + rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4) + end + + +function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME + rr.cname = self:name () + end + + +function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX + rr.pref = self:word () + rr.mx = self:name () + end + + +function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power + local b = self:byte () + -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) + return ((b-(b%0x10))/0x10) * (10^(b%0x10)) + end + + +function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC + rr.version = self:byte () + if rr.version == 0 then + rr.loc = rr.loc or {} + rr.loc.size = self:LOC_nibble_power () + rr.loc.horiz_pre = self:LOC_nibble_power () + rr.loc.vert_pre = self:LOC_nibble_power () + rr.loc.latitude = self:dword () + rr.loc.longitude = self:dword () + rr.loc.altitude = self:dword () + end end + + +local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - - + f = f - 0x80000000 + if f < 0 then pos = neg f = -f end + local deg, min, msec + msec = f%60000 + f = (f-msec)/60000 + min = f%60 + deg = (f-min)/60 + return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos) + end + + +function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring + + local t = {} + + --[[ + for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', + 'latitude', 'longitude', 'altitude' } do + append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name])) + end + --]] + + append ( t, string.format ( + '%s %s %.2fm %.2fm %.2fm %.2fm', + LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), + LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), + (rr.loc.altitude - 10000000) / 100, + rr.loc.size / 100, + rr.loc.horiz_pre / 100, + rr.loc.vert_pre / 100 ) ) + + return table.concat (t) + end + + +function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS + rr.ns = self:name () + end + + +function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA + end + + +function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV + rr.srv = {} + rr.srv.priority = self:word () + rr.srv.weight = self:word () + rr.srv.port = self:word () + rr.srv.target = self:name () + end + + +function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring + local s = rr.srv + return string.format ( '%5d %5d %5d %s', + s.priority, s.weight, s.port, s.target ) + end + + +function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT + rr.txt = self:sub (rr.rdlength) + end + + +function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr + local rr = {} + setmetatable (rr, rr_metatable) + rr.name = self:name (self) + rr.type = dns.type[self:word ()] or rr.type + rr.class = dns.class[self:word ()] or rr.class + rr.ttl = 0x10000*self:word () + self:word () + rr.rdlength = self:word () + + if rr.ttl == 0 then -- pass + else rr.tod = self.time + rr.ttl end + + local remember = self.offset + local rr_parser = self[dns.type[rr.type]] + if rr_parser then rr_parser (self, rr) end + self.offset = remember + rr.rdata = self:sub (rr.rdlength) + return rr + end + + +function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs + local rrs = {} + for i = 1,count do append (rrs, self:rr ()) end + return rrs + end + + +function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode + + self.packet, self.offset = packet, 1 + local header = self:header (force) + if not header then return nil end + local response = { header = header } + + response.question = {} + local offset = self.offset + for i = 1,response.header.qdcount do + append (response.question, self:question ()) end + response.question.raw = string.sub (self.packet, offset, self.offset - 1) + + if not force then + if not self.active[response.header.id] or + not self.active[response.header.id][response.question.raw] then + return nil end end + + response.answer = self:rrs (response.header.ancount) + response.authority = self:rrs (response.header.nscount) + response.additional = self:rrs (response.header.arcount) + + return response + end + + +-- socket layer -------------------------------------------------- socket layer + + +resolver.delays = { 1, 3, 11, 45 } + + +function resolver:addnameserver (address) -- - - - - - - - - - addnameserver + self.server = self.server or {} + append (self.server, address) + end + + +function resolver:setnameserver (address) -- - - - - - - - - - setnameserver + self.server = {} + self:addnameserver (address) + end + + +function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers + for line in io.lines ('/etc/resolv.conf') do + address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)') + if address then self:addnameserver (address) end + end end + + +function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket + + self.socket = self.socket or {} + self.socketset = self.socketset or {} + + local sock = self.socket[servernum] + if sock then return sock end + + sock = socket.udp () + if self.socket_wrapper then sock = self.socket_wrapper (sock) end + sock:settimeout (0) + -- todo: attempt to use a random port, fallback to 0 + sock:setsockname ('*', 0) + sock:setpeername (self.server[servernum], 53) + self.socket[servernum] = sock + self.socketset[sock] = sock + return sock + end + + +function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set + self.socket_wrapper = func + end + + +function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall + for i,sock in ipairs (self.socket) do self.socket[i]:close () end + self.socket = {} + end + + +function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember + + -- print ('remember', type, rr.class, rr.type, rr.name) + + if type ~= '*' then + type = rr.type + local all = get (self.cache, rr.class, '*', rr.name) + -- print ('remember all', all) + if all then append (all, rr) end + end + + self.cache = self.cache or setmetatable ({}, cache_metatable) + local rrs = get (self.cache, rr.class, type, rr.name) or + set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable)) + append (rrs, rr) + + if type == 'MX' then self.unsorted[rrs] = true end + end + + +local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx + return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref) + end + + +function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek + qname, qtype, qclass = standardize (qname, qtype, qclass) + local rrs = get (self.cache, qclass, qtype, qname) + if not rrs then return nil end + if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then + set (self.cache, qclass, qtype, qname, nil) return nil end + if self.unsorted[rrs] then table.sort (rrs, comp_mx) end + return rrs + end + + +function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge + if soft == 'soft' then + self.time = socket.gettime () + for class,types in pairs (self.cache or {}) do + for type,names in pairs (types) do + for name,rrs in pairs (names) do + prune (rrs, time, 'soft') + end end end + else self.cache = {} end + end + + +function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query + + qname, qtype, qclass = standardize (qname, qtype, qclass) + + if not self.server then self:adddefaultnameservers () end + + local question = question or encodeQuestion (qname, qtype, qclass) + local peek = self:peek (qname, qtype, qclass) + if peek then return peek end + + local header, id = encodeHeader () + -- print ('query id', id, qclass, qtype, qname) + local o = { packet = header..question, + server = 1, + delay = 1, + retry = socket.gettime () + self.delays[1] } + self:getsocket (o.server):send (o.packet) + + -- remember the query + self.active[id] = self.active[id] or {} + self.active[id][question] = o + + -- remember which coroutine wants the answer + local co = coroutine.running () + if co then + set (self.wanted, qclass, qtype, qname, co, true) + set (self.yielded, co, qclass, qtype, qname, true) + end end + + +function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive + + -- print 'receive' print (self.socket) + self.time = socket.gettime () + rset = rset or self.socket + + local response + for i,sock in pairs (rset) do + + if self.socketset[sock] then + local packet = sock:receive () + if packet then + + response = self:decode (packet) + if response then + -- print 'received response' + -- self.print (response) + + for i,section in pairs { 'answer', 'authority', 'additional' } do + for j,rr in pairs (response[section]) do + self:remember (rr, response.question[1].type) end end + + -- retire the query + local queries = self.active[response.header.id] + if queries[response.question.raw] then + queries[response.question.raw] = nil end + if not next (queries) then self.active[response.header.id] = nil end + if not next (self.active) then self:closeall () end + + -- was the query on the wanted list? + local q = response.question + local cos = get (self.wanted, q.class, q.type, q.name) + if cos then + for co in pairs (cos) do + set (self.yielded, co, q.class, q.type, q.name, nil) + if not self.yielded[co] then coroutine.resume (co) end + end + set (self.wanted, q.class, q.type, q.name, nil) + end end end end end + + return response + end + + +function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse + + -- print ':pulse' + while self:receive () do end + if not next (self.active) then return nil end + + self.time = socket.gettime () + for id,queries in pairs (self.active) do + for question,o in pairs (queries) do + if self.time >= o.retry then + + o.server = o.server + 1 + if o.server > #self.server then + o.server = 1 + o.delay = o.delay + 1 + end + + if o.delay > #self.delays then + print ('timeout') + queries[question] = nil + if not next (queries) then self.active[id] = nil end + if not next (self.active) then return nil end + else + -- print ('retry', o.server, o.delay) + self.socket[o.server]:send (o.packet) + o.retry = self.time + self.delays[o.delay] + end end end end + + if next (self.active) then return true end + return nil + end + + +function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup + self:query (qname, qtype, qclass) + while self:pulse () do socket.select (self.socket, nil, 4) end + -- print (self.cache) + return self:peek (qname, qtype, qclass) + end + + +-- print ---------------------------------------------------------------- print + + +local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints + qr = { [0]='query', 'response' }, + opcode = { [0]='query', 'inverse query', 'server status request' }, + aa = { [0]='non-authoritative', 'authoritative' }, + tc = { [0]='complete', 'truncated' }, + rd = { [0]='recursion not desired', 'recursion desired' }, + ra = { [0]='recursion not available', 'recursion available' }, + z = { [0]='(reserved)' }, + rcode = { [0]='no error', 'format error', 'server failure', 'name error', + 'not implemented' }, + + type = dns.type, + class = dns.class, } + + +local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint + return (hints[s] and hints[s][p[s]]) or '' end + + +function resolver.print (response) -- - - - - - - - - - - - - resolver.print + + for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', + 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do + print ( string.format ('%-30s', 'header.'..s), + response.header[s], hint (response.header, s) ) + end + + for i,question in ipairs (response.question) do + print (string.format ('question[%i].name ', i), question.name) + print (string.format ('question[%i].type ', i), question.type) + print (string.format ('question[%i].class ', i), question.class) + end + + local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 } + local tmp + for s,s in pairs {'answer', 'authority', 'additional'} do + for i,rr in pairs (response[s]) do + for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do + tmp = string.format ('%s[%i].%s', s, i, t) + print (string.format ('%-30s', tmp), rr[t], hint (rr, t)) + end + for j,t in pairs (rr) do + if not common[j] then + tmp = string.format ('%s[%i].%s', s, i, j) + print (string.format ('%-30s %s', tmp, t)) + end end end end end + + +-- module api ------------------------------------------------------ module api + + +local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get + dns._resolver = dns._resolver or dns.resolver () + return func (dns._resolver, ...) + end + + +function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver + + -- this function seems to be redundant with resolver.new () + + r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} } + setmetatable (r, resolver) + setmetatable (r.cache, cache_metatable) + setmetatable (r.unsorted, { __mode = 'kv' }) + return r + end + + +function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup + return resolve (resolver.lookup, ...) end + + +function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge + return resolve (resolver.purge, ...) end + +function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek + return resolve (resolver.peek, ...) end + + +function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query + return resolve (resolver.query, ...) end + + +function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set + return resolve (resolver.socket_wrapper_set, ...) end + + +return dns
--- a/tests/test.lua Tue Nov 18 22:40:23 2008 +0000 +++ b/tests/test.lua Tue Nov 18 22:41:04 2008 +0000 @@ -83,3 +83,4 @@ dotest "util.jid" dotest "core.stanza_router" +dotest "core.s2smanager"
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/util/ztact.lua Tue Nov 18 22:41:04 2008 +0000 @@ -0,0 +1,364 @@ + + +-- public domain 20080410 lua@ztact.com + + +pcall (require, 'lfs') -- lfs may not be installed/necessary. +pcall (require, 'pozix') -- pozix may not be installed/necessary. + + +local getfenv, ipairs, next, pairs, pcall, require, select, tostring, type = + getfenv, ipairs, next, pairs, pcall, require, select, tostring, type +local unpack, xpcall = + unpack, xpcall + +local io, lfs, os, string, table, pozix = io, lfs, os, string, table, pozix + +local assert, print = assert, print + +local error = error + + +module ((...) or 'ztact') ------------------------------------- module ztact + + +-- dir -------------------------------------------------------------------- dir + + +function dir (path) -- - - - - - - - - - - - - - - - - - - - - - - - - - dir + local it = lfs.dir (path) + return function () + repeat + local dir = it () + if dir ~= '.' and dir ~= '..' then return dir end + until not dir + end end + + +function is_file (path) -- - - - - - - - - - - - - - - - - - is_file (path) + local mode = lfs.attributes (path, 'mode') + return mode == 'file' and path + end + + +-- network byte ordering -------------------------------- network byte ordering + + +function htons (word) -- - - - - - - - - - - - - - - - - - - - - - - - htons + return (word-word%0x100)/0x100, word%0x100 + end + + +-- pcall2 -------------------------------------------------------------- pcall2 + + +getfenv ().pcall = pcall -- store the original pcall as ztact.pcall + + +local argc, argv, errorhandler, pcall2_f + + +local function _pcall2 () -- - - - - - - - - - - - - - - - - - - - - _pcall2 + local tmpv = argv + argv = nil + return pcall2_f (unpack (tmpv, 1, argc)) + end + + +function seterrorhandler (func) -- - - - - - - - - - - - - - seterrorhandler + errorhandler = func + end + + +function pcall2 (f, ...) -- - - - - - - - - - - - - - - - - - - - - - pcall2 + + pcall2_f = f + argc = select ('#', ...) + argv = { ... } + + if not errorhandler then + local debug = require ('debug') + errorhandler = debug.traceback + end + + return xpcall (_pcall2, errorhandler) + end + + +function append (t, ...) -- - - - - - - - - - - - - - - - - - - - - - append + local insert = table.insert + for i,v in ipairs {...} do + insert (t, v) + end end + + +function print_r (d, indent) -- - - - - - - - - - - - - - - - - - - print_r + local rep = string.rep (' ', indent or 0) + if type (d) == 'table' then + for k,v in pairs (d) do + if type (v) == 'table' then + io.write (rep, k, '\n') + print_r (v, (indent or 0) + 1) + else io.write (rep, k, ' = ', tostring (v), '\n') end + end + else io.write (d, '\n') end + end + + +function tohex (s) -- - - - - - - - - - - - - - - - - - - - - - - - - tohex + return string.format (string.rep ('%02x ', #s), string.byte (s, 1, #s)) + end + + +function tostring_r (d, indent, tab0) -- - - - - - - - - - - - - tostring_r + + tab1 = tab0 or {} + local rep = string.rep (' ', indent or 0) + if type (d) == 'table' then + for k,v in pairs (d) do + if type (v) == 'table' then + append (tab1, rep, k, '\n') + tostring_r (v, (indent or 0) + 1, tab1) + else append (tab1, rep, k, ' = ', tostring (v), '\n') end + end + else append (tab1, d, '\n') end + + if not tab0 then return table.concat (tab1) end + end + + +-- queue manipulation -------------------------------------- queue manipulation + + +-- Possible queue states. 1 (i.e. queue.p[1]) is head of queue. +-- +-- 1..2 +-- 3..4 1..2 +-- 3..4 1..2 5..6 +-- 1..2 5..6 +-- 1..2 + + +local function print_queue (queue, ...) -- - - - - - - - - - - - print_queue + for i=1,10 do io.write ((queue[i] or '.')..' ') end + io.write ('\t') + for i=1,6 do io.write ((queue.p[i] or '.')..' ') end + print (...) + end + + +function dequeue (queue) -- - - - - - - - - - - - - - - - - - - - - dequeue + + local p = queue.p + if not p and queue[1] then queue.p = { 1, #queue } p = queue.p end + + if not p[1] then return nil end + + local element = queue[p[1]] + queue[p[1]] = nil + + if p[1] < p[2] then p[1] = p[1] + 1 + + elseif p[4] then p[1], p[2], p[3], p[4] = p[3], p[4], nil, nil + + elseif p[5] then p[1], p[2], p[5], p[6] = p[5], p[6], nil, nil + + else p[1], p[2] = nil, nil end + + print_queue (queue, ' de '..element) + return element + end + + +function enqueue (queue, element) -- - - - - - - - - - - - - - - - - enqueue + + local p = queue.p + if not p then queue.p = {} p = queue.p end + + if p[5] then -- p3..p4 p1..p2 p5..p6 + p[6] = p[6]+1 + queue[p[6]] = element + + elseif p[3] then -- p3..p4 p1..p2 + + if p[4]+1 < p[1] then + p[4] = p[4] + 1 + queue[p[4]] = element + + else + p[5] = p[2]+1 + p[6], queue[p[5]] = p[5], element + end + + elseif p[1] then -- p1..p2 + if p[1] == 1 then + p[2] = p[2] + 1 + queue[p[2]] = element + + else + p[3], p[4], queue[1] = 1, 1, element + end + + else -- empty queue + p[1], p[2], queue[1] = 1, 1, element + end + + print_queue (queue, ' '..element) + end + + +local function test_queue () + t = {} + enqueue (t, 1) + enqueue (t, 2) + enqueue (t, 3) + enqueue (t, 4) + enqueue (t, 5) + dequeue (t) + dequeue (t) + enqueue (t, 6) + enqueue (t, 7) + enqueue (t, 8) + enqueue (t, 9) + dequeue (t) + dequeue (t) + dequeue (t) + dequeue (t) + enqueue (t, 'a') + dequeue (t) + enqueue (t, 'b') + enqueue (t, 'c') + dequeue (t) + dequeue (t) + dequeue (t) + dequeue (t) + dequeue (t) + enqueue (t, 'd') + dequeue (t) + dequeue (t) + dequeue (t) + end + + +-- test_queue () + + +function queue_len (queue) + end + + +function queue_peek (queue) + end + + +-- tree manipulation ---------------------------------------- tree manipulation + + +function set (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - set + + -- print ('set', ...) + + local len = select ('#', ...) + local key, value = select (len-1, ...) + local cutpoint, cutkey + + for i=1,len-2 do + + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then return + elseif next (child, next (child)) then cutpoint = nil cutkey = nil + elseif cutpoint == nil then cutpoint = parent cutkey = key end + + elseif child == nil then child = {} parent[key] = child end + + parent = child + end + + if value == nil and cutpoint then cutpoint[cutkey] = nil + else parent[key] = value return value end + end + + +function get (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - get + local len = select ('#', ...) + for i=1,len do + parent = parent[select (i, ...)] + if parent == nil then break end + end + return parent + end + + +-- misc ------------------------------------------------------------------ misc + + +function find (path, ...) --------------------------------------------- find + + local dirs, operators = { path }, {...} + for operator in ivalues (operators) do + if not operator (path) then break end end + + while next (dirs) do + local parent = table.remove (dirs) + for child in assert (pozix.opendir (parent)) do + if child and child ~= '.' and child ~= '..' then + local path = parent..'/'..child + if pozix.stat (path, 'is_dir') then table.insert (dirs, path) end + for operator in ivalues (operators) do + if not operator (path) then break end end + end end end end + + +function ivalues (t) ----------------------------------------------- ivalues + local i = 0 + return function () if t[i+1] then i = i + 1 return t[i] end end + end + + +function lson_encode (mixed, f, indent, indents) --------------- lson_encode + + + local capture + if not f then + capture = {} + f = function (s) append (capture, s) end + end + + indent = indent or 0 + indents = indents or {} + indents[indent] = indents[indent] or string.rep (' ', 2*indent) + + local type = type (mixed) + + if type == 'number' then f (mixed) + + else if type == 'string' then f (string.format ('%q', mixed)) + + else if type == 'table' then + f ('{') + for k,v in pairs (mixed) do + f ('\n') + f (indents[indent]) + f ('[') f (lson_encode (k)) f ('] = ') + lson_encode (v, f, indent+1, indents) + f (',') + end + f (' }') + end end end + + if capture then return table.concat (capture) end + end + + +function timestamp (time) ---------------------------------------- timestamp + return os.date ('%Y%m%d.%H%M%S', time) + end + + +function values (t) ------------------------------------------------- values + local k, v + return function () k, v = next (t, k) return v end + end