Software /
code /
prosody
File
net/server_epoll.lua @ 7602:a4d5eec36a17
net.server_epoll: Prevent invalid FDs from being added to epoll
author | Kim Alvefur <zash@zash.se> |
---|---|
date | Thu, 18 Aug 2016 17:16:21 +0200 |
parent | 7601:ce20546d7714 |
child | 7603:4e2b019ff176 |
line wrap: on
line source
-- Prosody IM -- Copyright (C) 2016 Kim Alvefur -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -- server_epoll -- Server backend based on https://luarocks.org/modules/zash/lua-epoll local t_sort = table.sort; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; local setmetatable = setmetatable; local tostring = tostring; local pcall = pcall; local log = require "util.logger".init("server_epoll"); local epoll = require "epoll"; local socket = require "socket"; local luasec = require "ssl"; local gettime = require "util.time".now; local createtable = require "util.table".create; assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version"); local _ENV = nil; local cfg = { read_timeout = 900; write_timeout = 7; tcp_backlog = 128; accept_retry_interval = 10; read_retry_delay = 1e-06; connect_timeout = 20; handshake_timeout = 60; }; local fds = createtable(10, 0); -- FD -> conn local timers = {}; local function noop() end local function closetimer(t) t[1] = 0; t[2] = noop; end -- Set to true when timers have changed local resort_timers = false; -- Add absolute timer local function at(time, f) local timer = { time, f, close = closetimer }; t_insert(timers, timer); resort_timers = true; return timer; end -- Add relative timer local function addtimer(timeout, f) return at(gettime() + timeout, f); end -- Run callbacks of expired timers -- Return time until next timeout local function runtimers() if resort_timers then -- Sort earliest timers to the end t_sort(timers, function (a, b) return a[1] > b[1]; end); resort_timers = false; end --[[ Is it worth it to skip the noop calls? for i = #timers, 1, -1 do if timers[i][2] == noop then timers[i] = nil; else break; end end --]] local next_delay = 86400; -- Iterate from the end and remove completed timers for i = #timers, 1, -1 do local timer = timers[i]; local t, f = timer[1], timer[2]; local now = gettime(); -- inside or before the loop? if t > now then -- This timer should not fire yet local diff = t - now; if diff < next_delay then next_delay = diff; end break; end local new_timeout = f(now); if new_timeout then timer[1] = t + new_timeout; resort_timers = true; else t_remove(timers, i); end end if resort_timers or next_delay < 1e-6 then -- Timers may be added from within a timer callback. -- Those would not be considered for next_delay, -- and we might sleep for too long, so instead -- we return a shorter timeout so we can -- properly sort all new timers. next_delay = 1e-6; end return next_delay; end local interface = {}; local interface_mt = { __index = interface }; function interface_mt:__tostring() if self.sockname and self.peername then return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport); elseif self.sockname or self.peername then return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport); end return ("%s FD %d"):format(tostring(self.conn), self:getfd()); end function interface:setlistener(listeners) self.listeners = listeners; end -- Call callback function interface:on(what, ...) local listener = self.listeners["on"..what]; if not listener then -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging return; end local ok, err = pcall(listener, self, ...); if not ok then log("error", "Error calling on%s: %s", what, err); end return err; end function interface:getfd() if self.conn then return self.conn:getfd(); end return -1; end function interface:ip() return self.peername or self.sockname; end function interface:port() return self.sockport or self.peerport; end function interface:clientport() return self.sockport; end function interface:serverport() if self.sockport then return self.sockport; elseif self.server then self.server:port(); end end function interface:socket() return self.conn; end function interface:setoption(k, v) -- LuaSec doesn't expose setoption :( if self.conn.setoption then self.conn:setoption(k, v); end end function interface:setreadtimeout(t) if t == false then if self._readtimeout then self._readtimeout:close(); self._readtimeout = nil; end return end t = t or cfg.read_timeout; if self._readtimeout then self._readtimeout[1] = gettime() + t; resort_timers = true; else self._readtimeout = addtimer(t, function () if self:on("readtimeout") then return cfg.read_timeout; else self:on("disconnect", "read timeout"); self:destroy(); end end); end end function interface:setwritetimeout(t) if t == false then if self._writetimeout then self._writetimeout:close(); self._writetimeout = nil; end return end t = t or cfg.write_timeout; if self._writetimeout then self._writetimeout[1] = gettime() + t; resort_timers = true; else self._writetimeout = addtimer(t, function () self:on("disconnect", "write timeout"); self:destroy(); end); end end function interface:flags() if self._wantread then if self._wantwrite then return "rw"; end return "r"; elseif self._wantwrite then return "w"; end end function interface:setflags(r, w) if r ~= nil then self._wantread = r; end if w ~= nil then self._wantwrite = w; end local flags = self:flags(); local currentflags = self._flags; if flags == currentflags then return true; end local fd = self:getfd(); if fd < 0 then self._wantread, self._wantwrite = nil, nil; return nil, "invalid fd"; end local op = "mod"; if not flags then op = "del"; elseif not currentflags then op = "add"; end local ok, err = epoll.ctl(op, fd, flags); -- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""), -- op, fd, flags or "", tostring(ok), err); if not ok then return ok, err end self._flags = flags; return true; end -- Called when socket is readable function interface:onreadable() local data, err, partial = self.conn:receive(self._pattern); if data or partial then self:on("incoming", data or partial, err); end if err == "wantread" then self:setflags(true, nil); elseif err == "wantwrite" then self:setflags(nil, true); elseif not data and err ~= "timeout" then self:on("disconnect", err); self:destroy() return; end if self.conn:dirty() then self:setreadtimeout(false); self:pausefor(cfg.read_retry_delay); else self:setreadtimeout(); end end -- Called when socket is writable function interface:onwriteable() local buffer = self.writebuffer; local data = t_concat(buffer); local ok, err, partial = self.conn:send(data); if ok then for i = #buffer, 1, -1 do buffer[i] = nil; end self:setflags(nil, false); self:setwritetimeout(false); self:ondrain(); -- Be aware of writes in ondrain return; elseif partial then buffer[1] = data:sub(partial+1); for i = #buffer, 2, -1 do buffer[i] = nil; end self:setwritetimeout(); end if err == "wantwrite" or err == "timeout" then self:setflags(nil, true); elseif err == "wantread" then self:setflags(true, nil); elseif err and err ~= "timeout" then self:on("disconnect", err); self:destroy(); end end function interface:ondrain() if self._toclose then return self:close(); elseif self._starttls then return self:starttls(); else return self:on("drain"); end end function interface:write(data) local buffer = self.writebuffer; if buffer then t_insert(buffer, data); else self.writebuffer = { data }; end self:setwritetimeout(); self:setflags(nil, true); return #data; end interface.send = interface.write; function interface:close() if self._wantwrite then self:setflags(false, true); -- Flush final buffer contents self.write, self.send = noop, noop; -- No more writing log("debug", "Close %s after writing", tostring(self)); self._toclose = true; else log("debug", "Close %s now", tostring(self)); self.write, self.send = noop, noop; self.close = noop; self:on("disconnect"); self:destroy(); end end function interface:destroy() self:setflags(false, false); self:setwritetimeout(false); self:setreadtimeout(false); self.onreadable = noop; self.onwriteable = noop; self.destroy = noop; self.close = noop; self.on = noop; fds[self:getfd()] = nil; self.conn:close(); self.conn = nil; end function interface:ssl() return self._tls; end function interface:starttls(ctx) if ctx then self.tls = ctx; end if self.writebuffer and self.writebuffer[1] then log("debug", "Start TLS on %s after write", tostring(self)); self._starttls = true; else log("debug", "Start TLS on %s now", tostring(self)); self:setflags(false, false); local conn, err = luasec.wrap(self.conn, ctx or self.tls); if not conn then self:on("disconnect", err); self:destroy(); return conn, err; end conn:settimeout(0); self.conn = conn; self._starttls = nil; self.onwriteable = interface.tlshandskake; self.onreadable = interface.tlshandskake; self:setflags(true, true); self:setwritetimeout(cfg.handshake_timeout); end end function interface:tlshandskake() self:setwritetimeout(false); self:setreadtimeout(false); local ok, err = self.conn:dohandshake(); if ok then log("debug", "TLS handshake on %s complete", tostring(self)); self.onwriteable = nil; self.onreadable = nil; self:setflags(true, true); local old = self._tls; self._tls = true; self.starttls = false; if old == false then self.onwriteable = interface.onconnect; self:setflags(false, true); self:setwritetimeout(); else self:setflags(true, true); self:on("status", "ssl-handshake-complete"); end elseif err == "wantread" then log("debug", "TLS handshake on %s to wait until readable", tostring(self)); self:setflags(true, false); self:setreadtimeout(cfg.handshake_timeout); elseif err == "wantwrite" then log("debug", "TLS handshake on %s to wait until writable", tostring(self)); self:setflags(false, true); self:setwritetimeout(cfg.handshake_timeout); else log("debug", "TLS handshake error on %s: %s", tostring(self), err); self:on("disconnect", err); self:destroy(); end end local function wrapsocket(client, server, pattern, listeners, tls) -- luasocket object -> interface object client:settimeout(0); local conn = setmetatable({ conn = client; server = server; created = gettime(); listeners = listeners; _pattern = pattern or server._pattern; onwriteable = interface.onconnect; writebuffer = {}; tls = tls; }, interface_mt); if client.getpeername then conn.peername, conn.peerport = client:getpeername(); end if client.getsockname then conn.sockname, conn.sockport = client:getsockname(); end fds[conn:getfd()] = conn; return conn; end function interface:onacceptable() local conn, err = self.conn:accept(); if not conn then log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); self:pausefor(cfg.accept_retry_interval); return; end local client = wrapsocket(conn, self, nil, self.listeners, self.tls); log("debug", "New connection %s", tostring(client)); client:init(); end function interface:init() if self.tls and not self._tls then self._tls = false; -- This means we should call onconnect when TLS is up return self:starttls(); else self:setflags(false, true); self:setwritetimeout(cfg.connect_timeout); end end function interface:pause() self:setflags(false); end function interface:resume() self:setflags(true); end function interface:pausefor(t) if self._pausefor then self._pausefor:close(); end if t == false then return; end self:setflags(false); self._pausefor = addtimer(t, function () self._pausefor = nil; self:setflags(true); if self.conn:dirty() then self:onreadable(); end end); end function interface:onconnect() self.onwriteable = nil; self:on("connect"); self:setflags(true); return self:onwriteable(); end local function addserver(addr, port, listeners, pattern, tls) local conn, err = socket.bind(addr, port, cfg.tcp_backlog); if not conn then return conn, err; end conn:settimeout(0); local server = setmetatable({ conn = conn; created = gettime(); listeners = listeners; _pattern = pattern; onreadable = interface.onacceptable; tls = tls; sockname = addr; sockport = port; }, interface_mt); server:setflags(true, false); fds[server:getfd()] = server; return server; end -- COMPAT local function wrapclient(conn, addr, port, listeners, pattern, tls) local client = wrapsocket(conn, nil, pattern, listeners, tls); if not client.peername then client.peername, client.peerport = addr, port; end client:init(); return client; end local function addclient(addr, port, listeners, pattern, tls) local conn, err = socket.tcp(); if not conn then return conn, err; end conn:settimeout(0); conn:connect(addr, port); local client = wrapsocket(conn, nil, pattern, listeners, tls) client:init(); return client, conn; end local function link(from, to) from.listeners = setmetatable({ onincoming = function (_, data) from:pause(); to:write(data); end, }, {__index=from.listeners}); to.listeners = setmetatable({ ondrain = function () from:resume(); end, }, {__index=to.listeners}); from:setflags(true, nil); to:setflags(nil, true); end -- XXX What uses this? -- net.adns function interface:set_send(new_send) self.send = new_send; end local quitting = nil; local function setquitting() quitting = "quitting"; end local function loop() repeat local t = runtimers(); local fd, r, w = epoll.wait(t); if fd then local conn = fds[fd]; if conn then if r then conn:onreadable(); end if w then conn:onwriteable(); end else log("debug", "Removing unknown fd %d", fd); epoll.ctl("del", fd); end elseif r ~= "timeout" then log("debug", "epoll_wait error: %s", tostring(r)); end until quitting; return quitting; end return { get_backend = function () return "epoll"; end; addserver = addserver; addclient = addclient; add_task = addtimer; at = at; loop = loop; setquitting = setquitting; wrapclient = wrapclient; link = link; -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) local function onevent(self) local ret = self:callback(); if ret == -1 then epoll.ctl("del", fd); elseif ret then epoll.ctl("mod", fd, mode); end end local conn = { callback = callback; onreadable = onevent; onwriteable = onevent; close = function () fds[fd] = nil; return epoll.ctl("del", fd); end; }; fds[fd] = conn; local ok, err = epoll.ctl("add", fd, mode or "r"); if not ok then return ok, err; end return conn; end; };