Diff

net/dns.lua @ 869:09019c452709

net.dns: Add methods necessary for allowing non-blocking DNS lookups
author Matthew Wild <mwild1@gmail.com>
date Wed, 04 Mar 2009 12:58:56 +0000
parent 759:5cccfb5da6cb
child 896:2c0b9e3c11c3
line wrap: on
line diff
--- a/net/dns.lua	Wed Mar 04 12:44:07 2009 +0000
+++ b/net/dns.lua	Wed Mar 04 12:58:56 2009 +0000
@@ -16,7 +16,7 @@
 
 require 'socket'
 local ztact = require 'util.ztact'
-
+local require = require
 
 local coroutine, io, math, socket, string, table =
       coroutine, io, math, socket, string, table
@@ -253,7 +253,7 @@
 
 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
   local b1, b2, b3, b4 = self:byte (4)
-  -- print ('dword', b1, b2, b3, b4)
+  --print ('dword', b1, b2, b3, b4)
   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
   end
 
@@ -269,7 +269,7 @@
 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
 
   local id = self:word ()
-  -- print (string.format (':header  id  %x', id))
+  --print (string.format (':header  id  %x', id))
   if not self.active[id] and not force then  return nil  end
 
   local h = { id = id }
@@ -322,7 +322,7 @@
   local q = {}
   q.name  = self:name ()
   q.type  = dns.type[self:word ()]
-  q.class = dns.type[self:word ()]
+  q.class = dns.class[self:word ()]
   return q
   end
 
@@ -346,7 +346,7 @@
 
 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
   local b = self:byte ()
-  -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
+  --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
   end
 
@@ -549,12 +549,12 @@
 
 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
 
-  -- print ('remember', type, rr.class, rr.type, rr.name)
+  --print ('remember', type, rr.class, rr.type, rr.name)
 
   if type ~= '*' then
     type = rr.type
     local all = get (self.cache, rr.class, '*', rr.name)
-    -- print ('remember all', all)
+    --print ('remember all', all)
     if all then  append (all, rr)  end
     end
 
@@ -599,14 +599,14 @@
 
   qname, qtype, qclass = standardize (qname, qtype, qclass)
 
-  if not self.server then  self:adddefaultnameservers ()  end
+  if not self.server then self:adddefaultnameservers ()  end
 
   local question = encodeQuestion (qname, qtype, qclass)
   local peek = self:peek (qname, qtype, qclass)
   if peek then  return peek  end
 
   local header, id = encodeHeader ()
-  -- print ('query  id', id, qclass, qtype, qname)
+  --print ('query  id', id, qclass, qtype, qname)
   local o = { packet = header..question,
               server = 1,
               delay  = 1,
@@ -621,13 +621,15 @@
   local co = coroutine.running ()
   if co then
     set (self.wanted, qclass, qtype, qname, co, true)
-    set (self.yielded, co, qclass, qtype, qname, true)
-    end  end
+    --set (self.yielded, co, qclass, qtype, qname, true)
+  end
+end
+
 
 
 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
 
-  -- print 'receive'  print (self.socket)
+  --print 'receive'  print (self.socket)
   self.time = socket.gettime ()
   rset = rset or self.socket
 
@@ -640,8 +642,8 @@
 
     response = self:decode (packet)
     if response then
-    -- print 'received response'
-    -- self.print (response)
+    --print 'received response'
+    --self.print (response)
 
     for i,section in pairs { 'answer', 'authority', 'additional' } do
       for j,rr in pairs (response[section]) do
@@ -660,7 +662,7 @@
     if cos then
       for co in pairs (cos) do
         set (self.yielded, co, q.class, q.type, q.name, nil)
-	if not self.yielded[co] then  coroutine.resume (co)  end
+	if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
         end
       set (self.wanted, q.class, q.type, q.name, nil)
       end  end  end  end  end
@@ -669,10 +671,51 @@
   end
 
 
+function resolver:feed(sock, packet)
+  --print 'receive'  print (self.socket)
+  self.time = socket.gettime ()
+
+  local response = self:decode (packet)
+  if response then
+    --print 'received response'
+    --self.print (response)
+
+    for i,section in pairs { 'answer', 'authority', 'additional' } do
+      for j,rr in pairs (response[section]) do
+        self:remember (rr, response.question[1].type)
+      end
+    end
+
+    -- retire the query
+    local queries = self.active[response.header.id]
+    if queries[response.question.raw] then
+      queries[response.question.raw] = nil
+    end
+    if not next (queries) then  self.active[response.header.id] = nil  end
+    if not next (self.active) then  self:closeall ()  end
+
+    -- was the query on the wanted list?
+    local q = response.question[1]
+    if q then
+      local cos = get (self.wanted, q.class, q.type, q.name)
+      if cos then
+        for co in pairs (cos) do
+          set (self.yielded, co, q.class, q.type, q.name, nil)
+          if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
+        end
+        set (self.wanted, q.class, q.type, q.name, nil)
+      end
+    end
+  end 
+
+  return response
+end
+
+
 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
 
-  -- print ':pulse'
-  while self:receive () do end
+  --print ':pulse'
+  while self:receive() do end
   if not next (self.active) then  return nil  end
 
   self.time = socket.gettime ()
@@ -687,12 +730,12 @@
           end
 
         if o.delay > #self.delays then
-          print ('timeout')
+          --print ('timeout')
           queries[question] = nil
           if not next (queries) then  self.active[id] = nil  end
           if not next (self.active) then  return nil  end
         else
-          -- print ('retry', o.server, o.delay)
+          --print ('retry', o.server, o.delay)
           local _a = self.socket[o.server];
           if _a then _a:send (o.packet) end
           o.retry = self.time + self.delays[o.delay]
@@ -706,12 +749,16 @@
 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
   self:query (qname, qtype, qclass)
   while self:pulse () do  socket.select (self.socket, nil, 4)  end
-  -- print (self.cache)
+  --print (self.cache)
   return self:peek (qname, qtype, qclass)
   end
 
+function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
+  return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
+  end
 
--- print ---------------------------------------------------------------- print
+
+--print ---------------------------------------------------------------- print
 
 
 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
@@ -758,7 +805,7 @@
       for j,t in pairs (rr) do
         if not common[j] then
           tmp = string.format ('%s[%i].%s', s, i, j)
-          print (string.format ('%-30s  %s', tmp, t))
+          print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
           end  end  end  end  end
 
 
@@ -797,6 +844,9 @@
 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
   return resolve (resolver.query, ...)  end
 
+function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
+  return resolve (resolver.feed, ...)  end
+
 
 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
   return resolve (resolver.socket_wrapper_set, ...)  end