Comparison

net/server_epoll.lua @ 7547:b327322ce2dd

net.server_epoll: New experimental server backend
author Kim Alvefur <zash@zash.se>
date Wed, 10 Aug 2016 16:57:16 +0200
child 7550:8c2bc1b6d84a
comparison
equal deleted inserted replaced
7546:9606a99f8617 7547:b327322ce2dd
1 -- Prosody IM
2 -- Copyright (C) 2016 Kim Alvefur
3 --
4 -- This project is MIT/X11 licensed. Please see the
5 -- COPYING file in the source package for more information.
6 --
7
8 -- server_epoll
9 -- Server backend based on https://luarocks.org/modules/zash/lua-epoll
10
11 local t_sort = table.sort;
12 local t_insert = table.insert;
13 local t_remove = table.remove;
14 local t_concat = table.concat;
15 local setmetatable = setmetatable;
16 local tostring = tostring;
17 local log = require "util.logger".init("server_epoll");
18 local epoll = require "epoll";
19 local socket = require "socket";
20 local luasec = require "ssl";
21 local gettime = require "util.time".now;
22 local createtable = require "util.table".create;
23
24 local _ENV = nil;
25
26 local cfg = {
27 read_timeout = 900;
28 write_timeout = 7;
29 tcp_backlog = 128;
30 accept_retry_interval = 10;
31 };
32
33 local fds = createtable(10, 0); -- FD -> conn
34 local timers = {};
35
36 local function noop() end
37 local function closetimer(t)
38 t[1] = 0;
39 t[2] = noop;
40 end
41
42 local resort_timers = false;
43 local function at(time, f)
44 local timer = { time, f, close = closetimer };
45 t_insert(timers, timer);
46 resort_timers = true;
47 return timer;
48 end
49 local function addtimer(timeout, f)
50 return at(gettime() + timeout, f);
51 end
52
53 local function runtimers()
54 if resort_timers then
55 -- Sort earliest timers to the end
56 t_sort(timers, function (a, b) return a[1] > b[1]; end);
57 resort_timers = false;
58 end
59
60 --[[ Is it worth it to skip the noop calls?
61 for i = #timers, 1, -1 do
62 if timers[i][2] == noop then
63 timers[i] = nil;
64 else
65 break;
66 end
67 end
68 --]]
69
70 local next_delay = 86400;
71
72 -- Iterate from the end and remove completed timers
73 for i = #timers, 1, -1 do
74 local timer = timers[i];
75 local t, f = timer[1], timer[2];
76 local now = gettime(); -- inside or before the loop?
77 if t > now then
78 local diff = t - now;
79 if diff < next_delay then
80 next_delay = diff;
81 end
82 return next_delay;
83 end
84 local new_timeout = f(now);
85 if new_timeout then
86 local t_diff = t + new_timeout - now;
87 if t_diff < 1e-6 then
88 t_diff = 1e-6;
89 end
90 if t_diff < next_delay then
91 next_delay = t_diff;
92 end
93 timer[1] = t + new_timeout;
94 resort_timers = true;
95 else
96 t_remove(timers, i);
97 end
98 end
99 if next_delay < 1e-6 then
100 next_delay = 1e-6;
101 end
102 return next_delay;
103 end
104
105 local interface = {};
106 local interface_mt = { __index = interface };
107
108 function interface_mt:__tostring()
109 if self.peer then
110 if self.conn then
111 return ("%d %s [%s]:%d"):format(self:getfd(), tostring(self.conn), self.peer[1], self.peer[2]);
112 else
113 return ("%d [%s]:%d"):format(self:getfd(), self.peer[1], self.peer[2]);
114 end
115 end
116 return tostring(self:getfd());
117 end
118
119 function interface:setlistener(listeners)
120 self.listeners = listeners;
121 end
122
123 function interface:getfd()
124 return self.conn:getfd();
125 end
126
127 function interface:ip()
128 return self.peer[1];
129 end
130
131 function interface:socket()
132 return self.conn;
133 end
134
135 function interface:setoption(k, v)
136 -- LuaSec doesn't expose setoption :(
137 if self.conn.setoption then
138 self.conn:setoption(k, v);
139 end
140 end
141
142 function interface:setreadtimeout(t)
143 if t == false then
144 if self._readtimeout then
145 self._readtimeout:close();
146 self._readtimeout = nil;
147 end
148 return
149 end
150 t = t or cfg.read_timeout;
151 if self._readtimeout then
152 self._readtimeout[1] = gettime() + t;
153 resort_timers = true;
154 else
155 self._readtimeout = addtimer(t, function ()
156 if self:onreadtimeout() then
157 return cfg.read_timeout;
158 else
159 self.listeners.ondisconnect(self, "read timeout");
160 self:destroy();
161 end
162 end);
163 end
164 end
165
166 function interface:onreadtimeout()
167 if self.listeners.onreadtimeout then
168 return self.listeners.onreadtimeout(self);
169 end
170 end
171
172 function interface:setwritetimeout(t)
173 if t == false then
174 if self._writetimeout then
175 self._writetimeout:close();
176 self._writetimeout = nil;
177 end
178 return
179 end
180 t = t or cfg.write_timeout;
181 if self._writetimeout then
182 self._writetimeout[1] = gettime() + t;
183 resort_timers = true;
184 else
185 self._writetimeout = addtimer(t, function ()
186 self.listeners.ondisconnect(self, "write timeout");
187 self:destroy();
188 end);
189 end
190 end
191
192 function interface:flags()
193 if self._wantread then
194 if self._wantwrite then
195 return "rw";
196 end
197 return "r";
198 elseif self._wantwrite then
199 return "w";
200 end
201 end
202
203 function interface:setflags(r, w)
204 if r ~= nil then self._wantread = r; end
205 if w ~= nil then self._wantwrite = w; end
206 local flags = self:flags();
207 local currentflags = self._flags;
208 if flags == currentflags then
209 return true;
210 end
211 local fd = self:getfd();
212 local op = "mod";
213 if not flags then
214 op = "del";
215 elseif not currentflags then
216 op = "add";
217 end
218 local ok, err = epoll.ctl(op, fd, flags);
219 if not ok then return ok, err end
220 self._flags = flags;
221 return true;
222 end
223
224 function interface:onreadable()
225 local data, err, partial = self.conn:receive(self._pattern);
226 if data or partial then
227 self.listeners.onincoming(self, data or partial, err);
228 end
229 if err == "wantread" then
230 self:setflags(true, nil);
231 elseif err == "wantwrite" then
232 self:setflags(nil, true);
233 elseif err ~= "timeout" then
234 self.listeners.ondisconnect(self, err);
235 self:destroy()
236 return;
237 end
238 self:setreadtimeout();
239 end
240
241 function interface:onwriteable()
242 local buffer = self.writebuffer;
243 local data = t_concat(buffer);
244 local ok, err, partial = self.conn:send(data);
245 if ok then
246 for i = #buffer, 1, -1 do
247 buffer[i] = nil;
248 end
249 self:ondrain();
250 if not buffer[1] then
251 self:setflags(nil, false);
252 self:setwritetimeout(false);
253 else
254 self:setwritetimeout();
255 end
256 self._writable = true;
257 elseif partial then
258 buffer[1] = data:sub(partial+1)
259 for i = #buffer, 2, -1 do
260 buffer[i] = nil;
261 end
262 self:setwritetimeout();
263 self._writable = false;
264 end
265 if err == "wantwrite" or err == "timeout" then
266 self:setflags(nil, true);
267 elseif err == "wantread" then
268 self:setflags(true, nil);
269 elseif err and err ~= "timeout" then
270 self.listeners.ondisconnect(self, err);
271 self:destroy();
272 end
273 end
274
275 function interface:ondrain()
276 if self.listeners.ondrain then
277 self.listeners.ondrain(self);
278 end
279 if self._starttls then
280 self:starttls();
281 elseif self._toclose then
282 self:close();
283 end
284 end
285
286 function interface:write(data)
287 local buffer = self.writebuffer;
288 if buffer then
289 t_insert(buffer, data);
290 else
291 self.writebuffer = { data };
292 end
293 if self._writable and false then
294 self:onwriteable();
295 else
296 self:setwritetimeout();
297 self:setflags(nil, true);
298 end
299 return #data;
300 end
301 interface.send = interface.write;
302
303 function interface:close()
304 if self._wantwrite then
305 self._toclose = true;
306 else
307 self.close = noop;
308 self.listeners.ondisconnect(self);
309 self:destroy();
310 end
311 end
312
313 function interface:destroy()
314 self:setflags(false, false);
315 self:setwritetimeout(false);
316 self:setreadtimeout(false);
317 fds[self:getfd()] = nil;
318 return self.conn:close();
319 end
320
321 function interface:ssl()
322 return self._tls;
323 end
324
325 function interface:starttls(ctx)
326 if ctx then self.tls = ctx; end
327 if self.writebuffer and self.writebuffer[1] then
328 self._starttls = true;
329 else
330 self:setflags(false, false);
331 local conn, err = luasec.wrap(self.conn, ctx or self.tls);
332 if not conn then
333 self.listeners.ondisconnect(self, err);
334 self:destroy();
335 end
336 conn:settimeout(0);
337 self.conn = conn;
338 self._starttls = nil;
339 self.onwriteable = interface.tlshandskake;
340 self.onreadable = interface.tlshandskake;
341 self:setflags(true, true);
342 end
343 end
344
345 function interface:tlshandskake()
346 local ok, err = self.conn:dohandshake();
347 if ok then
348 self.onwriteable = nil;
349 self.onreadable = nil;
350 self:setflags(true, true);
351 local old = self._tls;
352 self._tls = true;
353 self.starttls = false;
354 if old == false then
355 self:onconnect();
356 elseif self.listeners.onstatus then
357 self.listeners.onstatus(self, "ssl-handshake-complete");
358 end
359 elseif err == "wantread" then
360 self:setflags(true, false);
361 self:setwritetimeout(false);
362 self:setreadtimeout(cfg.handshake_timeout);
363 elseif err == "wantwrite" then
364 self:setflags(false, true);
365 self:setreadtimeout(false);
366 self:setwritetimeout(cfg.handshake_timeout);
367 else
368 self.listeners.ondisconnect(self, err);
369 self:destroy();
370 end
371 end
372
373 local function wrapsocket(client, server, pattern, listeners, tls) -- luasocket object -> interface object
374 client:settimeout(0);
375 local conn = setmetatable({
376 conn = client;
377 server = server;
378 created = gettime();
379 listeners = listeners;
380 _pattern = pattern or server._pattern;
381 writebuffer = {};
382 tls = tls;
383 }, interface_mt);
384 if client.getpeername then
385 conn.peer = {client:getpeername()}
386 end
387
388 fds[conn:getfd()] = conn;
389 return conn;
390 end
391
392 function interface:onacceptable()
393 local conn, err = self.conn:accept();
394 if not conn then
395 log(debug, "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
396 self:pausefor(cfg.accept_retry_interval);
397 return;
398 end
399 local client = wrapsocket(conn, self, nil, self.listeners, self.tls);
400 if self.tls then
401 client._tls = false;
402 client:starttls();
403 else
404 self.listeners.onconnect(client);
405 client:setflags(true);
406 end
407 client:setreadtimeout();
408 end
409
410 function interface:pause()
411 self:setflags(false);
412 end
413
414 function interface:resume()
415 self:setflags(true);
416 end
417
418 function interface:pausefor(t)
419 if self._wantread then
420 self:setflags(false);
421 addtimer(t, function () self:setflags(true); end);
422 end
423 end
424
425 function interface:onconnect()
426 self.onreadable = nil;
427 self.onwriteable = nil;
428 self.listeners.onconnect(self);
429 end
430
431 local function addclient(addr, port, listeners, pattern, tls)
432 local conn, err = socket.connect(addr, port);
433 if not conn then return conn, err; end
434 return wrapsocket(conn, nil, pattern, listeners, tls);
435 end
436
437 local function addserver(addr, port, listeners, pattern, tls)
438 local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
439 if not conn then return conn, err; end
440 conn:settimeout(0);
441 local server = setmetatable({
442 conn = conn;
443 created = gettime();
444 listeners = listeners;
445 _pattern = pattern;
446 onreadable = interface.onacceptable;
447 tls = tls;
448 peer = { addr, port };
449 }, interface_mt);
450 server:setflags(true, false);
451 fds[server:getfd()] = server;
452 return server;
453 end
454
455 -- COMPAT
456 local function wrapclient(client, addr, port, listeners, mode, tls)
457 local conn = setmetatable({
458 conn = client;
459 created = gettime();
460 listeners = listeners;
461 _pattern = mode;
462 writebuffer = {};
463 tls = tls;
464 onreadable = interface.onconnect;
465 onwriteable = interface.onconnect;
466 peer = { addr, port };
467 }, interface_mt);
468 fds[conn:getfd()] = conn;
469 conn:setflags(true, true);
470 return conn;
471 end
472
473 local function link(from, to)
474 from.listeners = setmetatable({
475 onincoming = function (_, data)
476 from:pause();
477 to:write(data);
478 end,
479 }, {__index=from.listeners});
480 to.listeners = setmetatable({
481 ondrain = function ()
482 from:resume();
483 end,
484 }, {__index=to.listeners});
485 from:setflags(true, nil);
486 to:setflags(nil, true);
487 end
488
489 -- XXX What uses this?
490 -- net.adns
491 function interface:set_send(new_send)
492 self.send = new_send;
493 end
494
495 local quitting = nil;
496
497 local function setquitting()
498 quitting = "quitting";
499 end
500
501 local function loop()
502 repeat
503 local t = runtimers();
504 local fd, r, w = epoll.wait(t);
505 if fd then
506 local conn = fds[fd];
507 if conn then
508 if r then
509 conn:onreadable();
510 end
511 if w then
512 conn:onwriteable();
513 end
514 else
515 log("debug", "Removing unknown fd %d", fd);
516 epoll.ctl("del", fd);
517 end
518 elseif r ~= "timeout" then
519 log("debug", "epoll_wait error: %s", tostring(r));
520 end
521 until quitting;
522 return quitting;
523 end
524
525 return {
526 get_backend = function () return "epoll"; end;
527 addserver = addserver;
528 addclient = addclient;
529 add_task = addtimer;
530 at = at;
531 loop = loop;
532 setquitting = setquitting;
533 wrapclient = wrapclient;
534 link = link;
535
536 -- libevent emulation
537 event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
538 addevent = function (fd, mode, callback)
539 local function onevent(self)
540 local ret = self:callback();
541 if ret == -1 then
542 epoll.ctl("del", fd);
543 elseif ret then
544 epoll.ctl("mod", fd, mode);
545 end
546 end
547
548 local conn = {
549 callback = callback;
550 onreadable = onevent;
551 onwriteable = onevent;
552 close = function ()
553 fds[fd] = nil;
554 return epoll.ctl("del", fd);
555 end;
556 };
557 fds[fd] = conn;
558 local ok, err = epoll.ctl("add", fd, mode or "r");
559 if not ok then return ok, err; end
560 return conn;
561 end;
562 };