Diff

net/server_epoll.lua @ 9314:c9eef7e8ee65

net.server_epoll: Use util.poll
author Kim Alvefur <zash@zash.se>
date Wed, 16 May 2018 23:57:09 +0200
parent 9311:9b0604fe01f1
child 9319:7c954c75b6ac
line wrap: on
line diff
--- a/net/server_epoll.lua	Wed May 16 23:56:34 2018 +0200
+++ b/net/server_epoll.lua	Wed May 16 23:57:09 2018 +0200
@@ -5,8 +5,6 @@
 -- COPYING file in the source package for more information.
 --
 
--- server_epoll
---  Server backend based on https://luarocks.org/modules/zash/lua-epoll
 
 local t_sort = table.sort;
 local t_insert = table.insert;
@@ -19,14 +17,13 @@
 local next = next;
 local pairs = pairs;
 local log = require "util.logger".init("server_epoll");
-local epoll = require "epoll";
 local socket = require "socket";
 local luasec = require "ssl";
 local gettime = require "util.time".now;
 local createtable = require "util.table".create;
 local _SOCKETINVALID = socket._SOCKETINVALID or -1;
 
-assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version");
+local poll = require "util.poll".new();
 
 local _ENV = nil;
 -- luacheck: std none
@@ -260,48 +257,56 @@
 	end
 end
 
--- lua-epoll flag for currently requested poll state
-function interface:flags()
-	if self._wantread then
-		if self._wantwrite then
-			return "rw";
-		end
-		return "r";
-	elseif self._wantwrite then
-		return "w";
+function interface:add(r, w)
+	local fd = self:getfd();
+	if fd < 0 then
+		return nil, "invalid fd";
 	end
+	if r == nil then r = self._wantread; end
+	if w == nil then w = self._wantwrite; end
+	local ok, err = poll:add(fd, r, w);
+	if not ok then
+		log("error", "Could not register %s: %s", self, err);
+		return ok, err;
+	end
+	self._wantread, self._wantwrite = r, w;
+	fds[fd] = self;
+	log("debug", "Registered %s", self);
+	return true;
 end
 
--- Add or remove sockets or modify epoll flags
-function interface:setflags(r, w)
-	if r ~= nil then self._wantread = r; end
-	if w ~= nil then self._wantwrite = w; end
-	local flags = self:flags();
-	local currentflags = self._flags;
-	if flags == currentflags then
-		return true;
-	end
+function interface:set(r, w)
 	local fd = self:getfd();
 	if fd < 0 then
-		self._wantread, self._wantwrite = nil, nil;
 		return nil, "invalid fd";
 	end
-	local op = "mod";
-	if not flags then
-		op = "del";
-	elseif not currentflags then
-		op = "add";
+	if r == nil then r = self._wantread; end
+	if w == nil then w = self._wantwrite; end
+	local ok, err = poll:set(fd, r, w);
+	if not ok then
+		log("error", "Could not update poller state %s: %s", self, err);
+		return ok, err;
 	end
-	local ok, err = epoll.ctl(op, fd, flags);
---	log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""),
---		op, fd, flags or "", tostring(ok), err);
-	if not ok then return ok, err end
-	if op == "add" then
-		fds[fd] = self;
-	elseif op == "del" then
-		fds[fd] = nil;
+	self._wantread, self._wantwrite = r, w;
+	return true;
+end
+
+function interface:del()
+	local fd = self:getfd();
+	if fd < 0 then
+		return nil, "invalid fd";
 	end
-	self._flags = flags;
+	if fds[fd] ~= self then
+		return nil, "unregistered fd";
+	end
+	local ok, err = poll:del(fd);
+	if not ok then
+		log("error", "Could not unregister %s: %s", self, err);
+		return ok, err;
+	end
+	self._wantread, self._wantwrite = nil, nil;
+	fds[fd] = nil;
+	log("debug", "Unregistered %s", self);
 	return true;
 end
 
@@ -317,9 +322,9 @@
 			self:on("incoming", partial, err);
 		end
 		if err == "wantread" then
-			self:setflags(true, nil);
+			self:set(true, nil);
 		elseif err == "wantwrite" then
-			self:setflags(nil, true);
+			self:set(nil, true);
 		elseif err ~= "timeout" then
 			self:on("disconnect", err);
 			self:destroy()
@@ -343,7 +348,7 @@
 	local data = t_concat(buffer);
 	local ok, err, partial = self.conn:send(data);
 	if ok then
-		self:setflags(nil, false);
+		self:set(nil, false);
 		for i = #buffer, 1, -1 do
 			buffer[i] = nil;
 		end
@@ -358,9 +363,9 @@
 		self:setwritetimeout();
 	end
 	if err == "wantwrite" or err == "timeout" then
-		self:setflags(nil, true);
+		self:set(nil, true);
 	elseif err == "wantread" then
-		self:setflags(true, nil);
+		self:set(true, nil);
 	elseif err ~= "timeout" then
 		self:on("disconnect", err);
 		self:destroy();
@@ -381,7 +386,7 @@
 		self.writebuffer = { data };
 	end
 	self:setwritetimeout();
-	self:setflags(nil, true);
+	self:set(nil, true);
 	return #data;
 end
 interface.send = interface.write;
@@ -389,7 +394,7 @@
 -- Close, possibly after writing is done
 function interface:close()
 	if self.writebuffer and self.writebuffer[1] then
-		self:setflags(false, true); -- Flush final buffer contents
+		self:set(false, true); -- Flush final buffer contents
 		self.write, self.send = noop, noop; -- No more writing
 		log("debug", "Close %s after writing", self);
 		self.ondrain = interface.close;
@@ -403,7 +408,7 @@
 end
 
 function interface:destroy()
-	self:setflags(false, false);
+	self:del();
 	self:setwritetimeout(false);
 	self:setreadtimeout(false);
 	self.onreadable = noop;
@@ -425,10 +430,10 @@
 		log("debug", "Start TLS on %s after write", self);
 		self.ondrain = interface.starttls;
 		self.starttls = false;
-		self:setflags(nil, true); -- make sure wantwrite is set
+		self:set(nil, true); -- make sure wantwrite is set
 	else
 		log("debug", "Start TLS on %s now", self);
-		self:setflags(false, false);
+		self:del();
 		local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx);
 		if not conn then
 			self:on("disconnect", err);
@@ -440,8 +445,7 @@
 		self.ondrain = nil;
 		self.onwritable = interface.tlshandskake;
 		self.onreadable = interface.tlshandskake;
-		self:setflags(true, true);
-		self:setwritetimeout(cfg.handshake_timeout);
+		return self:init();
 	end
 end
 
@@ -455,14 +459,15 @@
 		self.onreadable = nil;
 		self._tls = true;
 		self:on("status", "ssl-handshake-complete");
-		self:init();
+		self:setwritetimeout();
+		self:set(true, true);
 	elseif err == "wantread" then
 		log("debug", "TLS handshake on %s to wait until readable", self);
-		self:setflags(true, false);
+		self:set(true, false);
 		self:setreadtimeout(cfg.handshake_timeout);
 	elseif err == "wantwrite" then
 		log("debug", "TLS handshake on %s to wait until writable", self);
-		self:setflags(false, true);
+		self:set(false, true);
 		self:setwritetimeout(cfg.handshake_timeout);
 	else
 		log("debug", "TLS handshake error on %s: %s", self, err);
@@ -513,15 +518,15 @@
 -- Initialization
 function interface:init()
 	self:setwritetimeout();
-	return self:setflags(true, true);
+	return self:add(true, true);
 end
 
 function interface:pause()
-	return self:setflags(false);
+	return self:set(false);
 end
 
 function interface:resume()
-	return self:setflags(true);
+	return self:set(true);
 end
 
 -- Pause connection for some time
@@ -530,13 +535,13 @@
 		self._pausefor:close();
 	end
 	if t == false then return; end
-	self:setflags(false);
+	self:set(false);
 	self._pausefor = addtimer(t, function ()
 		self._pausefor = nil;
 		if self.conn and self.conn:dirty() then
 			self:onreadable();
 		end
-		self:setflags(true);
+		self:set(true);
 	end);
 end
 
@@ -564,7 +569,7 @@
 		sockname = addr;
 		sockport = port;
 	}, interface_mt);
-	server:setflags(true, false);
+	server:add(true, false);
 	return server;
 end
 
@@ -603,7 +608,7 @@
 		onreadable = onreadable;
 		onwriteable = onwriteable;
 		close = function (self)
-			self:setflags(false, false);
+			self:del();
 		end
 	}, interface_mt);
 	if type(fd) == "number" then
@@ -612,7 +617,7 @@
 		end;
 		-- Otherwise it'll need to be something LuaSocket-compatible
 	end
-	conn:setflags(onreadable, onwriteable);
+	conn:add(onreadable, onwriteable);
 	return conn;
 end;
 
@@ -629,8 +634,8 @@
 			from:resume();
 		end,
 	}, {__index=to.listeners});
-	from:setflags(true, nil);
-	to:setflags(nil, true);
+	from:set(true, nil);
+	to:set(nil, true);
 end
 
 -- XXX What uses this?
@@ -662,7 +667,7 @@
 local function loop(once)
 	repeat
 		local t = runtimers(cfg.max_wait, cfg.min_wait);
-		local fd, r, w = epoll.wait(t);
+		local fd, r, w = poll:wait(t);
 		if fd then
 			local conn = fds[fd];
 			if conn then
@@ -674,7 +679,7 @@
 				end
 			else
 				log("debug", "Removing unknown fd %d", fd);
-				epoll.ctl("del", fd);
+				poll:del(fd);
 			end
 		elseif r ~= "timeout" then
 			log("debug", "epoll_wait error: %s", tostring(r));
@@ -705,9 +710,9 @@
 		local function onevent(self)
 			local ret = self:callback();
 			if ret == -1 then
-				self:setflags(false, false);
+				self:set(false, false);
 			elseif ret then
-				self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+				self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
 			end
 		end
 
@@ -717,11 +722,11 @@
 			onreadable = onevent;
 			onwritable = onevent;
 			close = function (self)
-				self:setflags(false, false);
+				self:del();
 				fds[fd] = nil;
 			end;
 		}, interface_mt);
-		local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+		local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
 		if not ok then return ok, err; end
 		return conn;
 	end;