Diff

net/server_epoll.lua @ 9430:412ff404bf58

net.server_epoll: Delay wrapping sockets in TLS until just before first handshake
author Kim Alvefur <zash@zash.se>
date Fri, 14 Sep 2018 01:34:38 +0200
parent 9387:33e52f727f0f
child 9431:c3c0523a37c6
line wrap: on
line diff
--- a/net/server_epoll.lua	Wed Oct 03 16:41:37 2018 +0200
+++ b/net/server_epoll.lua	Fri Sep 14 01:34:38 2018 +0200
@@ -440,15 +440,30 @@
 
 function interface:starttls(tls_ctx)
 	if tls_ctx then self.tls_ctx = tls_ctx; end
+	self.starttls = false;
 	if self.writebuffer and self.writebuffer[1] then
 		log("debug", "Start TLS on %s after write", self);
 		self.ondrain = interface.starttls;
-		self.starttls = false;
 		self:set(nil, true); -- make sure wantwrite is set
 	else
+		if self.ondrain == interface.starttls then
+			self.ondrain = nil;
+		end
+		self.onwritable = interface.tlshandskake;
+		self.onreadable = interface.tlshandskake;
+		self:set(true, true);
+		log("debug", "Prepare to start TLS on %s", self);
+	end
+end
+
+function interface:tlshandskake()
+	self:setwritetimeout(false);
+	self:setreadtimeout(false);
+	if not self._tls then
+		self._tls = true;
 		log("debug", "Start TLS on %s now", self);
 		self:del();
-		local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx);
+		local conn, err = luasec.wrap(self.conn, self.tls_ctx);
 		if not conn then
 			self:on("disconnect", err);
 			self:destroy();
@@ -456,22 +471,17 @@
 		end
 		conn:settimeout(0);
 		self.conn = conn;
+		self:on("starttls");
 		self.ondrain = nil;
 		self.onwritable = interface.tlshandskake;
 		self.onreadable = interface.tlshandskake;
 		return self:init();
 	end
-end
-
-function interface:tlshandskake()
-	self:setwritetimeout(false);
-	self:setreadtimeout(false);
 	local ok, err = self.conn:dohandshake();
 	if ok then
 		log("debug", "TLS handshake on %s complete", self);
 		self.onwritable = nil;
 		self.onreadable = nil;
-		self._tls = true;
 		self:on("status", "ssl-handshake-complete");
 		self:setwritetimeout();
 		self:set(true, true);
@@ -529,10 +539,9 @@
 	end
 	local client = wrapsocket(conn, self, nil, self.listeners);
 	log("debug", "New connection %s", tostring(client));
+	client:init();
 	if self.tls_direct then
 		client:starttls(self.tls_ctx);
-	else
-		client:init();
 	end
 end
 
@@ -600,10 +609,9 @@
 	if not client.peername then
 		client.peername, client.peerport = addr, port;
 	end
+	client:init();
 	if tls_ctx then
 		client:starttls(tls_ctx);
-	else
-		client:init();
 	end
 	return client;
 end
@@ -615,10 +623,9 @@
 	conn:settimeout(0);
 	conn:connect(addr, port);
 	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+	client:init();
 	if tls_ctx then
 		client:starttls(tls_ctx);
-	else
-		client:init();
 	end
 	return client, conn;
 end