Changeset

4349:16fd8061964e

net.server_select: Merge straight-SSL and starttls code paths, also fixes onconnect being called before handshake completion for straight-SSL
author Matthew Wild <mwild1@gmail.com>
date Sat, 20 Aug 2011 15:06:14 -0400
parents 4348:5b240c6b5334
children 4350:0b9ed126286e
files net/server_select.lua
diffstat 1 files changed, 59 insertions(+), 71 deletions(-) [+]
line wrap: on
line diff
--- a/net/server_select.lua	Sat Aug 20 15:04:17 2011 -0400
+++ b/net/server_select.lua	Sat Aug 20 15:06:14 2011 -0400
@@ -525,6 +525,9 @@
 						handler.readbuffer = _readbuffer	-- when handshake is done, replace the handshake function with regular functions
 						handler.sendbuffer = _sendbuffer
 						_ = status and status( handler, "ssl-handshake-complete" )
+						if self.autostart_ssl and listeners.onconnect then
+							listeners.onconnect(self);
+						end
 						_readlistlen = addsocket(_readlist, client, _readlistlen)
 						return true
 					else
@@ -549,74 +552,56 @@
 		)
 	end
 	if luasec then
-		if sslctx then -- ssl?
-			handler:set_sslctx(sslctx);
-			out_put("server.lua: ", "starting ssl handshake")
-			local err
+		handler.starttls = function( self, _sslctx)
+			if _sslctx then
+				handler:set_sslctx(_sslctx);
+			end
+			if bufferqueuelen > 0 then
+				out_put "server.lua: we need to do tls, but delaying until send buffer empty"
+				needtls = true
+				return
+			end
+			out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+			local oldsocket, err = socket
 			socket, err = ssl_wrap( socket, sslctx )	-- wrap socket
-			if err then
-				out_put( "server.lua: ssl error: ", tostring(err) )
-				--mem_free( )
-				return nil, nil, err	-- fatal error
+			if not socket then
+				out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
+				return nil, err -- fatal error
 			end
+
 			socket:settimeout( 0 )
+
+			-- add the new socket to our system
+			send = socket.send
+			receive = socket.receive
+			shutdown = id
+			_socketlist[ socket ] = handler
+			_readlistlen = addsocket(_readlist, socket, _readlistlen)
+			
+			-- remove traces of the old socket
+			_readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+			_sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+			_socketlist[ oldsocket ] = nil
+
+			handler.starttls = nil
+			needtls = nil
+
+			-- Secure now (if handshake fails connection will close)
+			ssl = true
+
 			handler.readbuffer = handshake
 			handler.sendbuffer = handshake
 			handshake( socket ) -- do handshake
-			if not socket then
-				return nil, nil, "ssl handshake failed";
-			end
-		else
-			local sslctx;
-			handler.starttls = function( self, _sslctx)
-				if _sslctx then
-					sslctx = _sslctx;
-					handler:set_sslctx(sslctx);
-				end
-				if bufferqueuelen > 0 then
-					out_put "server.lua: we need to do tls, but delaying until send buffer empty"
-					needtls = true
-					return
-				end
-				out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
-				local oldsocket, err = socket
-				socket, err = ssl_wrap( socket, sslctx )	-- wrap socket
-				--out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
-				if err then
-					out_put( "server.lua: error while starting tls on client: ", tostring(err) )
-					return nil, err -- fatal error
-				end
+		end
+		handler.readbuffer = _readbuffer
+		handler.sendbuffer = _sendbuffer
+		
+		if sslctx then
+			out_put "server.lua: auto-starting ssl negotiation..."
+			handler.autostart_ssl = true;
+			handler:starttls(sslctx);
+		end
 
-				socket:settimeout( 0 )
-	
-				-- add the new socket to our system
-	
-				send = socket.send
-				receive = socket.receive
-				shutdown = id
-
-				_socketlist[ socket ] = handler
-				_readlistlen = addsocket(_readlist, socket, _readlistlen)
-
-				-- remove traces of the old socket
-
-				_readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
-				_sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
-				_socketlist[ oldsocket ] = nil
-
-				handler.starttls = nil
-				needtls = nil
-
-				-- Secure now
-				ssl = true
-
-				handler.readbuffer = handshake
-				handler.sendbuffer = handshake
-				handshake( socket ) -- do handshake
-			end
-			handler.readbuffer = _readbuffer
-			handler.sendbuffer = _sendbuffer
-		end
 	else
 		handler.readbuffer = _readbuffer
 		handler.sendbuffer = _sendbuffer
@@ -857,16 +842,19 @@
 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
 	local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
 	_socketlist[ socket ] = handler
-	_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
-	if listeners.onconnect then
-		-- When socket is writeable, call onconnect
-		local _sendbuffer = handler.sendbuffer;
-		handler.sendbuffer = function ()
-			handler.sendbuffer = _sendbuffer;
-			listeners.onconnect(handler);
-			-- If there was data with the incoming packet, handle it now.
-			if #handler:bufferqueue() > 0 then
-				return _sendbuffer();
+	if not sslctx then
+		_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+		if listeners.onconnect then
+			-- When socket is writeable, call onconnect
+			local _sendbuffer = handler.sendbuffer;
+			handler.sendbuffer = function ()
+				handler.sendbuffer = _sendbuffer;
+				listeners.onconnect(handler);
+				-- If there was data with the incoming packet, handle it now.
+				if #handler:bufferqueue() > 0 then
+					return _sendbuffer();
+				end
+				_sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
 			end
 		end
 	end