Diff

net/server_epoll.lua @ 10563:e8db377a2983

Merge 0.11->trunk
author Kim Alvefur <zash@zash.se>
date Tue, 24 Dec 2019 00:39:45 +0100
parent 10546:944863f878b9
child 10571:cfeb0077c9e9
line wrap: on
line diff
--- a/net/server_epoll.lua	Tue Dec 24 00:26:40 2019 +0100
+++ b/net/server_epoll.lua	Tue Dec 24 00:39:45 2019 +0100
@@ -9,20 +9,22 @@
 local t_insert = table.insert;
 local t_concat = table.concat;
 local setmetatable = setmetatable;
-local tostring = tostring;
 local pcall = pcall;
 local type = type;
 local next = next;
 local pairs = pairs;
-local log = require "util.logger".init("server_epoll");
+local logger = require "util.logger";
+local log = logger.init("server_epoll");
 local socket = require "socket";
 local luasec = require "ssl";
-local gettime = require "util.time".now;
+local realtime = require "util.time".now;
+local monotonic = require "util.time".monotonic;
 local indexedbheap = require "util.indexedbheap";
 local createtable = require "util.table".create;
 local inet = require "util.net";
 local inet_pton = inet.pton;
 local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+local new_id = require "util.id".medium;
 
 local poller = require "util.poll"
 local EEXIST = poller.EEXIST;
@@ -38,7 +40,10 @@
 	read_timeout = 14 * 60;
 
 	-- How long to wait for a socket to become writable after queuing data to send
-	send_timeout = 60;
+	send_timeout = 180;
+
+	-- How long to wait for a socket to become writable after creation
+	connect_timeout = 20;
 
 	-- Some number possibly influencing how many pending connections can be accepted
 	tcp_backlog = 128;
@@ -58,6 +63,13 @@
 	-- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
 	max_wait = 86400;
 	min_wait = 1e-06;
+
+	-- EXPERIMENTAL
+	-- Whether to kill connections in case of callback errors.
+	fatal_errors = false;
+
+	-- Attempt writes instantly
+	opportunistic_writes = false;
 }};
 local cfg = default_config.__index;
 
@@ -75,45 +87,43 @@
 end
 
 local function reschedule(t, time)
+	time = monotonic() + time;
 	t[1] = time;
 	timers:reprioritize(t.id, time);
 end
 
--- Add absolute timer
-local function at(time, f)
+-- Add relative timer
+local function addtimer(timeout, f)
+	local time = monotonic() + timeout;
 	local timer = { time, f, close = closetimer, reschedule = reschedule, id = nil };
 	timer.id = timers:insert(timer, time);
 	return timer;
 end
 
--- Add relative timer
-local function addtimer(timeout, f)
-	return at(gettime() + timeout, f);
-end
-
 -- Run callbacks of expired timers
 -- Return time until next timeout
 local function runtimers(next_delay, min_wait)
 	-- Any timers at all?
-	local now = gettime();
+	local elapsed = monotonic();
+	local now = realtime();
 	local peek = timers:peek();
 	while peek do
 
-		if peek > now then
-			next_delay = peek - now;
+		if peek > elapsed then
+			next_delay = peek - elapsed;
 			break;
-	end
+		end
 
-		local _, timer, id = timers:pop();
+		local _, timer = timers:pop();
 		local ok, ret = pcall(timer[2], now);
 		if ok and type(ret) == "number"  then
-			local next_time = now+ret;
+			local next_time = elapsed+ret;
 			timer[1] = next_time;
 			timers:insert(timer, next_time);
-	end
+		end
 
 		peek = timers:peek();
-			end
+	end
 	if peek == nil then
 		return next_delay;
 	end
@@ -138,6 +148,15 @@
 	return ("FD %d"):format(self:getfd());
 end
 
+interface.log = log;
+function interface:debug(msg, ...) --luacheck: ignore 212/self
+	self.log("debug", msg, ...);
+end
+
+function interface:error(msg, ...) --luacheck: ignore 212/self
+	self.log("error", msg, ...);
+end
+
 -- Replace the listener and tell the old one
 function interface:setlistener(listeners, data)
 	self:on("detach");
@@ -148,21 +167,32 @@
 -- Call a listener callback
 function interface:on(what, ...)
 	if not self.listeners then
-		log("error", "%s has no listeners", self);
+		self:debug("Interface is missing listener callbacks");
 		return;
 	end
 	local listener = self.listeners["on"..what];
 	if not listener then
-		-- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+		-- self:debug("Missing listener 'on%s'", what); -- uncomment for development and debugging
 		return;
 	end
 	local ok, err = pcall(listener, self, ...);
 	if not ok then
-		log("error", "Error calling on%s: %s", what, err);
+		if cfg.fatal_errors then
+			self:debug("Closing due to error calling on%s: %s", what, err);
+			self:destroy();
+		else
+			self:debug("Error calling on%s: %s", what, err);
+		end
+		return nil, err;
 	end
 	return err;
 end
 
+-- Allow this one to be overridden
+function interface:onincoming(...)
+	return self:on("incoming", ...);
+end
+
 -- Return the file descriptor number
 function interface:getfd()
 	if self.conn then
@@ -226,12 +256,14 @@
 	end
 	t = t or cfg.read_timeout;
 	if self._readtimeout then
-		self._readtimeout:reschedule(gettime() + t);
+		self._readtimeout:reschedule(t);
 	else
 		self._readtimeout = addtimer(t, function ()
 			if self:on("readtimeout") then
+				self:debug("Read timeout handled");
 				return cfg.read_timeout;
 			else
+				self:debug("Read timeout not handled, disconnecting");
 				self:on("disconnect", "read timeout");
 				self:destroy();
 			end
@@ -250,9 +282,10 @@
 	end
 	t = t or cfg.send_timeout;
 	if self._writetimeout then
-		self._writetimeout:reschedule(gettime() + t);
+		self._writetimeout:reschedule(t);
 	else
 		self._writetimeout = addtimer(t, function ()
+			self:debug("Write timeout");
 			self:on("disconnect", "write timeout");
 			self:destroy();
 		end);
@@ -269,15 +302,15 @@
 	local ok, err, errno = poll:add(fd, r, w);
 	if not ok then
 		if errno == EEXIST then
-			log("debug", "%s already registered!", self);
+			self:debug("FD already registered in poller! (EEXIST)");
 			return self:set(r, w); -- So try to change its flags
 		end
-		log("error", "Could not register %s: %s(%d)", self, err, errno);
+		self:debug("Could not register in poller: %s(%d)", err, errno);
 		return ok, err;
 	end
 	self._wantread, self._wantwrite = r, w;
 	fds[fd] = self;
-	log("debug", "Watching %s", self);
+	self:debug("Registered in poller");
 	return true;
 end
 
@@ -290,7 +323,7 @@
 	if w == nil then w = self._wantwrite; end
 	local ok, err, errno = poll:set(fd, r, w);
 	if not ok then
-		log("error", "Could not update poller state %s: %s(%d)", self, err, errno);
+		self:debug("Could not update poller state: %s(%d)", err, errno);
 		return ok, err;
 	end
 	self._wantread, self._wantwrite = r, w;
@@ -307,12 +340,12 @@
 	end
 	local ok, err, errno = poll:del(fd);
 	if not ok and errno ~= ENOENT then
-		log("error", "Could not unregister %s: %s(%d)", self, err, errno);
+		self:debug("Could not unregister: %s(%d)", err, errno);
 		return ok, err;
 	end
 	self._wantread, self._wantwrite = nil, nil;
 	fds[fd] = nil;
-	log("debug", "Unwatched %s", self);
+	self:debug("Unregistered from poller");
 	return true;
 end
 
@@ -334,7 +367,7 @@
 	local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
 	if data then
 		self:onconnect();
-		self:on("incoming", data);
+		self:onincoming(data);
 	else
 		if err == "wantread" then
 			self:set(true, nil);
@@ -345,7 +378,7 @@
 		end
 		if partial and partial ~= "" then
 			self:onconnect();
-			self:on("incoming", partial, err);
+			self:onincoming(partial, err);
 		end
 		if err ~= "timeout" then
 			self:on("disconnect", err);
@@ -354,6 +387,14 @@
 		end
 	end
 	if not self.conn then return; end
+	if self._limit and (data or partial) then
+		local cost = self._limit * #(data or partial);
+		if cost > cfg.min_wait then
+			self:setreadtimeout(false);
+			self:pausefor(cost);
+			return;
+		end
+	end
 	if self._wantread and self.conn:dirty() then
 		self:setreadtimeout(false);
 		self:pausefor(cfg.read_retry_delay);
@@ -378,10 +419,12 @@
 		self:ondrain(); -- Be aware of writes in ondrain
 		return;
 	elseif partial then
+		self:debug("Sent %d out of %d buffered bytes", partial, #data);
 		buffer[1] = data:sub(partial+1);
 		for i = #buffer, 2, -1 do
 			buffer[i] = nil;
 		end
+		self:set(nil, true);
 		self:setwritetimeout();
 	end
 	if err == "wantwrite" or err == "timeout" then
@@ -407,8 +450,14 @@
 	else
 		self.writebuffer = { data };
 	end
-	self:setwritetimeout();
-	self:set(nil, true);
+	if not self._write_lock then
+		if cfg.opportunistic_writes then
+			self:onwritable();
+			return #data;
+		end
+		self:setwritetimeout();
+		self:set(nil, true);
+	end
 	return #data;
 end
 interface.send = interface.write;
@@ -418,10 +467,10 @@
 	if self.writebuffer and self.writebuffer[1] then
 		self:set(false, true); -- Flush final buffer contents
 		self.write, self.send = noop, noop; -- No more writing
-		log("debug", "Close %s after writing", self);
+		self:debug("Close after writing remaining buffered data");
 		self.ondrain = interface.close;
 	else
-		log("debug", "Close %s now", self);
+		self:debug("Closing now");
 		self.write, self.send = noop, noop;
 		self.close = noop;
 		self:on("disconnect");
@@ -450,7 +499,7 @@
 	if tls_ctx then self.tls_ctx = tls_ctx; end
 	self.starttls = false;
 	if self.writebuffer and self.writebuffer[1] then
-		log("debug", "Start TLS on %s after write", self);
+		self:debug("Start TLS after write");
 		self.ondrain = interface.starttls;
 		self:set(nil, true); -- make sure wantwrite is set
 	else
@@ -460,7 +509,7 @@
 		self.onwritable = interface.tlshandskake;
 		self.onreadable = interface.tlshandskake;
 		self:set(true, true);
-		log("debug", "Prepare to start TLS on %s", self);
+		self:debug("Prepared to start TLS");
 	end
 end
 
@@ -469,12 +518,13 @@
 	self:setreadtimeout(false);
 	if not self._tls then
 		self._tls = true;
-		log("debug", "Start TLS on %s now", self);
+		self:debug("Starting TLS now");
 		self:del();
+		self:updatenames(); -- Can't getpeer/sockname after wrap()
 		local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
 		if not ok then
 			conn, err = ok, conn;
-			log("error", "Failed to initialize TLS: %s", err);
+			self:debug("Failed to initialize TLS: %s", err);
 		end
 		if not conn then
 			self:on("disconnect", err);
@@ -483,6 +533,13 @@
 		end
 		conn:settimeout(0);
 		self.conn = conn;
+		if conn.sni then
+			if self.servername then
+				conn:sni(self.servername);
+			elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+				conn:sni(self._server.hosts, true);
+			end
+		end
 		self:on("starttls");
 		self.ondrain = nil;
 		self.onwritable = interface.tlshandskake;
@@ -491,40 +548,55 @@
 	end
 	local ok, err = self.conn:dohandshake();
 	if ok then
-		log("debug", "TLS handshake on %s complete", self);
+		local info = self.conn.info and self.conn:info();
+		if type(info) == "table" then
+			self:debug("TLS handshake complete (%s with %s)", info.protocol, info.cipher);
+		else
+			self:debug("TLS handshake complete");
+		end
 		self.onwritable = nil;
 		self.onreadable = nil;
 		self:on("status", "ssl-handshake-complete");
 		self:setwritetimeout();
 		self:set(true, true);
 	elseif err == "wantread" then
-		log("debug", "TLS handshake on %s to wait until readable", self);
+		self:debug("TLS handshake to wait until readable");
 		self:set(true, false);
 		self:setreadtimeout(cfg.ssl_handshake_timeout);
 	elseif err == "wantwrite" then
-		log("debug", "TLS handshake on %s to wait until writable", self);
+		self:debug("TLS handshake to wait until writable");
 		self:set(false, true);
 		self:setwritetimeout(cfg.ssl_handshake_timeout);
 	else
-		log("debug", "TLS handshake error on %s: %s", self, err);
+		self:debug("TLS handshake error: %s", err);
 		self:on("disconnect", err);
 		self:destroy();
 	end
 end
 
-local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luasocket object -> interface object
+local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object
 	client:settimeout(0);
+	local conn_id = ("conn%s"):format(new_id());
 	local conn = setmetatable({
 		conn = client;
 		_server = server;
-		created = gettime();
+		created = realtime();
 		listeners = listeners;
 		read_size = read_size or (server and server.read_size);
 		writebuffer = {};
 		tls_ctx = tls_ctx or (server and server.tls_ctx);
 		tls_direct = server and server.tls_direct;
+		id = conn_id;
+		log = logger.init(conn_id);
+		extra = extra;
 	}, interface_mt);
 
+	if extra then
+		if extra.servername then
+			conn.servername = extra.servername;
+		end
+	end
+
 	conn:updatenames();
 	return conn;
 end
@@ -532,11 +604,11 @@
 function interface:updatenames()
 	local conn = self.conn;
 	local ok, peername, peerport = pcall(conn.getpeername, conn);
-	if ok then
+	if ok and peername then
 		self.peername, self.peerport = peername, peerport;
 	end
 	local ok, sockname, sockport = pcall(conn.getsockname, conn);
-	if ok then
+	if ok and sockname then
 		self.sockname, self.sockport = sockname, sockport;
 	end
 end
@@ -546,34 +618,39 @@
 function interface:onacceptable()
 	local conn, err = self.conn:accept();
 	if not conn then
-		log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+		self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
 		self:pausefor(cfg.accept_retry_interval);
 		return;
 	end
 	local client = wrapsocket(conn, self, nil, self.listeners);
-	log("debug", "New connection %s", tostring(client));
+	client:debug("New connection %s on server %s", client, self);
 	client:init();
 	if self.tls_direct then
 		client:starttls(self.tls_ctx);
+	else
+		client:onconnect();
 	end
 end
 
 -- Initialization
 function interface:init()
-	self:setwritetimeout();
+	self:setwritetimeout(cfg.connect_timeout);
 	return self:add(true, true);
 end
 
 function interface:pause()
+	self:debug("Pause reading");
 	return self:set(false);
 end
 
 function interface:resume()
+	self:debug("Resume reading");
 	return self:set(true);
 end
 
 -- Pause connection for some time
 function interface:pausefor(t)
+	self:debug("Pause for %fs", t);
 	if self._pausefor then
 		self._pausefor:close();
 	end
@@ -582,43 +659,85 @@
 	self._pausefor = addtimer(t, function ()
 		self._pausefor = nil;
 		self:set(true);
+		self:debug("Resuming after pause, connection is %s", not self.conn and "missing" or self.conn:dirty() and "dirty" or "clean");
 		if self.conn and self.conn:dirty() then
 			self:onreadable();
 		end
 	end);
 end
 
+function interface:setlimit(Bps)
+	if Bps > 0 then
+		self._limit = 1/Bps;
+	else
+		self._limit = nil;
+	end
+end
+
+function interface:pause_writes()
+	if self._write_lock then
+		return
+	end
+	self:debug("Pause writes");
+	self._write_lock = true;
+	self:setwritetimeout(false);
+	self:set(nil, false);
+end
+
+function interface:resume_writes()
+	if not self._write_lock then
+		return
+	end
+	self:debug("Resume writes");
+	self._write_lock = nil;
+	if self.writebuffer[1] then
+		self:setwritetimeout();
+		self:set(nil, true);
+	end
+end
+
 -- Connected!
 function interface:onconnect()
-	if self.conn and not self.peername and self.conn.getpeername then
-		self.peername, self.peerport = self.conn:getpeername();
-	end
+	self:updatenames();
+	self:debug("Connected (%s)", self);
 	self.onconnect = noop;
 	self:on("connect");
 end
 
-local function addserver(addr, port, listeners, read_size, tls_ctx)
+local function listen(addr, port, listeners, config)
 	local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
 	if not conn then return conn, err; end
 	conn:settimeout(0);
 	local server = setmetatable({
 		conn = conn;
-		created = gettime();
+		created = realtime();
 		listeners = listeners;
-		read_size = read_size;
+		read_size = config and config.read_size;
 		onreadable = interface.onacceptable;
-		tls_ctx = tls_ctx;
-		tls_direct = tls_ctx and true or false;
+		tls_ctx = config and config.tls_ctx;
+		tls_direct = config and config.tls_direct;
+		hosts = config and config.sni_hosts;
 		sockname = addr;
 		sockport = port;
+		log = logger.init(("serv%s"):format(new_id()));
 	}, interface_mt);
+	server:debug("Server %s created", server);
 	server:add(true, false);
 	return server;
 end
 
 -- COMPAT
-local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
-	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
+local function addserver(addr, port, listeners, read_size, tls_ctx)
+	return listen(addr, port, listeners, {
+		read_size = read_size;
+		tls_ctx = tls_ctx;
+		tls_direct = tls_ctx and true or false;
+	});
+end
+
+-- COMPAT
+local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
+	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
 	if not client.peername then
 		client.peername, client.peerport = addr, port;
 	end
@@ -631,7 +750,7 @@
 end
 
 -- New outgoing TCP connection
-local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
+local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
 	local create;
 	if not typ then
 		local n = inet_pton(addr);
@@ -649,13 +768,19 @@
 		return nil, "invalid socket type";
 	end
 	local conn, err = create();
+	if not conn then return conn, err; end
 	local ok, err = conn:settimeout(0);
 	if not ok then return ok, err; end
 	local ok, err = conn:setpeername(addr, port);
 	if not ok and err ~= "timeout" then return ok, err; end
-	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
 	local ok, err = client:init();
+	if not client.peername then
+		-- otherwise not set until connected
+		client.peername, client.peerport = addr, port;
+	end
 	if not ok then return ok, err; end
+	client:debug("Client %s created", client);
 	if tls_ctx then
 		client:starttls(tls_ctx);
 	end
@@ -677,23 +802,23 @@
 		end;
 		-- Otherwise it'll need to be something LuaSocket-compatible
 	end
+	conn.id = new_id();
+	conn.log = logger.init(("fdwatch%s"):format(conn.id));
 	conn:add(onreadable, onwritable);
 	return conn;
 end;
 
 -- Dump all data from one connection into another
-local function link(from, to)
-	from.listeners = setmetatable({
-		onincoming = function (_, data)
-			from:pause();
-			to:write(data);
-		end,
-	}, {__index=from.listeners});
-	to.listeners = setmetatable({
-		ondrain = function ()
-			from:resume();
-		end,
-	}, {__index=to.listeners});
+local function link(from, to, read_size)
+	from:debug("Linking to %s", to.id);
+	function from:onincoming(data)
+		self:pause();
+		to:write(data);
+	end
+	function to:ondrain() -- luacheck: ignore 212/self
+		from:resume();
+	end
+	from:set_mode(read_size);
 	from:set(true, nil);
 	to:set(nil, true);
 end
@@ -752,7 +877,7 @@
 	addserver = addserver;
 	addclient = addclient;
 	add_task = addtimer;
-	at = at;
+	listen = listen;
 	loop = loop;
 	closeall = closeall;
 	setquitting = setquitting;
@@ -766,6 +891,7 @@
 	-- libevent emulation
 	event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
 	addevent = function (fd, mode, callback)
+		log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
 		local function onevent(self)
 			local ret = self:callback();
 			if ret == -1 then
@@ -785,6 +911,8 @@
 				fds[fd] = nil;
 			end;
 		}, interface_mt);
+		conn.id = conn:getfd();
+		conn.log = logger.init(("fdwatch%d"):format(conn.id));
 		local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
 		if not ok then return ok, err; end
 		return conn;