Software /
code /
prosody
Diff
net/dns.lua @ 869:09019c452709
net.dns: Add methods necessary for allowing non-blocking DNS lookups
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Wed, 04 Mar 2009 12:58:56 +0000 |
parent | 759:5cccfb5da6cb |
child | 896:2c0b9e3c11c3 |
line wrap: on
line diff
--- a/net/dns.lua Wed Mar 04 12:44:07 2009 +0000 +++ b/net/dns.lua Wed Mar 04 12:58:56 2009 +0000 @@ -16,7 +16,7 @@ require 'socket' local ztact = require 'util.ztact' - +local require = require local coroutine, io, math, socket, string, table = coroutine, io, math, socket, string, table @@ -253,7 +253,7 @@ function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword local b1, b2, b3, b4 = self:byte (4) - -- print ('dword', b1, b2, b3, b4) + --print ('dword', b1, b2, b3, b4) return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 end @@ -269,7 +269,7 @@ function resolver:header (force) -- - - - - - - - - - - - - - - - - - header local id = self:word () - -- print (string.format (':header id %x', id)) + --print (string.format (':header id %x', id)) if not self.active[id] and not force then return nil end local h = { id = id } @@ -322,7 +322,7 @@ local q = {} q.name = self:name () q.type = dns.type[self:word ()] - q.class = dns.type[self:word ()] + q.class = dns.class[self:word ()] return q end @@ -346,7 +346,7 @@ function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power local b = self:byte () - -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) + --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) return ((b-(b%0x10))/0x10) * (10^(b%0x10)) end @@ -549,12 +549,12 @@ function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember - -- print ('remember', type, rr.class, rr.type, rr.name) + --print ('remember', type, rr.class, rr.type, rr.name) if type ~= '*' then type = rr.type local all = get (self.cache, rr.class, '*', rr.name) - -- print ('remember all', all) + --print ('remember all', all) if all then append (all, rr) end end @@ -599,14 +599,14 @@ qname, qtype, qclass = standardize (qname, qtype, qclass) - if not self.server then self:adddefaultnameservers () end + if not self.server then self:adddefaultnameservers () end local question = encodeQuestion (qname, qtype, qclass) local peek = self:peek (qname, qtype, qclass) if peek then return peek end local header, id = encodeHeader () - -- print ('query id', id, qclass, qtype, qname) + --print ('query id', id, qclass, qtype, qname) local o = { packet = header..question, server = 1, delay = 1, @@ -621,13 +621,15 @@ local co = coroutine.running () if co then set (self.wanted, qclass, qtype, qname, co, true) - set (self.yielded, co, qclass, qtype, qname, true) - end end + --set (self.yielded, co, qclass, qtype, qname, true) + end +end + function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive - -- print 'receive' print (self.socket) + --print 'receive' print (self.socket) self.time = socket.gettime () rset = rset or self.socket @@ -640,8 +642,8 @@ response = self:decode (packet) if response then - -- print 'received response' - -- self.print (response) + --print 'received response' + --self.print (response) for i,section in pairs { 'answer', 'authority', 'additional' } do for j,rr in pairs (response[section]) do @@ -660,7 +662,7 @@ if cos then for co in pairs (cos) do set (self.yielded, co, q.class, q.type, q.name, nil) - if not self.yielded[co] then coroutine.resume (co) end + if coroutine.status(co) == "suspended" then coroutine.resume (co) end end set (self.wanted, q.class, q.type, q.name, nil) end end end end end @@ -669,10 +671,51 @@ end +function resolver:feed(sock, packet) + --print 'receive' print (self.socket) + self.time = socket.gettime () + + local response = self:decode (packet) + if response then + --print 'received response' + --self.print (response) + + for i,section in pairs { 'answer', 'authority', 'additional' } do + for j,rr in pairs (response[section]) do + self:remember (rr, response.question[1].type) + end + end + + -- retire the query + local queries = self.active[response.header.id] + if queries[response.question.raw] then + queries[response.question.raw] = nil + end + if not next (queries) then self.active[response.header.id] = nil end + if not next (self.active) then self:closeall () end + + -- was the query on the wanted list? + local q = response.question[1] + if q then + local cos = get (self.wanted, q.class, q.type, q.name) + if cos then + for co in pairs (cos) do + set (self.yielded, co, q.class, q.type, q.name, nil) + if coroutine.status(co) == "suspended" then coroutine.resume (co) end + end + set (self.wanted, q.class, q.type, q.name, nil) + end + end + end + + return response +end + + function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse - -- print ':pulse' - while self:receive () do end + --print ':pulse' + while self:receive() do end if not next (self.active) then return nil end self.time = socket.gettime () @@ -687,12 +730,12 @@ end if o.delay > #self.delays then - print ('timeout') + --print ('timeout') queries[question] = nil if not next (queries) then self.active[id] = nil end if not next (self.active) then return nil end else - -- print ('retry', o.server, o.delay) + --print ('retry', o.server, o.delay) local _a = self.socket[o.server]; if _a then _a:send (o.packet) end o.retry = self.time + self.delays[o.delay] @@ -706,12 +749,16 @@ function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup self:query (qname, qtype, qclass) while self:pulse () do socket.select (self.socket, nil, 4) end - -- print (self.cache) + --print (self.cache) return self:peek (qname, qtype, qclass) end +function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup + return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass) + end --- print ---------------------------------------------------------------- print + +--print ---------------------------------------------------------------- print local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints @@ -758,7 +805,7 @@ for j,t in pairs (rr) do if not common[j] then tmp = string.format ('%s[%i].%s', s, i, j) - print (string.format ('%-30s %s', tmp, t)) + print (string.format ('%-30s %s', tostring(tmp), tostring(t))) end end end end end @@ -797,6 +844,9 @@ function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query return resolve (resolver.query, ...) end +function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed + return resolve (resolver.feed, ...) end + function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set return resolve (resolver.socket_wrapper_set, ...) end