Comparison

net/server_epoll.lua @ 9314:c9eef7e8ee65

net.server_epoll: Use util.poll
author Kim Alvefur <zash@zash.se>
date Wed, 16 May 2018 23:57:09 +0200
parent 9311:9b0604fe01f1
child 9319:7c954c75b6ac
comparison
equal deleted inserted replaced
9313:b95ef295c66d 9314:c9eef7e8ee65
3 -- 3 --
4 -- This project is MIT/X11 licensed. Please see the 4 -- This project is MIT/X11 licensed. Please see the
5 -- COPYING file in the source package for more information. 5 -- COPYING file in the source package for more information.
6 -- 6 --
7 7
8 -- server_epoll
9 -- Server backend based on https://luarocks.org/modules/zash/lua-epoll
10 8
11 local t_sort = table.sort; 9 local t_sort = table.sort;
12 local t_insert = table.insert; 10 local t_insert = table.insert;
13 local t_remove = table.remove; 11 local t_remove = table.remove;
14 local t_concat = table.concat; 12 local t_concat = table.concat;
17 local pcall = pcall; 15 local pcall = pcall;
18 local type = type; 16 local type = type;
19 local next = next; 17 local next = next;
20 local pairs = pairs; 18 local pairs = pairs;
21 local log = require "util.logger".init("server_epoll"); 19 local log = require "util.logger".init("server_epoll");
22 local epoll = require "epoll";
23 local socket = require "socket"; 20 local socket = require "socket";
24 local luasec = require "ssl"; 21 local luasec = require "ssl";
25 local gettime = require "util.time".now; 22 local gettime = require "util.time".now;
26 local createtable = require "util.table".create; 23 local createtable = require "util.table".create;
27 local _SOCKETINVALID = socket._SOCKETINVALID or -1; 24 local _SOCKETINVALID = socket._SOCKETINVALID or -1;
28 25
29 assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version"); 26 local poll = require "util.poll".new();
30 27
31 local _ENV = nil; 28 local _ENV = nil;
32 -- luacheck: std none 29 -- luacheck: std none
33 30
34 local default_config = { __index = { 31 local default_config = { __index = {
258 self:destroy(); 255 self:destroy();
259 end); 256 end);
260 end 257 end
261 end 258 end
262 259
263 -- lua-epoll flag for currently requested poll state 260 function interface:add(r, w)
264 function interface:flags()
265 if self._wantread then
266 if self._wantwrite then
267 return "rw";
268 end
269 return "r";
270 elseif self._wantwrite then
271 return "w";
272 end
273 end
274
275 -- Add or remove sockets or modify epoll flags
276 function interface:setflags(r, w)
277 if r ~= nil then self._wantread = r; end
278 if w ~= nil then self._wantwrite = w; end
279 local flags = self:flags();
280 local currentflags = self._flags;
281 if flags == currentflags then
282 return true;
283 end
284 local fd = self:getfd(); 261 local fd = self:getfd();
285 if fd < 0 then 262 if fd < 0 then
286 self._wantread, self._wantwrite = nil, nil;
287 return nil, "invalid fd"; 263 return nil, "invalid fd";
288 end 264 end
289 local op = "mod"; 265 if r == nil then r = self._wantread; end
290 if not flags then 266 if w == nil then w = self._wantwrite; end
291 op = "del"; 267 local ok, err = poll:add(fd, r, w);
292 elseif not currentflags then 268 if not ok then
293 op = "add"; 269 log("error", "Could not register %s: %s", self, err);
294 end 270 return ok, err;
295 local ok, err = epoll.ctl(op, fd, flags); 271 end
296 -- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""), 272 self._wantread, self._wantwrite = r, w;
297 -- op, fd, flags or "", tostring(ok), err); 273 fds[fd] = self;
298 if not ok then return ok, err end 274 log("debug", "Registered %s", self);
299 if op == "add" then 275 return true;
300 fds[fd] = self; 276 end
301 elseif op == "del" then 277
302 fds[fd] = nil; 278 function interface:set(r, w)
303 end 279 local fd = self:getfd();
304 self._flags = flags; 280 if fd < 0 then
281 return nil, "invalid fd";
282 end
283 if r == nil then r = self._wantread; end
284 if w == nil then w = self._wantwrite; end
285 local ok, err = poll:set(fd, r, w);
286 if not ok then
287 log("error", "Could not update poller state %s: %s", self, err);
288 return ok, err;
289 end
290 self._wantread, self._wantwrite = r, w;
291 return true;
292 end
293
294 function interface:del()
295 local fd = self:getfd();
296 if fd < 0 then
297 return nil, "invalid fd";
298 end
299 if fds[fd] ~= self then
300 return nil, "unregistered fd";
301 end
302 local ok, err = poll:del(fd);
303 if not ok then
304 log("error", "Could not unregister %s: %s", self, err);
305 return ok, err;
306 end
307 self._wantread, self._wantwrite = nil, nil;
308 fds[fd] = nil;
309 log("debug", "Unregistered %s", self);
305 return true; 310 return true;
306 end 311 end
307 312
308 -- Called when socket is readable 313 -- Called when socket is readable
309 function interface:onreadable() 314 function interface:onreadable()
315 if partial and partial ~= "" then 320 if partial and partial ~= "" then
316 self:onconnect(); 321 self:onconnect();
317 self:on("incoming", partial, err); 322 self:on("incoming", partial, err);
318 end 323 end
319 if err == "wantread" then 324 if err == "wantread" then
320 self:setflags(true, nil); 325 self:set(true, nil);
321 elseif err == "wantwrite" then 326 elseif err == "wantwrite" then
322 self:setflags(nil, true); 327 self:set(nil, true);
323 elseif err ~= "timeout" then 328 elseif err ~= "timeout" then
324 self:on("disconnect", err); 329 self:on("disconnect", err);
325 self:destroy() 330 self:destroy()
326 return; 331 return;
327 end 332 end
341 if not self.conn then return; end -- could have been closed in onconnect 346 if not self.conn then return; end -- could have been closed in onconnect
342 local buffer = self.writebuffer; 347 local buffer = self.writebuffer;
343 local data = t_concat(buffer); 348 local data = t_concat(buffer);
344 local ok, err, partial = self.conn:send(data); 349 local ok, err, partial = self.conn:send(data);
345 if ok then 350 if ok then
346 self:setflags(nil, false); 351 self:set(nil, false);
347 for i = #buffer, 1, -1 do 352 for i = #buffer, 1, -1 do
348 buffer[i] = nil; 353 buffer[i] = nil;
349 end 354 end
350 self:setwritetimeout(false); 355 self:setwritetimeout(false);
351 self:ondrain(); -- Be aware of writes in ondrain 356 self:ondrain(); -- Be aware of writes in ondrain
356 buffer[i] = nil; 361 buffer[i] = nil;
357 end 362 end
358 self:setwritetimeout(); 363 self:setwritetimeout();
359 end 364 end
360 if err == "wantwrite" or err == "timeout" then 365 if err == "wantwrite" or err == "timeout" then
361 self:setflags(nil, true); 366 self:set(nil, true);
362 elseif err == "wantread" then 367 elseif err == "wantread" then
363 self:setflags(true, nil); 368 self:set(true, nil);
364 elseif err ~= "timeout" then 369 elseif err ~= "timeout" then
365 self:on("disconnect", err); 370 self:on("disconnect", err);
366 self:destroy(); 371 self:destroy();
367 end 372 end
368 end 373 end
379 t_insert(buffer, data); 384 t_insert(buffer, data);
380 else 385 else
381 self.writebuffer = { data }; 386 self.writebuffer = { data };
382 end 387 end
383 self:setwritetimeout(); 388 self:setwritetimeout();
384 self:setflags(nil, true); 389 self:set(nil, true);
385 return #data; 390 return #data;
386 end 391 end
387 interface.send = interface.write; 392 interface.send = interface.write;
388 393
389 -- Close, possibly after writing is done 394 -- Close, possibly after writing is done
390 function interface:close() 395 function interface:close()
391 if self.writebuffer and self.writebuffer[1] then 396 if self.writebuffer and self.writebuffer[1] then
392 self:setflags(false, true); -- Flush final buffer contents 397 self:set(false, true); -- Flush final buffer contents
393 self.write, self.send = noop, noop; -- No more writing 398 self.write, self.send = noop, noop; -- No more writing
394 log("debug", "Close %s after writing", self); 399 log("debug", "Close %s after writing", self);
395 self.ondrain = interface.close; 400 self.ondrain = interface.close;
396 else 401 else
397 log("debug", "Close %s now", self); 402 log("debug", "Close %s now", self);
401 self:destroy(); 406 self:destroy();
402 end 407 end
403 end 408 end
404 409
405 function interface:destroy() 410 function interface:destroy()
406 self:setflags(false, false); 411 self:del();
407 self:setwritetimeout(false); 412 self:setwritetimeout(false);
408 self:setreadtimeout(false); 413 self:setreadtimeout(false);
409 self.onreadable = noop; 414 self.onreadable = noop;
410 self.onwritable = noop; 415 self.onwritable = noop;
411 self.destroy = noop; 416 self.destroy = noop;
423 if tls_ctx then self.tls_ctx = tls_ctx; end 428 if tls_ctx then self.tls_ctx = tls_ctx; end
424 if self.writebuffer and self.writebuffer[1] then 429 if self.writebuffer and self.writebuffer[1] then
425 log("debug", "Start TLS on %s after write", self); 430 log("debug", "Start TLS on %s after write", self);
426 self.ondrain = interface.starttls; 431 self.ondrain = interface.starttls;
427 self.starttls = false; 432 self.starttls = false;
428 self:setflags(nil, true); -- make sure wantwrite is set 433 self:set(nil, true); -- make sure wantwrite is set
429 else 434 else
430 log("debug", "Start TLS on %s now", self); 435 log("debug", "Start TLS on %s now", self);
431 self:setflags(false, false); 436 self:del();
432 local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx); 437 local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx);
433 if not conn then 438 if not conn then
434 self:on("disconnect", err); 439 self:on("disconnect", err);
435 self:destroy(); 440 self:destroy();
436 return conn, err; 441 return conn, err;
438 conn:settimeout(0); 443 conn:settimeout(0);
439 self.conn = conn; 444 self.conn = conn;
440 self.ondrain = nil; 445 self.ondrain = nil;
441 self.onwritable = interface.tlshandskake; 446 self.onwritable = interface.tlshandskake;
442 self.onreadable = interface.tlshandskake; 447 self.onreadable = interface.tlshandskake;
443 self:setflags(true, true); 448 return self:init();
444 self:setwritetimeout(cfg.handshake_timeout);
445 end 449 end
446 end 450 end
447 451
448 function interface:tlshandskake() 452 function interface:tlshandskake()
449 self:setwritetimeout(false); 453 self:setwritetimeout(false);
453 log("debug", "TLS handshake on %s complete", self); 457 log("debug", "TLS handshake on %s complete", self);
454 self.onwritable = nil; 458 self.onwritable = nil;
455 self.onreadable = nil; 459 self.onreadable = nil;
456 self._tls = true; 460 self._tls = true;
457 self:on("status", "ssl-handshake-complete"); 461 self:on("status", "ssl-handshake-complete");
458 self:init(); 462 self:setwritetimeout();
463 self:set(true, true);
459 elseif err == "wantread" then 464 elseif err == "wantread" then
460 log("debug", "TLS handshake on %s to wait until readable", self); 465 log("debug", "TLS handshake on %s to wait until readable", self);
461 self:setflags(true, false); 466 self:set(true, false);
462 self:setreadtimeout(cfg.handshake_timeout); 467 self:setreadtimeout(cfg.handshake_timeout);
463 elseif err == "wantwrite" then 468 elseif err == "wantwrite" then
464 log("debug", "TLS handshake on %s to wait until writable", self); 469 log("debug", "TLS handshake on %s to wait until writable", self);
465 self:setflags(false, true); 470 self:set(false, true);
466 self:setwritetimeout(cfg.handshake_timeout); 471 self:setwritetimeout(cfg.handshake_timeout);
467 else 472 else
468 log("debug", "TLS handshake error on %s: %s", self, err); 473 log("debug", "TLS handshake error on %s: %s", self, err);
469 self:on("disconnect", err); 474 self:on("disconnect", err);
470 self:destroy(); 475 self:destroy();
511 end 516 end
512 517
513 -- Initialization 518 -- Initialization
514 function interface:init() 519 function interface:init()
515 self:setwritetimeout(); 520 self:setwritetimeout();
516 return self:setflags(true, true); 521 return self:add(true, true);
517 end 522 end
518 523
519 function interface:pause() 524 function interface:pause()
520 return self:setflags(false); 525 return self:set(false);
521 end 526 end
522 527
523 function interface:resume() 528 function interface:resume()
524 return self:setflags(true); 529 return self:set(true);
525 end 530 end
526 531
527 -- Pause connection for some time 532 -- Pause connection for some time
528 function interface:pausefor(t) 533 function interface:pausefor(t)
529 if self._pausefor then 534 if self._pausefor then
530 self._pausefor:close(); 535 self._pausefor:close();
531 end 536 end
532 if t == false then return; end 537 if t == false then return; end
533 self:setflags(false); 538 self:set(false);
534 self._pausefor = addtimer(t, function () 539 self._pausefor = addtimer(t, function ()
535 self._pausefor = nil; 540 self._pausefor = nil;
536 if self.conn and self.conn:dirty() then 541 if self.conn and self.conn:dirty() then
537 self:onreadable(); 542 self:onreadable();
538 end 543 end
539 self:setflags(true); 544 self:set(true);
540 end); 545 end);
541 end 546 end
542 547
543 -- Connected! 548 -- Connected!
544 function interface:onconnect() 549 function interface:onconnect()
562 tls_ctx = tls_ctx; 567 tls_ctx = tls_ctx;
563 tls_direct = tls_ctx and true or false; 568 tls_direct = tls_ctx and true or false;
564 sockname = addr; 569 sockname = addr;
565 sockport = port; 570 sockport = port;
566 }, interface_mt); 571 }, interface_mt);
567 server:setflags(true, false); 572 server:add(true, false);
568 return server; 573 return server;
569 end 574 end
570 575
571 -- COMPAT 576 -- COMPAT
572 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) 577 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
601 local conn = setmetatable({ 606 local conn = setmetatable({
602 conn = fd; 607 conn = fd;
603 onreadable = onreadable; 608 onreadable = onreadable;
604 onwriteable = onwriteable; 609 onwriteable = onwriteable;
605 close = function (self) 610 close = function (self)
606 self:setflags(false, false); 611 self:del();
607 end 612 end
608 }, interface_mt); 613 }, interface_mt);
609 if type(fd) == "number" then 614 if type(fd) == "number" then
610 conn.getfd = function () 615 conn.getfd = function ()
611 return fd; 616 return fd;
612 end; 617 end;
613 -- Otherwise it'll need to be something LuaSocket-compatible 618 -- Otherwise it'll need to be something LuaSocket-compatible
614 end 619 end
615 conn:setflags(onreadable, onwriteable); 620 conn:add(onreadable, onwriteable);
616 return conn; 621 return conn;
617 end; 622 end;
618 623
619 -- Dump all data from one connection into another 624 -- Dump all data from one connection into another
620 local function link(from, to) 625 local function link(from, to)
627 to.listeners = setmetatable({ 632 to.listeners = setmetatable({
628 ondrain = function () 633 ondrain = function ()
629 from:resume(); 634 from:resume();
630 end, 635 end,
631 }, {__index=to.listeners}); 636 }, {__index=to.listeners});
632 from:setflags(true, nil); 637 from:set(true, nil);
633 to:setflags(nil, true); 638 to:set(nil, true);
634 end 639 end
635 640
636 -- XXX What uses this? 641 -- XXX What uses this?
637 -- net.adns 642 -- net.adns
638 function interface:set_send(new_send) 643 function interface:set_send(new_send)
660 665
661 -- Main loop 666 -- Main loop
662 local function loop(once) 667 local function loop(once)
663 repeat 668 repeat
664 local t = runtimers(cfg.max_wait, cfg.min_wait); 669 local t = runtimers(cfg.max_wait, cfg.min_wait);
665 local fd, r, w = epoll.wait(t); 670 local fd, r, w = poll:wait(t);
666 if fd then 671 if fd then
667 local conn = fds[fd]; 672 local conn = fds[fd];
668 if conn then 673 if conn then
669 if r then 674 if r then
670 conn:onreadable(); 675 conn:onreadable();
672 if w then 677 if w then
673 conn:onwritable(); 678 conn:onwritable();
674 end 679 end
675 else 680 else
676 log("debug", "Removing unknown fd %d", fd); 681 log("debug", "Removing unknown fd %d", fd);
677 epoll.ctl("del", fd); 682 poll:del(fd);
678 end 683 end
679 elseif r ~= "timeout" then 684 elseif r ~= "timeout" then
680 log("debug", "epoll_wait error: %s", tostring(r)); 685 log("debug", "epoll_wait error: %s", tostring(r));
681 end 686 end
682 until once or (quitting and next(fds) == nil); 687 until once or (quitting and next(fds) == nil);
703 event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; 708 event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
704 addevent = function (fd, mode, callback) 709 addevent = function (fd, mode, callback)
705 local function onevent(self) 710 local function onevent(self)
706 local ret = self:callback(); 711 local ret = self:callback();
707 if ret == -1 then 712 if ret == -1 then
708 self:setflags(false, false); 713 self:set(false, false);
709 elseif ret then 714 elseif ret then
710 self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); 715 self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
711 end 716 end
712 end 717 end
713 718
714 local conn = setmetatable({ 719 local conn = setmetatable({
715 getfd = function () return fd; end; 720 getfd = function () return fd; end;
716 callback = callback; 721 callback = callback;
717 onreadable = onevent; 722 onreadable = onevent;
718 onwritable = onevent; 723 onwritable = onevent;
719 close = function (self) 724 close = function (self)
720 self:setflags(false, false); 725 self:del();
721 fds[fd] = nil; 726 fds[fd] = nil;
722 end; 727 end;
723 }, interface_mt); 728 }, interface_mt);
724 local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); 729 local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
725 if not ok then return ok, err; end 730 if not ok then return ok, err; end
726 return conn; 731 return conn;
727 end; 732 end;
728 }; 733 };