Software /
code /
prosody
Comparison
net/server_select.lua @ 10563:e8db377a2983
Merge 0.11->trunk
author | Kim Alvefur <zash@zash.se> |
---|---|
date | Tue, 24 Dec 2019 00:39:45 +0100 |
parent | 10474:175b72700d79 |
child | 10851:6cf16abd0976 |
comparison
equal
deleted
inserted
replaced
10562:670afc079f68 | 10563:e8db377a2983 |
---|---|
66 local stats | 66 local stats |
67 local idfalse | 67 local idfalse |
68 local closeall | 68 local closeall |
69 local addsocket | 69 local addsocket |
70 local addserver | 70 local addserver |
71 local listen | |
71 local addtimer | 72 local addtimer |
72 local getserver | 73 local getserver |
73 local wrapserver | 74 local wrapserver |
74 local getsettings | 75 local getsettings |
75 local closesocket | 76 local closesocket |
121 | 122 |
122 ----------------------------------// DEFINITION //-- | 123 ----------------------------------// DEFINITION //-- |
123 | 124 |
124 _server = { } -- key = port, value = table; list of listening servers | 125 _server = { } -- key = port, value = table; list of listening servers |
125 _readlist = { } -- array with sockets to read from | 126 _readlist = { } -- array with sockets to read from |
126 _sendlist = { } -- arrary with sockets to write to | 127 _sendlist = { } -- array with sockets to write to |
127 _timerlist = { } -- array of timer functions | 128 _timerlist = { } -- array of timer functions |
128 _socketlist = { } -- key = socket, value = wrapped socket (handlers) | 129 _socketlist = { } -- key = socket, value = wrapped socket (handlers) |
129 _readtimes = { } -- key = handler, value = timestamp of last data reading | 130 _readtimes = { } -- key = handler, value = timestamp of last data reading |
130 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending | 131 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending |
131 _closelist = { } -- handlers to close | 132 _closelist = { } -- handlers to close |
147 | 148 |
148 _checkinterval = 30 -- interval in secs to check idle clients | 149 _checkinterval = 30 -- interval in secs to check idle clients |
149 _sendtimeout = 60000 -- allowed send idle time in secs | 150 _sendtimeout = 60000 -- allowed send idle time in secs |
150 _readtimeout = 14 * 60 -- allowed read idle time in secs | 151 _readtimeout = 14 * 60 -- allowed read idle time in secs |
151 | 152 |
152 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows | 153 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows |
153 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows | 154 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows |
154 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows | 155 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows |
155 | 156 |
156 _maxsslhandshake = 30 -- max handshake round-trips | 157 _maxsslhandshake = 30 -- max handshake round-trips |
157 | 158 |
158 ----------------------------------// PRIVATE //-- | 159 ----------------------------------// PRIVATE //-- |
159 | 160 |
160 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd | 161 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd |
161 | 162 |
162 if socket:getfd() >= _maxfd then | 163 if socket:getfd() >= _maxfd then |
163 out_error("server.lua: Disallowed FD number: "..socket:getfd()) | 164 out_error("server.lua: Disallowed FD number: "..socket:getfd()) |
164 socket:close() | 165 socket:close() |
165 return nil, "fd-too-large" | 166 return nil, "fd-too-large" |
181 return sslctx ~= nil | 182 return sslctx ~= nil |
182 end | 183 end |
183 handler.sslctx = function( ) | 184 handler.sslctx = function( ) |
184 return sslctx | 185 return sslctx |
185 end | 186 end |
187 handler.hosts = {} -- sni | |
186 handler.remove = function( ) | 188 handler.remove = function( ) |
187 connections = connections - 1 | 189 connections = connections - 1 |
188 if handler then | 190 if handler then |
189 handler.resume( ) | 191 handler.resume( ) |
190 end | 192 end |
242 return false | 244 return false |
243 end | 245 end |
244 local client, err = accept( socket ) -- try to accept | 246 local client, err = accept( socket ) -- try to accept |
245 if client then | 247 if client then |
246 local ip, clientport = client:getpeername( ) | 248 local ip, clientport = client:getpeername( ) |
247 local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket | 249 local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket |
248 if err then -- error while wrapping ssl socket | 250 if err then -- error while wrapping ssl socket |
249 return false | 251 return false |
250 end | 252 end |
251 connections = connections + 1 | 253 connections = connections + 1 |
252 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) | 254 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) |
253 if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes | 255 if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes |
254 return dispatch( handler ); | 256 return dispatch( handler ); |
255 end | 257 end |
256 return; | 258 return; |
257 elseif err then -- maybe timeout or something else | 259 elseif err then -- maybe timeout or something else |
258 out_put( "server.lua: error with new client connection: ", tostring(err) ) | 260 out_put( "server.lua: error with new client connection: ", tostring(err) ) |
262 end | 264 end |
263 end | 265 end |
264 return handler | 266 return handler |
265 end | 267 end |
266 | 268 |
267 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object | 269 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object |
268 | 270 |
269 if socket:getfd() >= _maxfd then | 271 if socket:getfd() >= _maxfd then |
270 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent | 272 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent |
271 socket:close( ) -- Should we send some kind of error here? | 273 socket:close( ) -- Should we send some kind of error here? |
272 if server then | 274 if server then |
311 local maxreadlen = _maxreadlen | 313 local maxreadlen = _maxreadlen |
312 | 314 |
313 --// public methods of the object //-- | 315 --// public methods of the object //-- |
314 | 316 |
315 local handler = bufferqueue -- saves a table ^_^ | 317 local handler = bufferqueue -- saves a table ^_^ |
318 | |
319 handler.extra = extra | |
320 if extra then | |
321 handler.servername = extra.servername | |
322 end | |
316 | 323 |
317 handler.dispatch = function( ) | 324 handler.dispatch = function( ) |
318 return dispatch | 325 return dispatch |
319 end | 326 end |
320 handler.disconnect = function( ) | 327 handler.disconnect = function( ) |
422 local write = function( self, data ) | 429 local write = function( self, data ) |
423 if not handler then return false end | 430 if not handler then return false end |
424 bufferlen = bufferlen + #data | 431 bufferlen = bufferlen + #data |
425 if bufferlen > maxsendlen then | 432 if bufferlen > maxsendlen then |
426 _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle | 433 _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle |
427 handler.write = idfalse -- don't write anymore | |
428 return false | 434 return false |
429 elseif socket and not _sendlist[ socket ] then | 435 elseif not nosend and socket and not _sendlist[ socket ] then |
430 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) | 436 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) |
431 end | 437 end |
432 bufferqueuelen = bufferqueuelen + 1 | 438 bufferqueuelen = bufferqueuelen + 1 |
433 bufferqueue[ bufferqueuelen ] = data | 439 bufferqueue[ bufferqueuelen ] = data |
434 if handler then | 440 if handler then |
454 handler.bufferlen = function( self, readlen, sendlen ) | 460 handler.bufferlen = function( self, readlen, sendlen ) |
455 maxsendlen = sendlen or maxsendlen | 461 maxsendlen = sendlen or maxsendlen |
456 maxreadlen = readlen or maxreadlen | 462 maxreadlen = readlen or maxreadlen |
457 return bufferlen, maxreadlen, maxsendlen | 463 return bufferlen, maxreadlen, maxsendlen |
458 end | 464 end |
459 --TODO: Deprecate | |
460 handler.lock_read = function (self, switch) | 465 handler.lock_read = function (self, switch) |
466 out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" ) | |
461 if switch == true then | 467 if switch == true then |
462 local tmp = _readlistlen | 468 return self:pause() |
463 _readlistlen = removesocket( _readlist, socket, _readlistlen ) | |
464 _readtimes[ handler ] = nil | |
465 if _readlistlen ~= tmp then | |
466 noread = true | |
467 end | |
468 elseif switch == false then | 469 elseif switch == false then |
469 if noread then | 470 return self:resume() |
470 noread = false | |
471 _readlistlen = addsocket(_readlist, socket, _readlistlen) | |
472 _readtimes[ handler ] = _currenttime | |
473 end | |
474 end | 471 end |
475 return noread | 472 return noread |
476 end | 473 end |
477 handler.pause = function (self) | 474 handler.pause = function (self) |
478 return self:lock_read(true); | 475 local tmp = _readlistlen |
476 _readlistlen = removesocket( _readlist, socket, _readlistlen ) | |
477 _readtimes[ handler ] = nil | |
478 if _readlistlen ~= tmp then | |
479 noread = true | |
480 end | |
481 return noread; | |
479 end | 482 end |
480 handler.resume = function (self) | 483 handler.resume = function (self) |
481 return self:lock_read(false); | 484 if noread then |
485 noread = false | |
486 _readlistlen = addsocket(_readlist, socket, _readlistlen) | |
487 _readtimes[ handler ] = _currenttime | |
488 end | |
489 return noread; | |
482 end | 490 end |
483 handler.lock = function( self, switch ) | 491 handler.lock = function( self, switch ) |
484 handler.lock_read (switch) | 492 out_error( "server.lua, lock() is deprecated" ) |
493 handler.lock_read (self, switch) | |
485 if switch == true then | 494 if switch == true then |
486 handler.write = idfalse | 495 handler.pause_writes (self) |
487 local tmp = _sendlistlen | |
488 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) | |
489 _writetimes[ handler ] = nil | |
490 if _sendlistlen ~= tmp then | |
491 nosend = true | |
492 end | |
493 elseif switch == false then | 496 elseif switch == false then |
494 handler.write = write | 497 handler.resume_writes (self) |
495 if nosend then | |
496 nosend = false | |
497 write( "" ) | |
498 end | |
499 end | 498 end |
500 return noread, nosend | 499 return noread, nosend |
501 end | 500 end |
501 handler.pause_writes = function (self) | |
502 local tmp = _sendlistlen | |
503 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) | |
504 _writetimes[ handler ] = nil | |
505 nosend = true | |
506 end | |
507 handler.resume_writes = function (self) | |
508 nosend = false | |
509 if bufferlen > 0 then | |
510 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) | |
511 end | |
512 end | |
513 | |
502 local _readbuffer = function( ) -- this function reads data | 514 local _readbuffer = function( ) -- this function reads data |
503 local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" | 515 local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" |
504 if not err or (err == "wantread" or err == "timeout") then -- received something | 516 if not err or (err == "wantread" or err == "timeout") then -- received something |
505 local buffer = buffer or part or "" | 517 local buffer = buffer or part or "" |
506 local len = #buffer | 518 local len = #buffer |
597 end | 609 end |
598 err = nil; | 610 err = nil; |
599 coroutine_yield( ) -- handshake not finished | 611 coroutine_yield( ) -- handshake not finished |
600 end | 612 end |
601 end | 613 end |
602 err = "ssl handshake error: " .. ( err or "handshake too long" ); | 614 err = ( err or "handshake too long" ); |
603 out_put( "server.lua: ", err ); | 615 out_put( "server.lua: ", err ); |
604 _ = handler and handler:force_close(err) | 616 _ = handler and handler:force_close(err) |
605 return false, err -- handshake failed | 617 return false, err -- handshake failed |
606 end | 618 end |
607 ) | 619 ) |
617 return | 629 return |
618 end | 630 end |
619 out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) | 631 out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) |
620 local oldsocket, err = socket | 632 local oldsocket, err = socket |
621 socket, err = ssl_wrap( socket, sslctx ) -- wrap socket | 633 socket, err = ssl_wrap( socket, sslctx ) -- wrap socket |
634 | |
622 if not socket then | 635 if not socket then |
623 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) | 636 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) |
624 return nil, err -- fatal error | 637 return nil, err -- fatal error |
638 end | |
639 | |
640 if socket.sni then | |
641 if self.servername then | |
642 socket:sni(self.servername); | |
643 elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then | |
644 socket:sni(self.server().hosts, true); | |
645 end | |
625 end | 646 end |
626 | 647 |
627 socket:settimeout( 0 ) | 648 socket:settimeout( 0 ) |
628 | 649 |
629 -- add the new socket to our system | 650 -- add the new socket to our system |
657 shutdown = ( ssl and id ) or socket.shutdown | 678 shutdown = ( ssl and id ) or socket.shutdown |
658 | 679 |
659 _socketlist[ socket ] = handler | 680 _socketlist[ socket ] = handler |
660 _readlistlen = addsocket(_readlist, socket, _readlistlen) | 681 _readlistlen = addsocket(_readlist, socket, _readlistlen) |
661 | 682 |
662 if sslctx and has_luasec then | 683 if sslctx and ssldirect and has_luasec then |
663 out_put "server.lua: auto-starting ssl negotiation..." | 684 out_put "server.lua: auto-starting ssl negotiation..." |
664 handler.autostart_ssl = true; | 685 handler.autostart_ssl = true; |
665 local ok, err = handler:starttls(sslctx); | 686 local ok, err = handler:starttls(sslctx); |
666 if ok == false then | 687 if ok == false then |
667 return nil, nil, err | 688 return nil, nil, err |
732 sender:set_mode("*a"); | 753 sender:set_mode("*a"); |
733 end | 754 end |
734 | 755 |
735 ----------------------------------// PUBLIC //-- | 756 ----------------------------------// PUBLIC //-- |
736 | 757 |
737 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server | 758 listen = function ( addr, port, listeners, config ) |
738 addr = addr or "*" | 759 addr = addr or "*" |
760 config = config or {} | |
739 local err | 761 local err |
762 local sslctx = config.tls_ctx; | |
763 local ssldirect = config.tls_direct; | |
764 local pattern = config.read_size; | |
740 if type( listeners ) ~= "table" then | 765 if type( listeners ) ~= "table" then |
741 err = "invalid listener table" | 766 err = "invalid listener table" |
742 elseif type ( addr ) ~= "string" then | 767 elseif type ( addr ) ~= "string" then |
743 err = "invalid address" | 768 err = "invalid address" |
744 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then | 769 elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then |
755 local server, err = socket_bind( addr, port, _tcpbacklog ) | 780 local server, err = socket_bind( addr, port, _tcpbacklog ) |
756 if err then | 781 if err then |
757 out_error( "server.lua, [", addr, "]:", port, ": ", err ) | 782 out_error( "server.lua, [", addr, "]:", port, ": ", err ) |
758 return nil, err | 783 return nil, err |
759 end | 784 end |
760 local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket | 785 local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket |
761 if not handler then | 786 if not handler then |
762 server:close( ) | 787 server:close( ) |
763 return nil, err | 788 return nil, err |
764 end | 789 end |
765 server:settimeout( 0 ) | 790 server:settimeout( 0 ) |
766 _readlistlen = addsocket(_readlist, server, _readlistlen) | 791 _readlistlen = addsocket(_readlist, server, _readlistlen) |
767 _server[ addr..":"..port ] = handler | 792 _server[ addr..":"..port ] = handler |
768 _socketlist[ server ] = handler | 793 _socketlist[ server ] = handler |
769 out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) | 794 out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) |
770 return handler | 795 return handler |
796 end | |
797 | |
798 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server | |
799 return listen(addr, port, listeners, { | |
800 read_size = pattern; | |
801 tls_ctx = sslctx; | |
802 tls_direct = sslctx and true or false; | |
803 }); | |
771 end | 804 end |
772 | 805 |
773 getserver = function ( addr, port ) | 806 getserver = function ( addr, port ) |
774 return _server[ addr..":"..port ]; | 807 return _server[ addr..":"..port ]; |
775 end | 808 end |
975 return "select"; | 1008 return "select"; |
976 end | 1009 end |
977 | 1010 |
978 --// EXPERIMENTAL //-- | 1011 --// EXPERIMENTAL //-- |
979 | 1012 |
980 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) | 1013 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra ) |
981 local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) | 1014 local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra) |
982 if not handler then return nil, err end | 1015 if not handler then return nil, err end |
983 _socketlist[ socket ] = handler | 1016 _socketlist[ socket ] = handler |
984 if not sslctx then | 1017 if not sslctx then |
985 _readlistlen = addsocket(_readlist, socket, _readlistlen) | 1018 _readlistlen = addsocket(_readlist, socket, _readlistlen) |
986 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) | 1019 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) |
995 end | 1028 end |
996 end | 1029 end |
997 return handler, socket | 1030 return handler, socket |
998 end | 1031 end |
999 | 1032 |
1000 local addclient = function( address, port, listeners, pattern, sslctx, typ ) | 1033 local addclient = function( address, port, listeners, pattern, sslctx, typ, extra ) |
1001 local err | 1034 local err |
1002 if type( listeners ) ~= "table" then | 1035 if type( listeners ) ~= "table" then |
1003 err = "invalid listener table" | 1036 err = "invalid listener table" |
1004 elseif type ( address ) ~= "string" then | 1037 elseif type ( address ) ~= "string" then |
1005 err = "invalid address" | 1038 err = "invalid address" |
1032 return nil, err | 1065 return nil, err |
1033 end | 1066 end |
1034 client:settimeout( 0 ) | 1067 client:settimeout( 0 ) |
1035 local ok, err = client:setpeername( address, port ) | 1068 local ok, err = client:setpeername( address, port ) |
1036 if ok or err == "timeout" or err == "Operation already in progress" then | 1069 if ok or err == "timeout" or err == "Operation already in progress" then |
1037 return wrapclient( client, address, port, listeners, pattern, sslctx ) | 1070 return wrapclient( client, address, port, listeners, pattern, sslctx, extra ) |
1038 else | 1071 else |
1039 return nil, err | 1072 return nil, err |
1040 end | 1073 end |
1041 end | 1074 end |
1042 | 1075 |
1112 link = link, | 1145 link = link, |
1113 step = step, | 1146 step = step, |
1114 stats = stats, | 1147 stats = stats, |
1115 closeall = closeall, | 1148 closeall = closeall, |
1116 addserver = addserver, | 1149 addserver = addserver, |
1150 listen = listen, | |
1117 getserver = getserver, | 1151 getserver = getserver, |
1118 setlogger = setlogger, | 1152 setlogger = setlogger, |
1119 getsettings = getsettings, | 1153 getsettings = getsettings, |
1120 setquitting = setquitting, | 1154 setquitting = setquitting, |
1121 removeserver = removeserver, | 1155 removeserver = removeserver, |