Software /
code /
prosody
Diff
net/server_select.lua @ 4349:16fd8061964e
net.server_select: Merge straight-SSL and starttls code paths, also fixes onconnect being called before handshake completion for straight-SSL
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Sat, 20 Aug 2011 15:06:14 -0400 |
parent | 4348:5b240c6b5334 |
child | 4353:f600591c87fa |
line wrap: on
line diff
--- a/net/server_select.lua Sat Aug 20 15:04:17 2011 -0400 +++ b/net/server_select.lua Sat Aug 20 15:06:14 2011 -0400 @@ -525,6 +525,9 @@ handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions handler.sendbuffer = _sendbuffer _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end _readlistlen = addsocket(_readlist, client, _readlistlen) return true else @@ -549,74 +552,56 @@ ) end if luasec then - if sslctx then -- ssl? - handler:set_sslctx(sslctx); - out_put("server.lua: ", "starting ssl handshake") - local err + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); + end + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - if err then - out_put( "server.lua: ssl error: ", tostring(err) ) - --mem_free( ) - return nil, nil, err -- fatal error + if not socket then + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error end + socket:settimeout( 0 ) + + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil + + handler.starttls = nil + needtls = nil + + -- Secure now (if handshake fails connection will close) + ssl = true + handler.readbuffer = handshake handler.sendbuffer = handshake handshake( socket ) -- do handshake - if not socket then - return nil, nil, "ssl handshake failed"; - end - else - local sslctx; - handler.starttls = function( self, _sslctx) - if _sslctx then - sslctx = _sslctx; - handler:set_sslctx(sslctx); - end - if bufferqueuelen > 0 then - out_put "server.lua: we need to do tls, but delaying until send buffer empty" - needtls = true - return - end - out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) - local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) - if err then - out_put( "server.lua: error while starting tls on client: ", tostring(err) ) - return nil, err -- fatal error - end + end + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer + + if sslctx then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + handler:starttls(sslctx); + end - socket:settimeout( 0 ) - - -- add the new socket to our system - - send = socket.send - receive = socket.receive - shutdown = id - - _socketlist[ socket ] = handler - _readlistlen = addsocket(_readlist, socket, _readlistlen) - - -- remove traces of the old socket - - _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) - _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) - _socketlist[ oldsocket ] = nil - - handler.starttls = nil - needtls = nil - - -- Secure now - ssl = true - - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer - end else handler.readbuffer = _readbuffer handler.sendbuffer = _sendbuffer @@ -857,16 +842,19 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) _socketlist[ socket ] = handler - _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - if listeners.onconnect then - -- When socket is writeable, call onconnect - local _sendbuffer = handler.sendbuffer; - handler.sendbuffer = function () - handler.sendbuffer = _sendbuffer; - listeners.onconnect(handler); - -- If there was data with the incoming packet, handle it now. - if #handler:bufferqueue() > 0 then - return _sendbuffer(); + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) end end end