Software /
code /
prosody
Comparison
net/server_epoll.lua @ 9314:c9eef7e8ee65
net.server_epoll: Use util.poll
author | Kim Alvefur <zash@zash.se> |
---|---|
date | Wed, 16 May 2018 23:57:09 +0200 |
parent | 9311:9b0604fe01f1 |
child | 9319:7c954c75b6ac |
comparison
equal
deleted
inserted
replaced
9313:b95ef295c66d | 9314:c9eef7e8ee65 |
---|---|
3 -- | 3 -- |
4 -- This project is MIT/X11 licensed. Please see the | 4 -- This project is MIT/X11 licensed. Please see the |
5 -- COPYING file in the source package for more information. | 5 -- COPYING file in the source package for more information. |
6 -- | 6 -- |
7 | 7 |
8 -- server_epoll | |
9 -- Server backend based on https://luarocks.org/modules/zash/lua-epoll | |
10 | 8 |
11 local t_sort = table.sort; | 9 local t_sort = table.sort; |
12 local t_insert = table.insert; | 10 local t_insert = table.insert; |
13 local t_remove = table.remove; | 11 local t_remove = table.remove; |
14 local t_concat = table.concat; | 12 local t_concat = table.concat; |
17 local pcall = pcall; | 15 local pcall = pcall; |
18 local type = type; | 16 local type = type; |
19 local next = next; | 17 local next = next; |
20 local pairs = pairs; | 18 local pairs = pairs; |
21 local log = require "util.logger".init("server_epoll"); | 19 local log = require "util.logger".init("server_epoll"); |
22 local epoll = require "epoll"; | |
23 local socket = require "socket"; | 20 local socket = require "socket"; |
24 local luasec = require "ssl"; | 21 local luasec = require "ssl"; |
25 local gettime = require "util.time".now; | 22 local gettime = require "util.time".now; |
26 local createtable = require "util.table".create; | 23 local createtable = require "util.table".create; |
27 local _SOCKETINVALID = socket._SOCKETINVALID or -1; | 24 local _SOCKETINVALID = socket._SOCKETINVALID or -1; |
28 | 25 |
29 assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version"); | 26 local poll = require "util.poll".new(); |
30 | 27 |
31 local _ENV = nil; | 28 local _ENV = nil; |
32 -- luacheck: std none | 29 -- luacheck: std none |
33 | 30 |
34 local default_config = { __index = { | 31 local default_config = { __index = { |
258 self:destroy(); | 255 self:destroy(); |
259 end); | 256 end); |
260 end | 257 end |
261 end | 258 end |
262 | 259 |
263 -- lua-epoll flag for currently requested poll state | 260 function interface:add(r, w) |
264 function interface:flags() | |
265 if self._wantread then | |
266 if self._wantwrite then | |
267 return "rw"; | |
268 end | |
269 return "r"; | |
270 elseif self._wantwrite then | |
271 return "w"; | |
272 end | |
273 end | |
274 | |
275 -- Add or remove sockets or modify epoll flags | |
276 function interface:setflags(r, w) | |
277 if r ~= nil then self._wantread = r; end | |
278 if w ~= nil then self._wantwrite = w; end | |
279 local flags = self:flags(); | |
280 local currentflags = self._flags; | |
281 if flags == currentflags then | |
282 return true; | |
283 end | |
284 local fd = self:getfd(); | 261 local fd = self:getfd(); |
285 if fd < 0 then | 262 if fd < 0 then |
286 self._wantread, self._wantwrite = nil, nil; | |
287 return nil, "invalid fd"; | 263 return nil, "invalid fd"; |
288 end | 264 end |
289 local op = "mod"; | 265 if r == nil then r = self._wantread; end |
290 if not flags then | 266 if w == nil then w = self._wantwrite; end |
291 op = "del"; | 267 local ok, err = poll:add(fd, r, w); |
292 elseif not currentflags then | 268 if not ok then |
293 op = "add"; | 269 log("error", "Could not register %s: %s", self, err); |
294 end | 270 return ok, err; |
295 local ok, err = epoll.ctl(op, fd, flags); | 271 end |
296 -- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""), | 272 self._wantread, self._wantwrite = r, w; |
297 -- op, fd, flags or "", tostring(ok), err); | 273 fds[fd] = self; |
298 if not ok then return ok, err end | 274 log("debug", "Registered %s", self); |
299 if op == "add" then | 275 return true; |
300 fds[fd] = self; | 276 end |
301 elseif op == "del" then | 277 |
302 fds[fd] = nil; | 278 function interface:set(r, w) |
303 end | 279 local fd = self:getfd(); |
304 self._flags = flags; | 280 if fd < 0 then |
281 return nil, "invalid fd"; | |
282 end | |
283 if r == nil then r = self._wantread; end | |
284 if w == nil then w = self._wantwrite; end | |
285 local ok, err = poll:set(fd, r, w); | |
286 if not ok then | |
287 log("error", "Could not update poller state %s: %s", self, err); | |
288 return ok, err; | |
289 end | |
290 self._wantread, self._wantwrite = r, w; | |
291 return true; | |
292 end | |
293 | |
294 function interface:del() | |
295 local fd = self:getfd(); | |
296 if fd < 0 then | |
297 return nil, "invalid fd"; | |
298 end | |
299 if fds[fd] ~= self then | |
300 return nil, "unregistered fd"; | |
301 end | |
302 local ok, err = poll:del(fd); | |
303 if not ok then | |
304 log("error", "Could not unregister %s: %s", self, err); | |
305 return ok, err; | |
306 end | |
307 self._wantread, self._wantwrite = nil, nil; | |
308 fds[fd] = nil; | |
309 log("debug", "Unregistered %s", self); | |
305 return true; | 310 return true; |
306 end | 311 end |
307 | 312 |
308 -- Called when socket is readable | 313 -- Called when socket is readable |
309 function interface:onreadable() | 314 function interface:onreadable() |
315 if partial and partial ~= "" then | 320 if partial and partial ~= "" then |
316 self:onconnect(); | 321 self:onconnect(); |
317 self:on("incoming", partial, err); | 322 self:on("incoming", partial, err); |
318 end | 323 end |
319 if err == "wantread" then | 324 if err == "wantread" then |
320 self:setflags(true, nil); | 325 self:set(true, nil); |
321 elseif err == "wantwrite" then | 326 elseif err == "wantwrite" then |
322 self:setflags(nil, true); | 327 self:set(nil, true); |
323 elseif err ~= "timeout" then | 328 elseif err ~= "timeout" then |
324 self:on("disconnect", err); | 329 self:on("disconnect", err); |
325 self:destroy() | 330 self:destroy() |
326 return; | 331 return; |
327 end | 332 end |
341 if not self.conn then return; end -- could have been closed in onconnect | 346 if not self.conn then return; end -- could have been closed in onconnect |
342 local buffer = self.writebuffer; | 347 local buffer = self.writebuffer; |
343 local data = t_concat(buffer); | 348 local data = t_concat(buffer); |
344 local ok, err, partial = self.conn:send(data); | 349 local ok, err, partial = self.conn:send(data); |
345 if ok then | 350 if ok then |
346 self:setflags(nil, false); | 351 self:set(nil, false); |
347 for i = #buffer, 1, -1 do | 352 for i = #buffer, 1, -1 do |
348 buffer[i] = nil; | 353 buffer[i] = nil; |
349 end | 354 end |
350 self:setwritetimeout(false); | 355 self:setwritetimeout(false); |
351 self:ondrain(); -- Be aware of writes in ondrain | 356 self:ondrain(); -- Be aware of writes in ondrain |
356 buffer[i] = nil; | 361 buffer[i] = nil; |
357 end | 362 end |
358 self:setwritetimeout(); | 363 self:setwritetimeout(); |
359 end | 364 end |
360 if err == "wantwrite" or err == "timeout" then | 365 if err == "wantwrite" or err == "timeout" then |
361 self:setflags(nil, true); | 366 self:set(nil, true); |
362 elseif err == "wantread" then | 367 elseif err == "wantread" then |
363 self:setflags(true, nil); | 368 self:set(true, nil); |
364 elseif err ~= "timeout" then | 369 elseif err ~= "timeout" then |
365 self:on("disconnect", err); | 370 self:on("disconnect", err); |
366 self:destroy(); | 371 self:destroy(); |
367 end | 372 end |
368 end | 373 end |
379 t_insert(buffer, data); | 384 t_insert(buffer, data); |
380 else | 385 else |
381 self.writebuffer = { data }; | 386 self.writebuffer = { data }; |
382 end | 387 end |
383 self:setwritetimeout(); | 388 self:setwritetimeout(); |
384 self:setflags(nil, true); | 389 self:set(nil, true); |
385 return #data; | 390 return #data; |
386 end | 391 end |
387 interface.send = interface.write; | 392 interface.send = interface.write; |
388 | 393 |
389 -- Close, possibly after writing is done | 394 -- Close, possibly after writing is done |
390 function interface:close() | 395 function interface:close() |
391 if self.writebuffer and self.writebuffer[1] then | 396 if self.writebuffer and self.writebuffer[1] then |
392 self:setflags(false, true); -- Flush final buffer contents | 397 self:set(false, true); -- Flush final buffer contents |
393 self.write, self.send = noop, noop; -- No more writing | 398 self.write, self.send = noop, noop; -- No more writing |
394 log("debug", "Close %s after writing", self); | 399 log("debug", "Close %s after writing", self); |
395 self.ondrain = interface.close; | 400 self.ondrain = interface.close; |
396 else | 401 else |
397 log("debug", "Close %s now", self); | 402 log("debug", "Close %s now", self); |
401 self:destroy(); | 406 self:destroy(); |
402 end | 407 end |
403 end | 408 end |
404 | 409 |
405 function interface:destroy() | 410 function interface:destroy() |
406 self:setflags(false, false); | 411 self:del(); |
407 self:setwritetimeout(false); | 412 self:setwritetimeout(false); |
408 self:setreadtimeout(false); | 413 self:setreadtimeout(false); |
409 self.onreadable = noop; | 414 self.onreadable = noop; |
410 self.onwritable = noop; | 415 self.onwritable = noop; |
411 self.destroy = noop; | 416 self.destroy = noop; |
423 if tls_ctx then self.tls_ctx = tls_ctx; end | 428 if tls_ctx then self.tls_ctx = tls_ctx; end |
424 if self.writebuffer and self.writebuffer[1] then | 429 if self.writebuffer and self.writebuffer[1] then |
425 log("debug", "Start TLS on %s after write", self); | 430 log("debug", "Start TLS on %s after write", self); |
426 self.ondrain = interface.starttls; | 431 self.ondrain = interface.starttls; |
427 self.starttls = false; | 432 self.starttls = false; |
428 self:setflags(nil, true); -- make sure wantwrite is set | 433 self:set(nil, true); -- make sure wantwrite is set |
429 else | 434 else |
430 log("debug", "Start TLS on %s now", self); | 435 log("debug", "Start TLS on %s now", self); |
431 self:setflags(false, false); | 436 self:del(); |
432 local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx); | 437 local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx); |
433 if not conn then | 438 if not conn then |
434 self:on("disconnect", err); | 439 self:on("disconnect", err); |
435 self:destroy(); | 440 self:destroy(); |
436 return conn, err; | 441 return conn, err; |
438 conn:settimeout(0); | 443 conn:settimeout(0); |
439 self.conn = conn; | 444 self.conn = conn; |
440 self.ondrain = nil; | 445 self.ondrain = nil; |
441 self.onwritable = interface.tlshandskake; | 446 self.onwritable = interface.tlshandskake; |
442 self.onreadable = interface.tlshandskake; | 447 self.onreadable = interface.tlshandskake; |
443 self:setflags(true, true); | 448 return self:init(); |
444 self:setwritetimeout(cfg.handshake_timeout); | |
445 end | 449 end |
446 end | 450 end |
447 | 451 |
448 function interface:tlshandskake() | 452 function interface:tlshandskake() |
449 self:setwritetimeout(false); | 453 self:setwritetimeout(false); |
453 log("debug", "TLS handshake on %s complete", self); | 457 log("debug", "TLS handshake on %s complete", self); |
454 self.onwritable = nil; | 458 self.onwritable = nil; |
455 self.onreadable = nil; | 459 self.onreadable = nil; |
456 self._tls = true; | 460 self._tls = true; |
457 self:on("status", "ssl-handshake-complete"); | 461 self:on("status", "ssl-handshake-complete"); |
458 self:init(); | 462 self:setwritetimeout(); |
463 self:set(true, true); | |
459 elseif err == "wantread" then | 464 elseif err == "wantread" then |
460 log("debug", "TLS handshake on %s to wait until readable", self); | 465 log("debug", "TLS handshake on %s to wait until readable", self); |
461 self:setflags(true, false); | 466 self:set(true, false); |
462 self:setreadtimeout(cfg.handshake_timeout); | 467 self:setreadtimeout(cfg.handshake_timeout); |
463 elseif err == "wantwrite" then | 468 elseif err == "wantwrite" then |
464 log("debug", "TLS handshake on %s to wait until writable", self); | 469 log("debug", "TLS handshake on %s to wait until writable", self); |
465 self:setflags(false, true); | 470 self:set(false, true); |
466 self:setwritetimeout(cfg.handshake_timeout); | 471 self:setwritetimeout(cfg.handshake_timeout); |
467 else | 472 else |
468 log("debug", "TLS handshake error on %s: %s", self, err); | 473 log("debug", "TLS handshake error on %s: %s", self, err); |
469 self:on("disconnect", err); | 474 self:on("disconnect", err); |
470 self:destroy(); | 475 self:destroy(); |
511 end | 516 end |
512 | 517 |
513 -- Initialization | 518 -- Initialization |
514 function interface:init() | 519 function interface:init() |
515 self:setwritetimeout(); | 520 self:setwritetimeout(); |
516 return self:setflags(true, true); | 521 return self:add(true, true); |
517 end | 522 end |
518 | 523 |
519 function interface:pause() | 524 function interface:pause() |
520 return self:setflags(false); | 525 return self:set(false); |
521 end | 526 end |
522 | 527 |
523 function interface:resume() | 528 function interface:resume() |
524 return self:setflags(true); | 529 return self:set(true); |
525 end | 530 end |
526 | 531 |
527 -- Pause connection for some time | 532 -- Pause connection for some time |
528 function interface:pausefor(t) | 533 function interface:pausefor(t) |
529 if self._pausefor then | 534 if self._pausefor then |
530 self._pausefor:close(); | 535 self._pausefor:close(); |
531 end | 536 end |
532 if t == false then return; end | 537 if t == false then return; end |
533 self:setflags(false); | 538 self:set(false); |
534 self._pausefor = addtimer(t, function () | 539 self._pausefor = addtimer(t, function () |
535 self._pausefor = nil; | 540 self._pausefor = nil; |
536 if self.conn and self.conn:dirty() then | 541 if self.conn and self.conn:dirty() then |
537 self:onreadable(); | 542 self:onreadable(); |
538 end | 543 end |
539 self:setflags(true); | 544 self:set(true); |
540 end); | 545 end); |
541 end | 546 end |
542 | 547 |
543 -- Connected! | 548 -- Connected! |
544 function interface:onconnect() | 549 function interface:onconnect() |
562 tls_ctx = tls_ctx; | 567 tls_ctx = tls_ctx; |
563 tls_direct = tls_ctx and true or false; | 568 tls_direct = tls_ctx and true or false; |
564 sockname = addr; | 569 sockname = addr; |
565 sockport = port; | 570 sockport = port; |
566 }, interface_mt); | 571 }, interface_mt); |
567 server:setflags(true, false); | 572 server:add(true, false); |
568 return server; | 573 return server; |
569 end | 574 end |
570 | 575 |
571 -- COMPAT | 576 -- COMPAT |
572 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) | 577 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) |
601 local conn = setmetatable({ | 606 local conn = setmetatable({ |
602 conn = fd; | 607 conn = fd; |
603 onreadable = onreadable; | 608 onreadable = onreadable; |
604 onwriteable = onwriteable; | 609 onwriteable = onwriteable; |
605 close = function (self) | 610 close = function (self) |
606 self:setflags(false, false); | 611 self:del(); |
607 end | 612 end |
608 }, interface_mt); | 613 }, interface_mt); |
609 if type(fd) == "number" then | 614 if type(fd) == "number" then |
610 conn.getfd = function () | 615 conn.getfd = function () |
611 return fd; | 616 return fd; |
612 end; | 617 end; |
613 -- Otherwise it'll need to be something LuaSocket-compatible | 618 -- Otherwise it'll need to be something LuaSocket-compatible |
614 end | 619 end |
615 conn:setflags(onreadable, onwriteable); | 620 conn:add(onreadable, onwriteable); |
616 return conn; | 621 return conn; |
617 end; | 622 end; |
618 | 623 |
619 -- Dump all data from one connection into another | 624 -- Dump all data from one connection into another |
620 local function link(from, to) | 625 local function link(from, to) |
627 to.listeners = setmetatable({ | 632 to.listeners = setmetatable({ |
628 ondrain = function () | 633 ondrain = function () |
629 from:resume(); | 634 from:resume(); |
630 end, | 635 end, |
631 }, {__index=to.listeners}); | 636 }, {__index=to.listeners}); |
632 from:setflags(true, nil); | 637 from:set(true, nil); |
633 to:setflags(nil, true); | 638 to:set(nil, true); |
634 end | 639 end |
635 | 640 |
636 -- XXX What uses this? | 641 -- XXX What uses this? |
637 -- net.adns | 642 -- net.adns |
638 function interface:set_send(new_send) | 643 function interface:set_send(new_send) |
660 | 665 |
661 -- Main loop | 666 -- Main loop |
662 local function loop(once) | 667 local function loop(once) |
663 repeat | 668 repeat |
664 local t = runtimers(cfg.max_wait, cfg.min_wait); | 669 local t = runtimers(cfg.max_wait, cfg.min_wait); |
665 local fd, r, w = epoll.wait(t); | 670 local fd, r, w = poll:wait(t); |
666 if fd then | 671 if fd then |
667 local conn = fds[fd]; | 672 local conn = fds[fd]; |
668 if conn then | 673 if conn then |
669 if r then | 674 if r then |
670 conn:onreadable(); | 675 conn:onreadable(); |
672 if w then | 677 if w then |
673 conn:onwritable(); | 678 conn:onwritable(); |
674 end | 679 end |
675 else | 680 else |
676 log("debug", "Removing unknown fd %d", fd); | 681 log("debug", "Removing unknown fd %d", fd); |
677 epoll.ctl("del", fd); | 682 poll:del(fd); |
678 end | 683 end |
679 elseif r ~= "timeout" then | 684 elseif r ~= "timeout" then |
680 log("debug", "epoll_wait error: %s", tostring(r)); | 685 log("debug", "epoll_wait error: %s", tostring(r)); |
681 end | 686 end |
682 until once or (quitting and next(fds) == nil); | 687 until once or (quitting and next(fds) == nil); |
703 event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; | 708 event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; |
704 addevent = function (fd, mode, callback) | 709 addevent = function (fd, mode, callback) |
705 local function onevent(self) | 710 local function onevent(self) |
706 local ret = self:callback(); | 711 local ret = self:callback(); |
707 if ret == -1 then | 712 if ret == -1 then |
708 self:setflags(false, false); | 713 self:set(false, false); |
709 elseif ret then | 714 elseif ret then |
710 self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); | 715 self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); |
711 end | 716 end |
712 end | 717 end |
713 | 718 |
714 local conn = setmetatable({ | 719 local conn = setmetatable({ |
715 getfd = function () return fd; end; | 720 getfd = function () return fd; end; |
716 callback = callback; | 721 callback = callback; |
717 onreadable = onevent; | 722 onreadable = onevent; |
718 onwritable = onevent; | 723 onwritable = onevent; |
719 close = function (self) | 724 close = function (self) |
720 self:setflags(false, false); | 725 self:del(); |
721 fds[fd] = nil; | 726 fds[fd] = nil; |
722 end; | 727 end; |
723 }, interface_mt); | 728 }, interface_mt); |
724 local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); | 729 local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); |
725 if not ok then return ok, err; end | 730 if not ok then return ok, err; end |
726 return conn; | 731 return conn; |
727 end; | 732 end; |
728 }; | 733 }; |