Comparison

net/server_epoll.lua @ 9430:412ff404bf58

net.server_epoll: Delay wrapping sockets in TLS until just before first handshake
author Kim Alvefur <zash@zash.se>
date Fri, 14 Sep 2018 01:34:38 +0200
parent 9387:33e52f727f0f
child 9431:c3c0523a37c6
comparison
equal deleted inserted replaced
9429:5f51710d7c1e 9430:412ff404bf58
438 return self._tls; 438 return self._tls;
439 end 439 end
440 440
441 function interface:starttls(tls_ctx) 441 function interface:starttls(tls_ctx)
442 if tls_ctx then self.tls_ctx = tls_ctx; end 442 if tls_ctx then self.tls_ctx = tls_ctx; end
443 self.starttls = false;
443 if self.writebuffer and self.writebuffer[1] then 444 if self.writebuffer and self.writebuffer[1] then
444 log("debug", "Start TLS on %s after write", self); 445 log("debug", "Start TLS on %s after write", self);
445 self.ondrain = interface.starttls; 446 self.ondrain = interface.starttls;
446 self.starttls = false;
447 self:set(nil, true); -- make sure wantwrite is set 447 self:set(nil, true); -- make sure wantwrite is set
448 else 448 else
449 if self.ondrain == interface.starttls then
450 self.ondrain = nil;
451 end
452 self.onwritable = interface.tlshandskake;
453 self.onreadable = interface.tlshandskake;
454 self:set(true, true);
455 log("debug", "Prepare to start TLS on %s", self);
456 end
457 end
458
459 function interface:tlshandskake()
460 self:setwritetimeout(false);
461 self:setreadtimeout(false);
462 if not self._tls then
463 self._tls = true;
449 log("debug", "Start TLS on %s now", self); 464 log("debug", "Start TLS on %s now", self);
450 self:del(); 465 self:del();
451 local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx); 466 local conn, err = luasec.wrap(self.conn, self.tls_ctx);
452 if not conn then 467 if not conn then
453 self:on("disconnect", err); 468 self:on("disconnect", err);
454 self:destroy(); 469 self:destroy();
455 return conn, err; 470 return conn, err;
456 end 471 end
457 conn:settimeout(0); 472 conn:settimeout(0);
458 self.conn = conn; 473 self.conn = conn;
474 self:on("starttls");
459 self.ondrain = nil; 475 self.ondrain = nil;
460 self.onwritable = interface.tlshandskake; 476 self.onwritable = interface.tlshandskake;
461 self.onreadable = interface.tlshandskake; 477 self.onreadable = interface.tlshandskake;
462 return self:init(); 478 return self:init();
463 end 479 end
464 end
465
466 function interface:tlshandskake()
467 self:setwritetimeout(false);
468 self:setreadtimeout(false);
469 local ok, err = self.conn:dohandshake(); 480 local ok, err = self.conn:dohandshake();
470 if ok then 481 if ok then
471 log("debug", "TLS handshake on %s complete", self); 482 log("debug", "TLS handshake on %s complete", self);
472 self.onwritable = nil; 483 self.onwritable = nil;
473 self.onreadable = nil; 484 self.onreadable = nil;
474 self._tls = true;
475 self:on("status", "ssl-handshake-complete"); 485 self:on("status", "ssl-handshake-complete");
476 self:setwritetimeout(); 486 self:setwritetimeout();
477 self:set(true, true); 487 self:set(true, true);
478 elseif err == "wantread" then 488 elseif err == "wantread" then
479 log("debug", "TLS handshake on %s to wait until readable", self); 489 log("debug", "TLS handshake on %s to wait until readable", self);
527 self:pausefor(cfg.accept_retry_interval); 537 self:pausefor(cfg.accept_retry_interval);
528 return; 538 return;
529 end 539 end
530 local client = wrapsocket(conn, self, nil, self.listeners); 540 local client = wrapsocket(conn, self, nil, self.listeners);
531 log("debug", "New connection %s", tostring(client)); 541 log("debug", "New connection %s", tostring(client));
542 client:init();
532 if self.tls_direct then 543 if self.tls_direct then
533 client:starttls(self.tls_ctx); 544 client:starttls(self.tls_ctx);
534 else
535 client:init();
536 end 545 end
537 end 546 end
538 547
539 -- Initialization 548 -- Initialization
540 function interface:init() 549 function interface:init()
598 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) 607 local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
599 local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx); 608 local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
600 if not client.peername then 609 if not client.peername then
601 client.peername, client.peerport = addr, port; 610 client.peername, client.peerport = addr, port;
602 end 611 end
612 client:init();
603 if tls_ctx then 613 if tls_ctx then
604 client:starttls(tls_ctx); 614 client:starttls(tls_ctx);
605 else
606 client:init();
607 end 615 end
608 return client; 616 return client;
609 end 617 end
610 618
611 -- New outgoing TCP connection 619 -- New outgoing TCP connection
613 local conn, err = socket.tcp(); 621 local conn, err = socket.tcp();
614 if not conn then return conn, err; end 622 if not conn then return conn, err; end
615 conn:settimeout(0); 623 conn:settimeout(0);
616 conn:connect(addr, port); 624 conn:connect(addr, port);
617 local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx) 625 local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
626 client:init();
618 if tls_ctx then 627 if tls_ctx then
619 client:starttls(tls_ctx); 628 client:starttls(tls_ctx);
620 else
621 client:init();
622 end 629 end
623 return client, conn; 630 return client, conn;
624 end 631 end
625 632
626 local function watchfd(fd, onreadable, onwritable) 633 local function watchfd(fd, onreadable, onwritable)