Diff

net/server.lua @ 65:9c471840acb9 tls

TLS: Handshake works, no data after that
author Matthew Wild <mwild1@gmail.com>
date Sun, 05 Oct 2008 17:33:38 +0100
parent 64:bcd0a3975580
child 66:018705d57f09
line wrap: on
line diff
--- a/net/server.lua	Sun Oct 05 02:48:39 2008 +0100
+++ b/net/server.lua	Sun Oct 05 17:33:38 2008 +0100
@@ -39,6 +39,7 @@
 local coroutine_yield = coroutine.yield
 local print = print;
 local out_put = function () end --print;
+local out_put = print;
 local out_error = print;
 
 --// extern libs //--
@@ -105,8 +106,6 @@
 	if sslctx then
 		if not ssl_newcontext then
 			return nil, "luasec not found"
---        elseif not cfg_get "use_ssl" then
---            return nil, "ssl is deactivated"
 		end
 		if type( sslctx ) ~= "table" then
 			out_error "server.lua: wrong server sslctx"
@@ -119,6 +118,7 @@
 			return nil, err
 		end
 		wrapclient = wrapsslclient
+		wrapclient = wraptlsclient
 	else
 		wrapclient = wraptcpclient
 	end
@@ -356,6 +356,216 @@
 	return handler, socket
 end
 
+wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, sslctx )    -- this function wraps a tls cleint
+
+	local dispatch, disconnect = listener.listener, listener.disconnect
+
+	--// transform socket to ssl object //--
+
+	local err
+
+	socket:settimeout( 0 )
+
+	--// private closures of the object //--
+
+	local writequeue = { }    -- buffer for messages to send
+
+	local eol   -- end of buffer
+
+	local sstat, rstat = 0, 0
+
+	--// local import of socket methods //--
+
+	local send = socket.send
+	local receive = socket.receive
+	local close = socket.close
+	--local shutdown = socket.shutdown
+
+	--// public methods of the object //--
+
+	local handler = { }
+
+	handler.getstats = function( )
+		return rstat, sstat
+	end
+
+	handler.listener = function( data, err )
+		return listener( handler, data, err )
+	end
+	handler.ssl = function( )
+		return false
+	end
+	handler.send = function( _, data, i, j )
+			return send( socket, data, i, j )
+	end
+	handler.receive = function( pattern, prefix )
+			return receive( socket, pattern, prefix )
+	end
+	handler.shutdown = function( pattern )
+		--return shutdown( socket, pattern )
+	end
+	handler.close = function( closed )
+		close( socket )
+		writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
+		readlen = removesocket( readlist, socket, readlen )
+		socketlist[ socket ] = nil
+		out_put "server.lua: closed handler and removed socket from list"
+	end
+	handler.ip = function( )
+		return ip
+	end
+	handler.serverport = function( )
+		return serverport
+	end
+	handler.clientport = function( ) 
+		return clientport
+	end
+
+	handler.write = function( data )
+		if not eol then
+			writelen = writelen + 1
+			writelist[ writelen ] = socket
+			eol = 0
+		end
+		eol = eol + 1
+		writequeue[ eol ] = data
+	end
+	handler.writequeue = function( )
+		return writequeue
+	end
+	handler.socket = function( )
+		return socket
+	end
+	handler.mode = function( )
+		return mode
+	end
+	handler._receivedata = function( )
+		local data, err, part = receive( socket, mode )    -- receive data in "mode"
+		if not err or ( err == "timeout" or err == "wantread" ) then    -- received something
+			local data = data or part or ""
+			local count = #data * STAT_UNIT
+			rstat = rstat + count
+			receivestat = receivestat + count
+			out_put( "server.lua: read data '", data, "', error: ", err )
+			return dispatch( handler, data, err )
+		else    -- connections was closed or fatal error
+			out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
+			handler.close( )
+			disconnect( handler, err )
+			writequeue = nil
+			handler = nil
+			return false
+		end
+	end
+	handler._dispatchdata = function( )    -- this function writes data to handlers
+		local buffer = table_concat( writequeue, "", 1, eol )
+		local succ, err, byte = send( socket, buffer )
+		local count = ( succ or 0 ) * STAT_UNIT
+		sstat = sstat + count
+		sendstat = sendstat + count
+		out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )
+		if succ then    -- sending succesful
+			--writequeue = { }
+			eol = nil
+			writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist
+			if handler.need_tls then
+				out_put("server.lua: connection is ready for tls handshake");
+				handler.need_tls = not handler.starttls(true);
+			end
+			return true
+		elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write
+			buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer
+			writequeue[ 1 ] = buffer    -- insert new buffer in queue
+			eol = 1
+			return true
+		else    -- connection was closed during sending or fatal error
+			out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
+			handler.close( )
+			disconnect( handler, err )
+			writequeue = nil
+			handler = nil
+			return false
+		end
+	end
+
+	handler.receivedata, handler.dispatchdata = handler._receivedata, handler._dispatchdata;
+	-- // COMPAT // --
+
+	handler.getIp = handler.ip
+	handler.getPort = handler.clientport
+
+	--// handshake //--
+
+	local wrote, read
+	
+	handler.starttls = function (now)
+		if not now then handler.need_tls = true; return; end
+		out_put( "server.lua: attempting to start tls on "..tostring(socket) )
+		socket, err = ssl_wrap( socket, sslctx )    -- wrap socket
+		out_put("sslwrapped socket is "..tostring(socket));
+		if err then
+			out_put( "server.lua: ssl error: ", err )
+			return nil, nil, err    -- fatal error
+		end
+		socket:settimeout( 1 )
+		send = socket.send
+		receive = socket.receive
+		close = socket.close
+		print(readlen, writelen)
+		for _, s in ipairs(readlist) do print("R:", tostring(s)) end
+		for _, s in ipairs(writelist) do print("W:", tostring(s)) end
+		handler.ssl = function( )
+			return true
+		end
+		handler.send = function( _, data, i, j )
+			return send( socket, data, i, j )
+		end
+		handler.receive = function( pattern, prefix )
+			return receive( socket, pattern, prefix )
+		end
+	
+		handler.handshake = function (conn)
+							local succ, msg
+							out_put("ssl handshaking on socket "..tostring(conn))
+							conn:settimeout()
+							while not succ do
+								succ, msg = conn:dohandshake()
+								out_put("msg: "..tostring(msg))
+								if msg == 'wantread' then
+									socket_select({conn}, nil)
+								elseif msg == 'wantwrite' then
+									socket_select(nil, {conn})
+								elseif not succ then
+									-- other error
+									_ = err ~= "closed" and close( socket )
+									handler.close( )
+									disconnect( handler, err )
+									writequeue = nil
+									handler = nil
+									out_error("server.lua: ssl handshake failed");
+									return false    -- handshake failed
+								end
+					
+							end
+							out_put("server.lua: ssl handshake succeeded!");
+							handler.receivedata = handler._receivedata;
+							handler.dispatchdata = handler._dispatchdata;
+							return true;
+						end
+		
+		handler.receivedata = handler.handshake
+		handler.dispatchdata = handler.handshake
+
+		return handler.handshake( socket )    -- do handshake
+	end
+	
+	socketlist[ socket ] = handler
+	readlen = readlen + 1
+	readlist[ readlen ] = socket
+
+	return handler, socket
+end
+
 wraptcpclient = function( listener, socket, ip, serverport, clientport, mode )    -- this function wraps a socket
 
 	local dispatch, disconnect = listener.listener, listener.disconnect
@@ -433,6 +643,7 @@
 	handler.mode = function( )
 		return mode
 	end
+	
 	handler.receivedata = function( )
 		local data, err, part = receive( socket, mode )    -- receive data in "mode"
 		if not err or ( err == "timeout" or err == "wantread" ) then    -- received something
@@ -451,6 +662,7 @@
 			return false
 		end
 	end
+	
 	handler.dispatchdata = function( )    -- this function writes data to handlers
 		local buffer = table_concat( writequeue, "", 1, eol )
 		local succ, err, byte = send( socket, buffer )
@@ -573,6 +785,7 @@
 loop = function( )    -- this is the main loop of the program
 	--signal_set( "hub", "run" )
 	repeat
+		out_put("select()")
 		local read, write, err = socket_select( readlist, writelist, 1 )    -- 1 sec timeout, nice for timers
 		for i, socket in ipairs( write ) do    -- send data waiting in writequeues
 			local handler = socketlist[ socket ]
@@ -593,9 +806,8 @@
 			end
 		end
 		firetimer( )
-		--collectgarbage "collect"
-	until false --signal_get "hub" ~= "run"
-	return --signal_get "hub"
+	until false
+	return
 end
 
 ----------------------------------// BEGIN //--