Diff

net/server_select.lua @ 10563:e8db377a2983

Merge 0.11->trunk
author Kim Alvefur <zash@zash.se>
date Tue, 24 Dec 2019 00:39:45 +0100
parent 10474:175b72700d79
child 10851:6cf16abd0976
line wrap: on
line diff
--- a/net/server_select.lua	Tue Dec 24 00:26:40 2019 +0100
+++ b/net/server_select.lua	Tue Dec 24 00:39:45 2019 +0100
@@ -68,6 +68,7 @@
 local closeall
 local addsocket
 local addserver
+local listen
 local addtimer
 local getserver
 local wrapserver
@@ -123,7 +124,7 @@
 
 _server = { } -- key = port, value = table; list of listening servers
 _readlist = { } -- array with sockets to read from
-_sendlist = { } -- arrary with sockets to write to
+_sendlist = { } -- array with sockets to write to
 _timerlist = { } -- array of timer functions
 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
 _readtimes = { } -- key = handler, value = timestamp of last data reading
@@ -149,7 +150,7 @@
 _sendtimeout = 60000 -- allowed send idle time in secs
 _readtimeout = 14 * 60 -- allowed read idle time in secs
 
-local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
+local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows
 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
 
@@ -157,7 +158,7 @@
 
 ----------------------------------// PRIVATE //--
 
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
 
 	if socket:getfd() >= _maxfd then
 		out_error("server.lua: Disallowed FD number: "..socket:getfd())
@@ -183,6 +184,7 @@
 	handler.sslctx = function( )
 		return sslctx
 	end
+	handler.hosts = {} -- sni
 	handler.remove = function( )
 		connections = connections - 1
 		if handler then
@@ -244,13 +246,13 @@
 		local client, err = accept( socket )	-- try to accept
 		if client then
 			local ip, clientport = client:getpeername( )
-			local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
+			local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket
 			if err then -- error while wrapping ssl socket
 				return false
 			end
 			connections = connections + 1
 			out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
-			if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
+			if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes
 				return dispatch( handler );
 			end
 			return;
@@ -264,7 +266,7 @@
 	return handler
 end
 
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object
 
 	if socket:getfd() >= _maxfd then
 		out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
@@ -314,6 +316,11 @@
 
 	local handler = bufferqueue -- saves a table ^_^
 
+	handler.extra = extra
+	if extra then
+		handler.servername = extra.servername
+	end
+
 	handler.dispatch = function( )
 		return dispatch
 	end
@@ -424,9 +431,8 @@
 		bufferlen = bufferlen + #data
 		if bufferlen > maxsendlen then
 			_closelist[ handler ] = "send buffer exceeded"	 -- cannot close the client at the moment, have to wait to the end of the cycle
-			handler.write = idfalse -- don't write anymore
 			return false
-		elseif socket and not _sendlist[ socket ] then
+		elseif not nosend and socket and not _sendlist[ socket ] then
 			_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
 		end
 		bufferqueuelen = bufferqueuelen + 1
@@ -456,49 +462,55 @@
 		maxreadlen = readlen or maxreadlen
 		return bufferlen, maxreadlen, maxsendlen
 	end
-	--TODO: Deprecate
 	handler.lock_read = function (self, switch)
+		out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" )
 		if switch == true then
-			local tmp = _readlistlen
-			_readlistlen = removesocket( _readlist, socket, _readlistlen )
-			_readtimes[ handler ] = nil
-			if _readlistlen ~= tmp then
-				noread = true
-			end
+			return self:pause()
 		elseif switch == false then
-			if noread then
-				noread = false
-				_readlistlen = addsocket(_readlist, socket, _readlistlen)
-				_readtimes[ handler ] = _currenttime
-			end
+			return self:resume()
 		end
 		return noread
 	end
 	handler.pause = function (self)
-		return self:lock_read(true);
+		local tmp = _readlistlen
+		_readlistlen = removesocket( _readlist, socket, _readlistlen )
+		_readtimes[ handler ] = nil
+		if _readlistlen ~= tmp then
+			noread = true
+		end
+		return noread;
 	end
 	handler.resume = function (self)
-		return self:lock_read(false);
+		if noread then
+			noread = false
+			_readlistlen = addsocket(_readlist, socket, _readlistlen)
+			_readtimes[ handler ] = _currenttime
+		end
+		return noread;
 	end
 	handler.lock = function( self, switch )
-		handler.lock_read (switch)
+		out_error( "server.lua, lock() is deprecated" )
+		handler.lock_read (self, switch)
 		if switch == true then
-			handler.write = idfalse
-			local tmp = _sendlistlen
-			_sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
-			_writetimes[ handler ] = nil
-			if _sendlistlen ~= tmp then
-				nosend = true
-			end
+			handler.pause_writes (self)
 		elseif switch == false then
-			handler.write = write
-			if nosend then
-				nosend = false
-				write( "" )
-			end
+			handler.resume_writes (self)
 		end
 		return noread, nosend
 	end
+	handler.pause_writes = function (self)
+		local tmp = _sendlistlen
+		_sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+		_writetimes[ handler ] = nil
+		nosend = true
+	end
+	handler.resume_writes = function (self)
+		nosend = false
+		if bufferlen > 0 then
+			_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+		end
+	end
+
 	local _readbuffer = function( ) -- this function reads data
 		local buffer, err, part = receive( socket, pattern )	-- receive buffer with "pattern"
 		if not err or (err == "wantread" or err == "timeout") then -- received something
@@ -599,7 +611,7 @@
 						coroutine_yield( ) -- handshake not finished
 					end
 				end
-				err = "ssl handshake error: " .. ( err or "handshake too long" );
+				err = ( err or "handshake too long" );
 				out_put( "server.lua: ", err );
 				_ = handler and handler:force_close(err)
 				return false, err -- handshake failed
@@ -619,11 +631,20 @@
 			out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
 			local oldsocket, err = socket
 			socket, err = ssl_wrap( socket, sslctx )	-- wrap socket
+
 			if not socket then
 				out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
 				return nil, err -- fatal error
 			end
 
+			if socket.sni then
+				if self.servername then
+					socket:sni(self.servername);
+				elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+					socket:sni(self.server().hosts, true);
+				end
+			end
+
 			socket:settimeout( 0 )
 
 			-- add the new socket to our system
@@ -659,7 +680,7 @@
 	_socketlist[ socket ] = handler
 	_readlistlen = addsocket(_readlist, socket, _readlistlen)
 
-	if sslctx and has_luasec then
+	if sslctx and ssldirect and has_luasec then
 		out_put "server.lua: auto-starting ssl negotiation..."
 		handler.autostart_ssl = true;
 		local ok, err = handler:starttls(sslctx);
@@ -734,9 +755,13 @@
 
 ----------------------------------// PUBLIC //--
 
-addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+listen = function ( addr, port, listeners, config )
 	addr = addr or "*"
+	config = config or {}
 	local err
+	local sslctx = config.tls_ctx;
+	local ssldirect = config.tls_direct;
+	local pattern = config.read_size;
 	if type( listeners ) ~= "table" then
 		err = "invalid listener table"
 	elseif type ( addr ) ~= "string" then
@@ -757,7 +782,7 @@
 		out_error( "server.lua, [", addr, "]:", port, ": ", err )
 		return nil, err
 	end
-	local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
+	local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket
 	if not handler then
 		server:close( )
 		return nil, err
@@ -770,6 +795,14 @@
 	return handler
 end
 
+addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+	return listen(addr, port, listeners, {
+		read_size = pattern;
+		tls_ctx = sslctx;
+		tls_direct = sslctx and true or false;
+	});
+end
+
 getserver = function ( addr, port )
 	return _server[ addr..":"..port ];
 end
@@ -977,8 +1010,8 @@
 
 --// EXPERIMENTAL //--
 
-local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
-	local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra )
+	local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra)
 	if not handler then return nil, err end
 	_socketlist[ socket ] = handler
 	if not sslctx then
@@ -997,7 +1030,7 @@
 	return handler, socket
 end
 
-local addclient = function( address, port, listeners, pattern, sslctx, typ )
+local addclient = function( address, port, listeners, pattern, sslctx, typ, extra )
 	local err
 	if type( listeners ) ~= "table" then
 		err = "invalid listener table"
@@ -1034,7 +1067,7 @@
 	client:settimeout( 0 )
 	local ok, err = client:setpeername( address, port )
 	if ok or err == "timeout" or err == "Operation already in progress" then
-		return wrapclient( client, address, port, listeners, pattern, sslctx )
+		return wrapclient( client, address, port, listeners, pattern, sslctx, extra )
 	else
 		return nil, err
 	end
@@ -1114,6 +1147,7 @@
 	stats = stats,
 	closeall = closeall,
 	addserver = addserver,
+	listen = listen,
 	getserver = getserver,
 	setlogger = setlogger,
 	getsettings = getsettings,