Comparison

net/server_select.lua @ 10563:e8db377a2983

Merge 0.11->trunk
author Kim Alvefur <zash@zash.se>
date Tue, 24 Dec 2019 00:39:45 +0100
parent 10474:175b72700d79
child 10851:6cf16abd0976
comparison
equal deleted inserted replaced
10562:670afc079f68 10563:e8db377a2983
66 local stats 66 local stats
67 local idfalse 67 local idfalse
68 local closeall 68 local closeall
69 local addsocket 69 local addsocket
70 local addserver 70 local addserver
71 local listen
71 local addtimer 72 local addtimer
72 local getserver 73 local getserver
73 local wrapserver 74 local wrapserver
74 local getsettings 75 local getsettings
75 local closesocket 76 local closesocket
121 122
122 ----------------------------------// DEFINITION //-- 123 ----------------------------------// DEFINITION //--
123 124
124 _server = { } -- key = port, value = table; list of listening servers 125 _server = { } -- key = port, value = table; list of listening servers
125 _readlist = { } -- array with sockets to read from 126 _readlist = { } -- array with sockets to read from
126 _sendlist = { } -- arrary with sockets to write to 127 _sendlist = { } -- array with sockets to write to
127 _timerlist = { } -- array of timer functions 128 _timerlist = { } -- array of timer functions
128 _socketlist = { } -- key = socket, value = wrapped socket (handlers) 129 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
129 _readtimes = { } -- key = handler, value = timestamp of last data reading 130 _readtimes = { } -- key = handler, value = timestamp of last data reading
130 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending 131 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
131 _closelist = { } -- handlers to close 132 _closelist = { } -- handlers to close
147 148
148 _checkinterval = 30 -- interval in secs to check idle clients 149 _checkinterval = 30 -- interval in secs to check idle clients
149 _sendtimeout = 60000 -- allowed send idle time in secs 150 _sendtimeout = 60000 -- allowed send idle time in secs
150 _readtimeout = 14 * 60 -- allowed read idle time in secs 151 _readtimeout = 14 * 60 -- allowed read idle time in secs
151 152
152 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows 153 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows
153 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows 154 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
154 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows 155 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
155 156
156 _maxsslhandshake = 30 -- max handshake round-trips 157 _maxsslhandshake = 30 -- max handshake round-trips
157 158
158 ----------------------------------// PRIVATE //-- 159 ----------------------------------// PRIVATE //--
159 160
160 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd 161 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
161 162
162 if socket:getfd() >= _maxfd then 163 if socket:getfd() >= _maxfd then
163 out_error("server.lua: Disallowed FD number: "..socket:getfd()) 164 out_error("server.lua: Disallowed FD number: "..socket:getfd())
164 socket:close() 165 socket:close()
165 return nil, "fd-too-large" 166 return nil, "fd-too-large"
181 return sslctx ~= nil 182 return sslctx ~= nil
182 end 183 end
183 handler.sslctx = function( ) 184 handler.sslctx = function( )
184 return sslctx 185 return sslctx
185 end 186 end
187 handler.hosts = {} -- sni
186 handler.remove = function( ) 188 handler.remove = function( )
187 connections = connections - 1 189 connections = connections - 1
188 if handler then 190 if handler then
189 handler.resume( ) 191 handler.resume( )
190 end 192 end
242 return false 244 return false
243 end 245 end
244 local client, err = accept( socket ) -- try to accept 246 local client, err = accept( socket ) -- try to accept
245 if client then 247 if client then
246 local ip, clientport = client:getpeername( ) 248 local ip, clientport = client:getpeername( )
247 local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket 249 local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket
248 if err then -- error while wrapping ssl socket 250 if err then -- error while wrapping ssl socket
249 return false 251 return false
250 end 252 end
251 connections = connections + 1 253 connections = connections + 1
252 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) 254 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
253 if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes 255 if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes
254 return dispatch( handler ); 256 return dispatch( handler );
255 end 257 end
256 return; 258 return;
257 elseif err then -- maybe timeout or something else 259 elseif err then -- maybe timeout or something else
258 out_put( "server.lua: error with new client connection: ", tostring(err) ) 260 out_put( "server.lua: error with new client connection: ", tostring(err) )
262 end 264 end
263 end 265 end
264 return handler 266 return handler
265 end 267 end
266 268
267 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object 269 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object
268 270
269 if socket:getfd() >= _maxfd then 271 if socket:getfd() >= _maxfd then
270 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent 272 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
271 socket:close( ) -- Should we send some kind of error here? 273 socket:close( ) -- Should we send some kind of error here?
272 if server then 274 if server then
311 local maxreadlen = _maxreadlen 313 local maxreadlen = _maxreadlen
312 314
313 --// public methods of the object //-- 315 --// public methods of the object //--
314 316
315 local handler = bufferqueue -- saves a table ^_^ 317 local handler = bufferqueue -- saves a table ^_^
318
319 handler.extra = extra
320 if extra then
321 handler.servername = extra.servername
322 end
316 323
317 handler.dispatch = function( ) 324 handler.dispatch = function( )
318 return dispatch 325 return dispatch
319 end 326 end
320 handler.disconnect = function( ) 327 handler.disconnect = function( )
422 local write = function( self, data ) 429 local write = function( self, data )
423 if not handler then return false end 430 if not handler then return false end
424 bufferlen = bufferlen + #data 431 bufferlen = bufferlen + #data
425 if bufferlen > maxsendlen then 432 if bufferlen > maxsendlen then
426 _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle 433 _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
427 handler.write = idfalse -- don't write anymore
428 return false 434 return false
429 elseif socket and not _sendlist[ socket ] then 435 elseif not nosend and socket and not _sendlist[ socket ] then
430 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) 436 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
431 end 437 end
432 bufferqueuelen = bufferqueuelen + 1 438 bufferqueuelen = bufferqueuelen + 1
433 bufferqueue[ bufferqueuelen ] = data 439 bufferqueue[ bufferqueuelen ] = data
434 if handler then 440 if handler then
454 handler.bufferlen = function( self, readlen, sendlen ) 460 handler.bufferlen = function( self, readlen, sendlen )
455 maxsendlen = sendlen or maxsendlen 461 maxsendlen = sendlen or maxsendlen
456 maxreadlen = readlen or maxreadlen 462 maxreadlen = readlen or maxreadlen
457 return bufferlen, maxreadlen, maxsendlen 463 return bufferlen, maxreadlen, maxsendlen
458 end 464 end
459 --TODO: Deprecate
460 handler.lock_read = function (self, switch) 465 handler.lock_read = function (self, switch)
466 out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" )
461 if switch == true then 467 if switch == true then
462 local tmp = _readlistlen 468 return self:pause()
463 _readlistlen = removesocket( _readlist, socket, _readlistlen )
464 _readtimes[ handler ] = nil
465 if _readlistlen ~= tmp then
466 noread = true
467 end
468 elseif switch == false then 469 elseif switch == false then
469 if noread then 470 return self:resume()
470 noread = false
471 _readlistlen = addsocket(_readlist, socket, _readlistlen)
472 _readtimes[ handler ] = _currenttime
473 end
474 end 471 end
475 return noread 472 return noread
476 end 473 end
477 handler.pause = function (self) 474 handler.pause = function (self)
478 return self:lock_read(true); 475 local tmp = _readlistlen
476 _readlistlen = removesocket( _readlist, socket, _readlistlen )
477 _readtimes[ handler ] = nil
478 if _readlistlen ~= tmp then
479 noread = true
480 end
481 return noread;
479 end 482 end
480 handler.resume = function (self) 483 handler.resume = function (self)
481 return self:lock_read(false); 484 if noread then
485 noread = false
486 _readlistlen = addsocket(_readlist, socket, _readlistlen)
487 _readtimes[ handler ] = _currenttime
488 end
489 return noread;
482 end 490 end
483 handler.lock = function( self, switch ) 491 handler.lock = function( self, switch )
484 handler.lock_read (switch) 492 out_error( "server.lua, lock() is deprecated" )
493 handler.lock_read (self, switch)
485 if switch == true then 494 if switch == true then
486 handler.write = idfalse 495 handler.pause_writes (self)
487 local tmp = _sendlistlen
488 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
489 _writetimes[ handler ] = nil
490 if _sendlistlen ~= tmp then
491 nosend = true
492 end
493 elseif switch == false then 496 elseif switch == false then
494 handler.write = write 497 handler.resume_writes (self)
495 if nosend then
496 nosend = false
497 write( "" )
498 end
499 end 498 end
500 return noread, nosend 499 return noread, nosend
501 end 500 end
501 handler.pause_writes = function (self)
502 local tmp = _sendlistlen
503 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
504 _writetimes[ handler ] = nil
505 nosend = true
506 end
507 handler.resume_writes = function (self)
508 nosend = false
509 if bufferlen > 0 then
510 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
511 end
512 end
513
502 local _readbuffer = function( ) -- this function reads data 514 local _readbuffer = function( ) -- this function reads data
503 local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" 515 local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
504 if not err or (err == "wantread" or err == "timeout") then -- received something 516 if not err or (err == "wantread" or err == "timeout") then -- received something
505 local buffer = buffer or part or "" 517 local buffer = buffer or part or ""
506 local len = #buffer 518 local len = #buffer
597 end 609 end
598 err = nil; 610 err = nil;
599 coroutine_yield( ) -- handshake not finished 611 coroutine_yield( ) -- handshake not finished
600 end 612 end
601 end 613 end
602 err = "ssl handshake error: " .. ( err or "handshake too long" ); 614 err = ( err or "handshake too long" );
603 out_put( "server.lua: ", err ); 615 out_put( "server.lua: ", err );
604 _ = handler and handler:force_close(err) 616 _ = handler and handler:force_close(err)
605 return false, err -- handshake failed 617 return false, err -- handshake failed
606 end 618 end
607 ) 619 )
617 return 629 return
618 end 630 end
619 out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) 631 out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
620 local oldsocket, err = socket 632 local oldsocket, err = socket
621 socket, err = ssl_wrap( socket, sslctx ) -- wrap socket 633 socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
634
622 if not socket then 635 if not socket then
623 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) 636 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
624 return nil, err -- fatal error 637 return nil, err -- fatal error
638 end
639
640 if socket.sni then
641 if self.servername then
642 socket:sni(self.servername);
643 elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
644 socket:sni(self.server().hosts, true);
645 end
625 end 646 end
626 647
627 socket:settimeout( 0 ) 648 socket:settimeout( 0 )
628 649
629 -- add the new socket to our system 650 -- add the new socket to our system
657 shutdown = ( ssl and id ) or socket.shutdown 678 shutdown = ( ssl and id ) or socket.shutdown
658 679
659 _socketlist[ socket ] = handler 680 _socketlist[ socket ] = handler
660 _readlistlen = addsocket(_readlist, socket, _readlistlen) 681 _readlistlen = addsocket(_readlist, socket, _readlistlen)
661 682
662 if sslctx and has_luasec then 683 if sslctx and ssldirect and has_luasec then
663 out_put "server.lua: auto-starting ssl negotiation..." 684 out_put "server.lua: auto-starting ssl negotiation..."
664 handler.autostart_ssl = true; 685 handler.autostart_ssl = true;
665 local ok, err = handler:starttls(sslctx); 686 local ok, err = handler:starttls(sslctx);
666 if ok == false then 687 if ok == false then
667 return nil, nil, err 688 return nil, nil, err
732 sender:set_mode("*a"); 753 sender:set_mode("*a");
733 end 754 end
734 755
735 ----------------------------------// PUBLIC //-- 756 ----------------------------------// PUBLIC //--
736 757
737 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server 758 listen = function ( addr, port, listeners, config )
738 addr = addr or "*" 759 addr = addr or "*"
760 config = config or {}
739 local err 761 local err
762 local sslctx = config.tls_ctx;
763 local ssldirect = config.tls_direct;
764 local pattern = config.read_size;
740 if type( listeners ) ~= "table" then 765 if type( listeners ) ~= "table" then
741 err = "invalid listener table" 766 err = "invalid listener table"
742 elseif type ( addr ) ~= "string" then 767 elseif type ( addr ) ~= "string" then
743 err = "invalid address" 768 err = "invalid address"
744 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then 769 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
755 local server, err = socket_bind( addr, port, _tcpbacklog ) 780 local server, err = socket_bind( addr, port, _tcpbacklog )
756 if err then 781 if err then
757 out_error( "server.lua, [", addr, "]:", port, ": ", err ) 782 out_error( "server.lua, [", addr, "]:", port, ": ", err )
758 return nil, err 783 return nil, err
759 end 784 end
760 local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket 785 local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket
761 if not handler then 786 if not handler then
762 server:close( ) 787 server:close( )
763 return nil, err 788 return nil, err
764 end 789 end
765 server:settimeout( 0 ) 790 server:settimeout( 0 )
766 _readlistlen = addsocket(_readlist, server, _readlistlen) 791 _readlistlen = addsocket(_readlist, server, _readlistlen)
767 _server[ addr..":"..port ] = handler 792 _server[ addr..":"..port ] = handler
768 _socketlist[ server ] = handler 793 _socketlist[ server ] = handler
769 out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) 794 out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
770 return handler 795 return handler
796 end
797
798 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
799 return listen(addr, port, listeners, {
800 read_size = pattern;
801 tls_ctx = sslctx;
802 tls_direct = sslctx and true or false;
803 });
771 end 804 end
772 805
773 getserver = function ( addr, port ) 806 getserver = function ( addr, port )
774 return _server[ addr..":"..port ]; 807 return _server[ addr..":"..port ];
775 end 808 end
975 return "select"; 1008 return "select";
976 end 1009 end
977 1010
978 --// EXPERIMENTAL //-- 1011 --// EXPERIMENTAL //--
979 1012
980 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) 1013 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra )
981 local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) 1014 local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra)
982 if not handler then return nil, err end 1015 if not handler then return nil, err end
983 _socketlist[ socket ] = handler 1016 _socketlist[ socket ] = handler
984 if not sslctx then 1017 if not sslctx then
985 _readlistlen = addsocket(_readlist, socket, _readlistlen) 1018 _readlistlen = addsocket(_readlist, socket, _readlistlen)
986 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) 1019 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
995 end 1028 end
996 end 1029 end
997 return handler, socket 1030 return handler, socket
998 end 1031 end
999 1032
1000 local addclient = function( address, port, listeners, pattern, sslctx, typ ) 1033 local addclient = function( address, port, listeners, pattern, sslctx, typ, extra )
1001 local err 1034 local err
1002 if type( listeners ) ~= "table" then 1035 if type( listeners ) ~= "table" then
1003 err = "invalid listener table" 1036 err = "invalid listener table"
1004 elseif type ( address ) ~= "string" then 1037 elseif type ( address ) ~= "string" then
1005 err = "invalid address" 1038 err = "invalid address"
1032 return nil, err 1065 return nil, err
1033 end 1066 end
1034 client:settimeout( 0 ) 1067 client:settimeout( 0 )
1035 local ok, err = client:setpeername( address, port ) 1068 local ok, err = client:setpeername( address, port )
1036 if ok or err == "timeout" or err == "Operation already in progress" then 1069 if ok or err == "timeout" or err == "Operation already in progress" then
1037 return wrapclient( client, address, port, listeners, pattern, sslctx ) 1070 return wrapclient( client, address, port, listeners, pattern, sslctx, extra )
1038 else 1071 else
1039 return nil, err 1072 return nil, err
1040 end 1073 end
1041 end 1074 end
1042 1075
1112 link = link, 1145 link = link,
1113 step = step, 1146 step = step,
1114 stats = stats, 1147 stats = stats,
1115 closeall = closeall, 1148 closeall = closeall,
1116 addserver = addserver, 1149 addserver = addserver,
1150 listen = listen,
1117 getserver = getserver, 1151 getserver = getserver,
1118 setlogger = setlogger, 1152 setlogger = setlogger,
1119 getsettings = getsettings, 1153 getsettings = getsettings,
1120 setquitting = setquitting, 1154 setquitting = setquitting,
1121 removeserver = removeserver, 1155 removeserver = removeserver,