Changeset

67:563360207292

Merged local TLS branch
author Matthew Wild <mwild1@gmail.com>
date Sun, 05 Oct 2008 19:16:32 +0100
parents 64:bcd0a3975580 (current diff) 66:018705d57f09 (diff)
children 68:ceb7a55676a4
files
diffstat 4 files changed, 228 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/core/modulemanager.lua	Sun Oct 05 02:48:39 2008 +0100
+++ b/core/modulemanager.lua	Sun Oct 05 19:16:32 2008 +0100
@@ -49,6 +49,7 @@
 	load("legacyauth");
 	load("roster");
 	load("register");
+	load("tls");
 end
 
 function load(name)
--- a/core/stanza_router.lua	Sun Oct 05 02:48:39 2008 +0100
+++ b/core/stanza_router.lua	Sun Oct 05 19:16:32 2008 +0100
@@ -11,6 +11,7 @@
 local jid_split = jid.split;
 
 function core_process_stanza(origin, stanza)
+	log("debug", "Received: "..tostring(stanza))
 	local to = stanza.attr.to;
 	
 	if not to or (hosts[to] and hosts[to].type == "local") then
--- a/main.lua	Sun Oct 05 02:48:39 2008 +0100
+++ b/main.lua	Sun Oct 05 19:16:32 2008 +0100
@@ -101,7 +101,7 @@
 local protected_handler = function (conn, data, err) local success, ret = pcall(handler, conn, data, err); if not success then print("ERROR on "..tostring(conn)..": "..ret); conn:close(); end end;
 local protected_disconnect = function (conn, err) local success, ret = pcall(disconnect, conn, err); if not success then print("ERROR on "..tostring(conn).." disconnect: "..ret); conn:close(); end end;
 
-server.add( { listener = protected_handler, disconnect = protected_disconnect }, 5222, "*", 1, nil ) -- server.add will send a status message
-server.add( { listener = protected_handler, disconnect = protected_disconnect }, 5223, "*", 1, ssl_ctx ) -- server.add will send a status message
+server.add( { listener = protected_handler, disconnect = protected_disconnect }, 5222, "*", 1, ssl_ctx ) -- server.add will send a status message
+--server.add( { listener = protected_handler, disconnect = protected_disconnect }, 5223, "*", 1, ssl_ctx ) -- server.add will send a status message
 
 server.loop();
--- a/net/server.lua	Sun Oct 05 02:48:39 2008 +0100
+++ b/net/server.lua	Sun Oct 05 19:16:32 2008 +0100
@@ -39,6 +39,7 @@
 local coroutine_yield = coroutine.yield
 local print = print;
 local out_put = function () end --print;
+local out_put = print;
 local out_error = print;
 
 --// extern libs //--
@@ -105,8 +106,6 @@
 	if sslctx then
 		if not ssl_newcontext then
 			return nil, "luasec not found"
---        elseif not cfg_get "use_ssl" then
---            return nil, "ssl is deactivated"
 		end
 		if type( sslctx ) ~= "table" then
 			out_error "server.lua: wrong server sslctx"
@@ -119,6 +118,7 @@
 			return nil, err
 		end
 		wrapclient = wrapsslclient
+		wrapclient = wraptlsclient
 	else
 		wrapclient = wraptcpclient
 	end
@@ -356,6 +356,220 @@
 	return handler, socket
 end
 
+wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, sslctx )    -- this function wraps a tls cleint
+
+	local dispatch, disconnect = listener.listener, listener.disconnect
+
+	--// transform socket to ssl object //--
+
+	local err
+
+	socket:settimeout( 0 )
+
+	--// private closures of the object //--
+
+	local writequeue = { }    -- buffer for messages to send
+
+	local eol   -- end of buffer
+
+	local sstat, rstat = 0, 0
+
+	--// local import of socket methods //--
+
+	local send = socket.send
+	local receive = socket.receive
+	local close = socket.close
+	--local shutdown = socket.shutdown
+
+	--// public methods of the object //--
+
+	local handler = { }
+
+	handler.getstats = function( )
+		return rstat, sstat
+	end
+
+	handler.listener = function( data, err )
+		return listener( handler, data, err )
+	end
+	handler.ssl = function( )
+		return false
+	end
+	handler.send = function( _, data, i, j )
+			return send( socket, data, i, j )
+	end
+	handler.receive = function( pattern, prefix )
+			return receive( socket, pattern, prefix )
+	end
+	handler.shutdown = function( pattern )
+		--return shutdown( socket, pattern )
+	end
+	handler.close = function( closed )
+		close( socket )
+		writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
+		readlen = removesocket( readlist, socket, readlen )
+		socketlist[ socket ] = nil
+		out_put "server.lua: closed handler and removed socket from list"
+	end
+	handler.ip = function( )
+		return ip
+	end
+	handler.serverport = function( )
+		return serverport
+	end
+	handler.clientport = function( ) 
+		return clientport
+	end
+
+	handler.write = function( data )
+		if not eol then
+			writelen = writelen + 1
+			writelist[ writelen ] = socket
+			eol = 0
+		end
+		eol = eol + 1
+		writequeue[ eol ] = data
+	end
+	handler.writequeue = function( )
+		return writequeue
+	end
+	handler.socket = function( )
+		return socket
+	end
+	handler.mode = function( )
+		return mode
+	end
+	handler._receivedata = function( )
+		local data, err, part = receive( socket, mode )    -- receive data in "mode"
+		if not err or ( err == "timeout" or err == "wantread" ) then    -- received something
+			local data = data or part or ""
+			local count = #data * STAT_UNIT
+			rstat = rstat + count
+			receivestat = receivestat + count
+			--out_put( "server.lua: read data '", data, "', error: ", err )
+			return dispatch( handler, data, err )
+		else    -- connections was closed or fatal error
+			out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
+			handler.close( )
+			disconnect( handler, err )
+			writequeue = nil
+			handler = nil
+			return false
+		end
+	end
+	handler._dispatchdata = function( )    -- this function writes data to handlers
+		local buffer = table_concat( writequeue, "", 1, eol )
+		local succ, err, byte = send( socket, buffer )
+		local count = ( succ or 0 ) * STAT_UNIT
+		sstat = sstat + count
+		sendstat = sendstat + count
+		out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )
+		if succ then    -- sending succesful
+			--writequeue = { }
+			eol = nil
+			writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist
+			if handler.need_tls then
+				out_put("server.lua: connection is ready for tls handshake");
+				handler.starttls(true);
+				if handler.need_tls then
+					out_put("server.lua: uh-oh... we still want tls, something must be wrong");
+				end
+			end
+			return true
+		elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write
+			buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer
+			writequeue[ 1 ] = buffer    -- insert new buffer in queue
+			eol = 1
+			return true
+		else    -- connection was closed during sending or fatal error
+			out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
+			handler.close( )
+			disconnect( handler, err )
+			writequeue = nil
+			handler = nil
+			return false
+		end
+	end
+
+	handler.receivedata, handler.dispatchdata = handler._receivedata, handler._dispatchdata;
+	-- // COMPAT // --
+
+	handler.getIp = handler.ip
+	handler.getPort = handler.clientport
+
+	--// handshake //--
+
+	local wrote, read
+	
+	handler.starttls = function (now)
+		if not now then out_put("server.lua: we need to do tls, but delaying until later"); handler.need_tls = true; return; end
+		out_put( "server.lua: attempting to start tls on "..tostring(socket) )
+		socket, err = ssl_wrap( socket, sslctx )    -- wrap socket
+		out_put("sslwrapped socket is "..tostring(socket));
+		if err then
+			out_put( "server.lua: ssl error: ", err )
+			return nil, nil, err    -- fatal error
+		end
+		socket:settimeout( 1 )
+		send = socket.send
+		receive = socket.receive
+		close = socket.close
+		handler.ssl = function( )
+			return true
+		end
+		handler.send = function( _, data, i, j )
+			return send( socket, data, i, j )
+		end
+		handler.receive = function( pattern, prefix )
+			return receive( socket, pattern, prefix )
+		end
+		
+			handler.handshake = coroutine_wrap( function( client )
+					local err
+					for i = 1, 10 do    -- 10 handshake attemps
+						_, err = client:dohandshake( )
+						if not err then
+							out_put( "server.lua: ssl handshake done" )
+							writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen
+							handler.receivedata = handler._receivedata    -- when handshake is done, replace the handshake function with regular functions
+							handler.dispatchdata = handler._dispatchdata
+							handler.need_tls = nil
+							socketlist[ client ] = handler
+							readlen = readlen + 1
+							readlist[ readlen ] = client												
+							return true;
+						else
+							out_put( "server.lua: error during ssl handshake: ", err )
+							if err == "wantwrite" then
+								if wrote == nil then
+									writelen = writelen + 1
+									writelist[ writelen ] = client
+									wrote = true
+								end
+							end
+							coroutine_yield( handler, nil, err )    -- handshake not finished
+						end
+					end
+					_ = err ~= "closed" and close( socket )
+					handler.close( )
+					disconnect( handler, err )
+					writequeue = nil
+					handler = nil
+					return false    -- handshake failed
+				end
+			)
+			handler.receivedata = handler.handshake
+			handler.dispatchdata = handler.handshake
+
+			handler.handshake( socket )    -- do handshake
+		end
+	socketlist[ socket ] = handler
+	readlen = readlen + 1
+	readlist[ readlen ] = socket
+
+	return handler, socket
+end
+
 wraptcpclient = function( listener, socket, ip, serverport, clientport, mode )    -- this function wraps a socket
 
 	local dispatch, disconnect = listener.listener, listener.disconnect
@@ -433,6 +647,7 @@
 	handler.mode = function( )
 		return mode
 	end
+	
 	handler.receivedata = function( )
 		local data, err, part = receive( socket, mode )    -- receive data in "mode"
 		if not err or ( err == "timeout" or err == "wantread" ) then    -- received something
@@ -451,6 +666,7 @@
 			return false
 		end
 	end
+	
 	handler.dispatchdata = function( )    -- this function writes data to handlers
 		local buffer = table_concat( writequeue, "", 1, eol )
 		local succ, err, byte = send( socket, buffer )
@@ -573,6 +789,10 @@
 loop = function( )    -- this is the main loop of the program
 	--signal_set( "hub", "run" )
 	repeat
+		--[[print(readlen, writelen)
+		for _, s in ipairs(readlist) do print("R:", tostring(s)) end
+		for _, s in ipairs(writelist) do print("W:", tostring(s)) end
+		out_put("select()"..os.time())]]
 		local read, write, err = socket_select( readlist, writelist, 1 )    -- 1 sec timeout, nice for timers
 		for i, socket in ipairs( write ) do    -- send data waiting in writequeues
 			local handler = socketlist[ socket ]
@@ -593,9 +813,8 @@
 			end
 		end
 		firetimer( )
-		--collectgarbage "collect"
-	until false --signal_get "hub" ~= "run"
-	return --signal_get "hub"
+	until false
+	return
 end
 
 ----------------------------------// BEGIN //--