Diff

net/server_select.lua @ 2541:2febd008214e

net.server_select: Remove startssl parameter to the client/server creation functions - passing a sslctx now indicates you want to use SSL from the start
author Matthew Wild <mwild1@gmail.com>
date Sun, 31 Jan 2010 15:37:08 +0000
parent 2478:7be72eca5666
child 2549:55a50e75c0c0
line wrap: on
line diff
--- a/net/server_select.lua	Sat Jan 30 18:51:07 2010 +0000
+++ b/net/server_select.lua	Sun Jan 31 15:37:08 2010 +0000
@@ -160,7 +160,7 @@
 _maxsslhandshake = 30 -- max handshake round-trips
 ----------------------------------// PRIVATE //--
 
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl )    -- this function wraps a server
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections )    -- this function wraps a server
 
     maxconnections = maxconnections or _maxclientsperserver
 
@@ -168,58 +168,6 @@
 
     local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect
 
-    local err
-
-    local ssl = false
-
-    if sslctx then
-        ssl = true
-        if not ssl_newcontext then
-            out_error "luasec not found"
-            ssl = false
-        end
-        if type( sslctx ) ~= "table" then
-            out_error "server.lua: wrong server sslctx"
-            ssl = false
-        end
-        local ctx;
-        ctx, err = ssl_newcontext( sslctx )
-        if not ctx then
-            err = err or "wrong sslctx parameters"
-            local file;
-            file = err:match("^error loading (.-) %(");
-            if file then
-            	if file == "private key" then
-            		file = sslctx.key or "your private key";
-            	elseif file == "certificate" then
-            		file = sslctx.certificate or "your certificate file";
-            	end
-	        local reason = err:match("%((.+)%)$") or "some reason";
-	        if reason == "Permission denied" then
-	        	reason = "Check that the permissions allow Prosody to read this file.";
-	        elseif reason == "No such file or directory" then
-	        	reason = "Check that the path is correct, and the file exists.";
-	        elseif reason == "system lib" then
-	        	reason = "Previous error (see logs), or other system error.";
-	        else
-	        	reason = "Reason: "..tostring(reason or "unknown"):lower();
-	        end
-	        log("error", "SSL/TLS: Failed to load %s: %s", file, reason);
-	    else
-                log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );
-            end
-            ssl = false
-        end
-        sslctx = ctx;
-    end
-    if not ssl then
-      sslctx = false;
-      if startssl then
-         log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )
-         return nil, "Cannot start ssl,  see log for details"
-       end
-    end
-
     local accept = socket.accept
 
     --// public methods of the object //--
@@ -229,7 +177,7 @@
     handler.shutdown = function( ) end
 
     handler.ssl = function( )
-        return ssl
+        return sslctx ~= nil
     end
     handler.sslctx = function( )
         return sslctx
@@ -271,7 +219,7 @@
         if client then
             local ip, clientport = client:getpeername( )
             client:settimeout( 0 )
-            local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl )    -- wrap new client socket
+            local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx )    -- wrap new client socket
             if err then    -- error while wrapping ssl socket
                 return false
             end
@@ -286,7 +234,7 @@
     return handler
 end
 
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl )    -- this function wraps a client to a handler object
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx )    -- this function wraps a client to a handler object
 
     socket:settimeout( 0 )
 
@@ -520,7 +468,7 @@
             bufferqueuelen = 0
             bufferlen = 0
             _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )    -- delete socket from writelist
-            _ = needtls and handler:starttls(true)
+            _ = needtls and handler:starttls(nil, true)
             _writetimes[ handler ] = nil
 	    _ = toclose and handler.close( )
             return true
@@ -584,72 +532,69 @@
     end
     if sslctx then    -- ssl?
     	handler:set_sslctx(sslctx);
-        if startssl then    -- ssl now?
-            --out_put("server.lua: ", "starting ssl handshake")
-	    local err
+        out_put("server.lua: ", "starting ssl handshake")
+        local err
+        socket, err = ssl_wrap( socket, sslctx )    -- wrap socket
+        if err then
+            out_put( "server.lua: ssl error: ", tostring(err) )
+            --mem_free( )
+            return nil, nil, err    -- fatal error
+        end
+        socket:settimeout( 0 )
+        handler.readbuffer = handshake
+        handler.sendbuffer = handshake
+        handshake( socket ) -- do handshake
+        if not socket then
+            return nil, nil, "ssl handshake failed";
+        end
+    else
+        local sslctx;
+        handler.starttls = function( self, _sslctx, now )
+            if _sslctx then
+                sslctx = _sslctx;
+            	handler:set_sslctx(sslctx);
+            end
+            if not now then
+                out_put "server.lua: we need to do tls, but delaying until later"
+                needtls = true
+                return
+            end
+            out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+            local oldsocket, err = socket
             socket, err = ssl_wrap( socket, sslctx )    -- wrap socket
+            --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
             if err then
-                out_put( "server.lua: ssl error: ", tostring(err) )
-                --mem_free( )
-                return nil, nil, err    -- fatal error
+                out_put( "server.lua: error while starting tls on client: ", tostring(err) )
+                return nil, err    -- fatal error
             end
+            
             socket:settimeout( 0 )
+
+            -- add the new socket to our system
+
+            send = socket.send
+            receive = socket.receive
+            shutdown = id
+
+            _socketlist[ socket ] = handler
+            _readlistlen = addsocket(_readlist, socket, _readlistlen)
+
+            -- remove traces of the old socket
+
+            _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+            _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+            _socketlist[ oldsocket ] = nil
+
+            handler.starttls = nil
+            needtls = nil
+                
+            -- Secure now
+            ssl = true
+
             handler.readbuffer = handshake
             handler.sendbuffer = handshake
-            handshake( socket ) -- do handshake
-            if not socket then
-                return nil, nil, "ssl handshake failed";
-            end
-        else
-            -- We're not automatically doing SSL, so we're not secure (yet)
-            ssl = false
-            handler.starttls = function( self, now )
-                if not now then
-                    --out_put "server.lua: we need to do tls, but delaying until later"
-                    needtls = true
-                    return
-                end
-                --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
-                local oldsocket, err = socket
-                socket, err = ssl_wrap( socket, sslctx )    -- wrap socket
-                --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
-                if err then
-                    out_put( "server.lua: error while starting tls on client: ", tostring(err) )
-                    return nil, err    -- fatal error
-                end
-
-                socket:settimeout( 0 )
-
-                -- add the new socket to our system
-
-                send = socket.send
-                receive = socket.receive
-                shutdown = id
-
-                _socketlist[ socket ] = handler
-                _readlistlen = addsocket(_readlist, socket, _readlistlen)
-
-                -- remove traces of the old socket
-
-                _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
-                _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
-                _socketlist[ oldsocket ] = nil
-
-                handler.starttls = nil
-                needtls = nil
-                
-                -- Secure now
-                ssl = true
-
-                handler.readbuffer = handshake
-                handler.sendbuffer = handshake
-                handshake( socket )    -- do handshake
-            end
-            handler.readbuffer = _readbuffer
-            handler.sendbuffer = _sendbuffer
+            handshake( socket )    -- do handshake
         end
-    else    -- normal connection
-        ssl = false
         handler.readbuffer = _readbuffer
         handler.sendbuffer = _sendbuffer
     end
@@ -705,9 +650,8 @@
 
 ----------------------------------// PUBLIC //--
 
-addserver = function( addr, port, listeners, pattern, sslctx, startssl )    -- this function provides a way for other scripts to reg a server
+addserver = function( addr, port, listeners, pattern, sslctx )    -- this function provides a way for other scripts to reg a server
     local err
-    --out_put("server.lua: autossl on ", port, " is ", startssl)
     if type( listeners ) ~= "table" then
         err = "invalid listener table"
     end
@@ -728,7 +672,7 @@
         out_error( "server.lua, port ", port, ": ", err )
         return nil, err
     end
-    local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver, startssl )    -- wrap new server socket
+    local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver )    -- wrap new server socket
     if not handler then
         server:close( )
         return nil, err
@@ -857,14 +801,14 @@
 
 --// EXPERIMENTAL //--
 
-local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )
-    local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )
+local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
+    local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
     _socketlist[ socket ] = handler
     _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
     return handler, socket
 end
 
-local addclient = function( address, port, listeners, pattern, sslctx, startssl )
+local addclient = function( address, port, listeners, pattern, sslctx )
     local client, err = luasocket.tcp( )
     if err then
         return nil, err
@@ -874,7 +818,7 @@
     if err then    -- try again
         local handler = wrapclient( client, address, port, listeners )
     else
-        wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl )
+        wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
     end
 end