Comparison

net/dns.lua @ 869:09019c452709

net.dns: Add methods necessary for allowing non-blocking DNS lookups
author Matthew Wild <mwild1@gmail.com>
date Wed, 04 Mar 2009 12:58:56 +0000
parent 759:5cccfb5da6cb
child 896:2c0b9e3c11c3
comparison
equal deleted inserted replaced
868:9e058e51ecaf 869:09019c452709
14 -- reference: http://tools.ietf.org/html/rfc1876 (LOC) 14 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
15 15
16 16
17 require 'socket' 17 require 'socket'
18 local ztact = require 'util.ztact' 18 local ztact = require 'util.ztact'
19 19 local require = require
20 20
21 local coroutine, io, math, socket, string, table = 21 local coroutine, io, math, socket, string, table =
22 coroutine, io, math, socket, string, table 22 coroutine, io, math, socket, string, table
23 23
24 local ipairs, next, pairs, print, setmetatable, tostring, assert, error = 24 local ipairs, next, pairs, print, setmetatable, tostring, assert, error =
251 end 251 end
252 252
253 253
254 function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword 254 function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword
255 local b1, b2, b3, b4 = self:byte (4) 255 local b1, b2, b3, b4 = self:byte (4)
256 -- print ('dword', b1, b2, b3, b4) 256 --print ('dword', b1, b2, b3, b4)
257 return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 257 return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
258 end 258 end
259 259
260 260
261 function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub 261 function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub
267 267
268 268
269 function resolver:header (force) -- - - - - - - - - - - - - - - - - - header 269 function resolver:header (force) -- - - - - - - - - - - - - - - - - - header
270 270
271 local id = self:word () 271 local id = self:word ()
272 -- print (string.format (':header id %x', id)) 272 --print (string.format (':header id %x', id))
273 if not self.active[id] and not force then return nil end 273 if not self.active[id] and not force then return nil end
274 274
275 local h = { id = id } 275 local h = { id = id }
276 276
277 local b1, b2 = self:byte (2) 277 local b1, b2 = self:byte (2)
320 320
321 function resolver:question () -- - - - - - - - - - - - - - - - - - question 321 function resolver:question () -- - - - - - - - - - - - - - - - - - question
322 local q = {} 322 local q = {}
323 q.name = self:name () 323 q.name = self:name ()
324 q.type = dns.type[self:word ()] 324 q.type = dns.type[self:word ()]
325 q.class = dns.type[self:word ()] 325 q.class = dns.class[self:word ()]
326 return q 326 return q
327 end 327 end
328 328
329 329
330 function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A 330 function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
344 end 344 end
345 345
346 346
347 function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power 347 function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power
348 local b = self:byte () 348 local b = self:byte ()
349 -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) 349 --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
350 return ((b-(b%0x10))/0x10) * (10^(b%0x10)) 350 return ((b-(b%0x10))/0x10) * (10^(b%0x10))
351 end 351 end
352 352
353 353
354 function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC 354 function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC
547 end 547 end
548 548
549 549
550 function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember 550 function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember
551 551
552 -- print ('remember', type, rr.class, rr.type, rr.name) 552 --print ('remember', type, rr.class, rr.type, rr.name)
553 553
554 if type ~= '*' then 554 if type ~= '*' then
555 type = rr.type 555 type = rr.type
556 local all = get (self.cache, rr.class, '*', rr.name) 556 local all = get (self.cache, rr.class, '*', rr.name)
557 -- print ('remember all', all) 557 --print ('remember all', all)
558 if all then append (all, rr) end 558 if all then append (all, rr) end
559 end 559 end
560 560
561 self.cache = self.cache or setmetatable ({}, cache_metatable) 561 self.cache = self.cache or setmetatable ({}, cache_metatable)
562 local rrs = get (self.cache, rr.class, type, rr.name) or 562 local rrs = get (self.cache, rr.class, type, rr.name) or
597 597
598 function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query 598 function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query
599 599
600 qname, qtype, qclass = standardize (qname, qtype, qclass) 600 qname, qtype, qclass = standardize (qname, qtype, qclass)
601 601
602 if not self.server then self:adddefaultnameservers () end 602 if not self.server then self:adddefaultnameservers () end
603 603
604 local question = encodeQuestion (qname, qtype, qclass) 604 local question = encodeQuestion (qname, qtype, qclass)
605 local peek = self:peek (qname, qtype, qclass) 605 local peek = self:peek (qname, qtype, qclass)
606 if peek then return peek end 606 if peek then return peek end
607 607
608 local header, id = encodeHeader () 608 local header, id = encodeHeader ()
609 -- print ('query id', id, qclass, qtype, qname) 609 --print ('query id', id, qclass, qtype, qname)
610 local o = { packet = header..question, 610 local o = { packet = header..question,
611 server = 1, 611 server = 1,
612 delay = 1, 612 delay = 1,
613 retry = socket.gettime () + self.delays[1] } 613 retry = socket.gettime () + self.delays[1] }
614 self:getsocket (o.server):send (o.packet) 614 self:getsocket (o.server):send (o.packet)
619 619
620 -- remember which coroutine wants the answer 620 -- remember which coroutine wants the answer
621 local co = coroutine.running () 621 local co = coroutine.running ()
622 if co then 622 if co then
623 set (self.wanted, qclass, qtype, qname, co, true) 623 set (self.wanted, qclass, qtype, qname, co, true)
624 set (self.yielded, co, qclass, qtype, qname, true) 624 --set (self.yielded, co, qclass, qtype, qname, true)
625 end end 625 end
626 end
627
626 628
627 629
628 function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive 630 function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
629 631
630 -- print 'receive' print (self.socket) 632 --print 'receive' print (self.socket)
631 self.time = socket.gettime () 633 self.time = socket.gettime ()
632 rset = rset or self.socket 634 rset = rset or self.socket
633 635
634 local response 636 local response
635 for i,sock in pairs (rset) do 637 for i,sock in pairs (rset) do
638 local packet = sock:receive () 640 local packet = sock:receive ()
639 if packet then 641 if packet then
640 642
641 response = self:decode (packet) 643 response = self:decode (packet)
642 if response then 644 if response then
643 -- print 'received response' 645 --print 'received response'
644 -- self.print (response) 646 --self.print (response)
645 647
646 for i,section in pairs { 'answer', 'authority', 'additional' } do 648 for i,section in pairs { 'answer', 'authority', 'additional' } do
647 for j,rr in pairs (response[section]) do 649 for j,rr in pairs (response[section]) do
648 self:remember (rr, response.question[1].type) end end 650 self:remember (rr, response.question[1].type) end end
649 651
658 local q = response.question 660 local q = response.question
659 local cos = get (self.wanted, q.class, q.type, q.name) 661 local cos = get (self.wanted, q.class, q.type, q.name)
660 if cos then 662 if cos then
661 for co in pairs (cos) do 663 for co in pairs (cos) do
662 set (self.yielded, co, q.class, q.type, q.name, nil) 664 set (self.yielded, co, q.class, q.type, q.name, nil)
663 if not self.yielded[co] then coroutine.resume (co) end 665 if coroutine.status(co) == "suspended" then coroutine.resume (co) end
664 end 666 end
665 set (self.wanted, q.class, q.type, q.name, nil) 667 set (self.wanted, q.class, q.type, q.name, nil)
666 end end end end end 668 end end end end end
667 669
668 return response 670 return response
669 end 671 end
670 672
671 673
674 function resolver:feed(sock, packet)
675 --print 'receive' print (self.socket)
676 self.time = socket.gettime ()
677
678 local response = self:decode (packet)
679 if response then
680 --print 'received response'
681 --self.print (response)
682
683 for i,section in pairs { 'answer', 'authority', 'additional' } do
684 for j,rr in pairs (response[section]) do
685 self:remember (rr, response.question[1].type)
686 end
687 end
688
689 -- retire the query
690 local queries = self.active[response.header.id]
691 if queries[response.question.raw] then
692 queries[response.question.raw] = nil
693 end
694 if not next (queries) then self.active[response.header.id] = nil end
695 if not next (self.active) then self:closeall () end
696
697 -- was the query on the wanted list?
698 local q = response.question[1]
699 if q then
700 local cos = get (self.wanted, q.class, q.type, q.name)
701 if cos then
702 for co in pairs (cos) do
703 set (self.yielded, co, q.class, q.type, q.name, nil)
704 if coroutine.status(co) == "suspended" then coroutine.resume (co) end
705 end
706 set (self.wanted, q.class, q.type, q.name, nil)
707 end
708 end
709 end
710
711 return response
712 end
713
714
672 function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse 715 function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
673 716
674 -- print ':pulse' 717 --print ':pulse'
675 while self:receive () do end 718 while self:receive() do end
676 if not next (self.active) then return nil end 719 if not next (self.active) then return nil end
677 720
678 self.time = socket.gettime () 721 self.time = socket.gettime ()
679 for id,queries in pairs (self.active) do 722 for id,queries in pairs (self.active) do
680 for question,o in pairs (queries) do 723 for question,o in pairs (queries) do
685 o.server = 1 728 o.server = 1
686 o.delay = o.delay + 1 729 o.delay = o.delay + 1
687 end 730 end
688 731
689 if o.delay > #self.delays then 732 if o.delay > #self.delays then
690 print ('timeout') 733 --print ('timeout')
691 queries[question] = nil 734 queries[question] = nil
692 if not next (queries) then self.active[id] = nil end 735 if not next (queries) then self.active[id] = nil end
693 if not next (self.active) then return nil end 736 if not next (self.active) then return nil end
694 else 737 else
695 -- print ('retry', o.server, o.delay) 738 --print ('retry', o.server, o.delay)
696 local _a = self.socket[o.server]; 739 local _a = self.socket[o.server];
697 if _a then _a:send (o.packet) end 740 if _a then _a:send (o.packet) end
698 o.retry = self.time + self.delays[o.delay] 741 o.retry = self.time + self.delays[o.delay]
699 end end end end 742 end end end end
700 743
704 747
705 748
706 function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup 749 function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup
707 self:query (qname, qtype, qclass) 750 self:query (qname, qtype, qclass)
708 while self:pulse () do socket.select (self.socket, nil, 4) end 751 while self:pulse () do socket.select (self.socket, nil, 4) end
709 -- print (self.cache) 752 --print (self.cache)
710 return self:peek (qname, qtype, qclass) 753 return self:peek (qname, qtype, qclass)
711 end 754 end
712 755
713 756 function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup
714 -- print ---------------------------------------------------------------- print 757 return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
758 end
759
760
761 --print ---------------------------------------------------------------- print
715 762
716 763
717 local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints 764 local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
718 qr = { [0]='query', 'response' }, 765 qr = { [0]='query', 'response' },
719 opcode = { [0]='query', 'inverse query', 'server status request' }, 766 opcode = { [0]='query', 'inverse query', 'server status request' },
756 print (string.format ('%-30s', tmp), rr[t], hint (rr, t)) 803 print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
757 end 804 end
758 for j,t in pairs (rr) do 805 for j,t in pairs (rr) do
759 if not common[j] then 806 if not common[j] then
760 tmp = string.format ('%s[%i].%s', s, i, j) 807 tmp = string.format ('%s[%i].%s', s, i, j)
761 print (string.format ('%-30s %s', tmp, t)) 808 print (string.format ('%-30s %s', tostring(tmp), tostring(t)))
762 end end end end end 809 end end end end end
763 810
764 811
765 -- module api ------------------------------------------------------ module api 812 -- module api ------------------------------------------------------ module api
766 813
795 842
796 843
797 function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query 844 function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query
798 return resolve (resolver.query, ...) end 845 return resolve (resolver.query, ...) end
799 846
847 function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed
848 return resolve (resolver.feed, ...) end
849
800 850
801 function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set 851 function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set
802 return resolve (resolver.socket_wrapper_set, ...) end 852 return resolve (resolver.socket_wrapper_set, ...) end
803 853
804 854