File

spec/util_async_spec.lua @ 13652:a08065207ef0

net.server_epoll: Call :shutdown() on TLS sockets when supported Comment from Matthew: This fixes a potential issue where the Prosody process gets blocked on sockets waiting for them to close. Unlike non-TLS sockets, closing a TLS socket sends layer 7 data, and this can cause problems for sockets which are in the process of being cleaned up. This depends on LuaSec changes which are not yet upstream. From Martijn's original email: So first my analysis of luasec. in ssl.c the socket is put into blocking mode right before calling SSL_shutdown() inside meth_destroy(). My best guess to why this is is because meth_destroy is linked to the __close and __gc methods, which can't exactly be called multiple times and luasec does want to make sure that a tls session is shutdown as clean as possible. I can't say I disagree with this reasoning and don't want to change this behaviour. My solution to this without changing the current behaviour is to introduce a shutdown() method. I am aware that this overlaps in a conflicting way with tcp's shutdown method, but it stays close to the OpenSSL name. This method calls SSL_shutdown() in the current (non)blocking mode of the underlying socket and returns a boolean whether or not the shutdown is completed (matching SSL_shutdown()'s 0 or 1 return values), and returns the familiar ssl_ioerror() strings on error with a false for completion. This error can then be used to determine if we have wantread/wantwrite to finalize things. Once meth_shutdown() has been called once a shutdown flag will be set, which indicates to meth_destroy() that the SSL_shutdown() has been handled by the application and it shouldn't be needed to set the socket to blocking mode. I've left the SSL_shutdown() call in the LSEC_STATE_CONNECTED to prevent TOCTOU if the application reaches a timeout for the shutdown code, which might allow SSL_shutdown() to clean up anyway at the last possible moment. Another thing I've changed to luasec is the call to socket_setblocking() right before calling close(2) in socket_destroy() in usocket.c. According to the latest POSIX[0]: Note that the requirement for close() on a socket to block for up to the current linger interval is not conditional on the O_NONBLOCK setting. Which I read to mean that removing O_NONBLOCK on the socket before close doesn't impact the behaviour and only causes noise in system call tracers. I didn't touch the windows bits of this, since I don't do windows. For the prosody side of things I've made the TLS shutdown bits resemble interface:onwritable(), and put it under a combined guard of self._tls and self.conn.shutdown. The self._tls bit is there to prevent getting stuck on this condition, and self.conn.shutdown is there to prevent the code being called by instances where the patched luasec isn't deployed. The destroy() method can be called from various places and is read by me as the "we give up" error path. To accommodate for these unexpected entrypoints I've added a single call to self.conn:shutdown() to prevent the socket being put into blocking mode. I have no expectations that there is any other use here. Same as previous, the self.conn.shutdown check is there to make sure it's not called on unpatched luasec deployments and self._tls is there to make sure we don't call shutdown() on tcp sockets. I wouldn't recommend logging of the conn:shutdown() error inside close(), since a lot of clients simply close the connection before SSL_shutdown() is done.
author Martijn van Duren <martijn@openbsd.org>
date Thu, 06 Feb 2025 15:04:38 +0000
parent 11964:563ee7969f6c
line wrap: on
line source

local async = require "util.async";
local match = require "luassert.match";

describe("util.async", function()
	local debug = false;
	local print = print;
	if debug then
		require "util.logger".add_simple_sink(print);
	else
		print = function () end
	end

	local function mock_watchers(event_log)
		local function generic_logging_watcher(name)
			return function (...)
				table.insert(event_log, { name = name, n = select("#", ...)-1, select(2, ...) });
			end;
		end;
		return setmetatable(mock{
			ready = generic_logging_watcher("ready");
			waiting = generic_logging_watcher("waiting");
			error = generic_logging_watcher("error");
		}, {
			__index = function (_, event)
				-- Unexpected watcher called
				assert(false, "unexpected watcher called: "..event);
			end;
		})
	end

	local function new(func)
		local event_log = {};
		local spy_func = spy.new(func);
		return async.runner(spy_func, mock_watchers(event_log)), spy_func, event_log;
	end
	describe("#runner", function()
		it("should work", function()
			local r = new(function (item) assert(type(item) == "number") end);
			r:run(1);
			r:run(2);
		end);

		it("should be ready after creation", function ()
			local r = new(function () end);
			assert.equal(r.state, "ready");
		end);

		it("should do nothing if the queue is empty", function ()
			local did_run;
			local r = new(function () did_run = true end);
			r:run();
			assert.equal(r.state, "ready");
			assert.is_nil(did_run);
			r:run("hello");
			assert.is_true(did_run);
		end);

		it("should support queuing work items without running", function ()
			local did_run;
			local r = new(function () did_run = true end);
			r:enqueue("hello");
			assert.equal(r.state, "ready");
			assert.is_nil(did_run);
			r:run();
			assert.is_true(did_run);
		end);

		it("should support queuing multiple work items", function ()
			local last_item;
			local r, s = new(function (item) last_item = item; end);
			r:enqueue("hello");
			r:enqueue("there");
			r:enqueue("world");
			assert.equal(r.state, "ready");
			r:run();
			assert.equal(r.state, "ready");
			assert.spy(s).was.called(3);
			assert.equal(last_item, "world");
		end);

		it("should support all simple data types", function ()
			local last_item;
			local r, s = new(function (item) last_item = item; end);
			local values = { {}, 123, "hello", true, false };
			for i = 1, #values do
				r:enqueue(values[i]);
			end
			assert.equal(r.state, "ready");
			r:run();
			assert.equal(r.state, "ready");
			assert.spy(s).was.called(#values);
			for i = 1, #values do
				assert.spy(s).was.called_with(values[i]);
			end
			assert.equal(last_item, values[#values]);
		end);

		it("should work with no parameters", function ()
			local item = "fail";
			local r = async.runner();
			local f = spy.new(function () item = "success"; end);
			r:run(f);
			assert.spy(f).was.called();
			assert.equal(item, "success");
		end);

		it("supports a default error handler", function ()
			local item = "fail";
			local r = async.runner();
			local f = spy.new(function () error("test error"); end);
			assert.error_matches(function ()
				r:run(f);
			end, "test error");
			assert.spy(f).was.called();
			assert.equal(item, "fail");
		end);

		describe("#errors", function ()
			describe("should notify", function ()
				local last_processed_item, last_error;
				local r;
				r = async.runner(function (item)
					if item == "error" then
						error({ e = "test error" });
					end
					last_processed_item = item;
				end, mock{
					ready = function () end;
					waiting = function () end;
					error = function (runner, err)
						assert.equal(r, runner);
						last_error = err;
					end;
				});

				-- Simple item, no error
				r:run("hello");
				assert.equal(r.state, "ready");
				assert.equal(last_processed_item, "hello");
				assert.spy(r.watchers.ready).was_not.called();
				assert.spy(r.watchers.error).was_not.called();

				-- Trigger an error inside the runner
				assert.equal(last_error, nil);
				r:run("error");
				test("the correct watcher functions", function ()
					-- Only the error watcher should have been called
					assert.spy(r.watchers.ready).was_not.called();
					assert.spy(r.watchers.waiting).was_not.called();
					assert.spy(r.watchers.error).was.called(1);
				end);
				test("with the correct error", function ()
					-- The error watcher state should be correct, to
					-- demonstrate the error was passed correctly
					assert.is_table(last_error);
					assert.equal(last_error.e, "test error");
					last_error = nil;
				end);
				assert.equal(r.state, "ready");
				assert.equal(last_processed_item, "hello");
			end);

			do
				local last_processed_item, last_error;
				local r;
				local wait, done;
				r = async.runner(function (item)
					if item == "error" then
						error({ e = "test error" });
					elseif item == "wait" then
						wait, done = async.waiter();
						wait();
						error({ e = "post wait error" });
					end
					last_processed_item = item;
				end, mock({
					ready = function () end;
					waiting = function () end;
					error = function (runner, err)
						assert.equal(r, runner);
						last_error = err;
					end;
				}));

				randomize(false); --luacheck: ignore 113/randomize

				it("should not be fatal to the runner", function ()
					r:run("world");
					assert.equal(r.state, "ready");
					assert.spy(r.watchers.ready).was_not.called();
					assert.equal(last_processed_item, "world");
				end);
				it("should work despite a #waiter", function ()
					-- This test covers an important case where a runner
					-- throws an error while being executed outside of the
					-- main loop. This happens when it was blocked ('waiting'),
					-- and then released (via a call to done()).
					last_error = nil;
					r:run("wait");
					assert.equal(r.state, "waiting");
					assert.spy(r.watchers.waiting).was.called(1);
					done();
					-- At this point an error happens (state goes error->ready)
					assert.equal(r.state, "ready");
					assert.spy(r.watchers.error).was.called(1);
					assert.spy(r.watchers.ready).was.called(1);
					assert.is_table(last_error);
					assert.equal(last_error.e, "post wait error");
					last_error = nil;
					r:run("hello again");
					assert.spy(r.watchers.ready).was.called(1);
					assert.spy(r.watchers.waiting).was.called(1);
					assert.spy(r.watchers.error).was.called(1);
					assert.equal(r.state, "ready");
					assert.equal(last_processed_item, "hello again");
				end);
			end

			it("should continue to process work items", function ()
				local last_item;
				local runner, runner_func = new(function (item)
					if item == "error" then
						error("test error");
					end
					last_item = item;
				end);
				runner:enqueue("one");
				runner:enqueue("error");
				runner:enqueue("two");
				runner:run();
				assert.equal(runner.state, "ready");
				assert.spy(runner_func).was.called(3);
				assert.spy(runner.watchers.error).was.called(1);
				assert.spy(runner.watchers.ready).was.called(0);
				assert.spy(runner.watchers.waiting).was.called(0);
				assert.equal(last_item, "two");
			end);

			it("should continue to process work items during resume", function ()
				local wait, done, last_item;
				local runner, runner_func = new(function (item)
					if item == "wait-error" then
						wait, done = async.waiter();
						wait();
						error("test error");
					end
					last_item = item;
				end);
				runner:enqueue("one");
				runner:enqueue("wait-error");
				runner:enqueue("two");
				runner:run();
				done();
				assert.equal(runner.state, "ready");
				assert.spy(runner_func).was.called(3);
				assert.spy(runner.watchers.error).was.called(1);
				assert.spy(runner.watchers.waiting).was.called(1);
				assert.spy(runner.watchers.ready).was.called(1);
				assert.equal(last_item, "two");
			end);
		end);
	end);
	describe("#waiter", function()
		it("should error outside of async context", function ()
			assert.has_error(function ()
				async.waiter();
			end);
		end);
		it("should work", function ()
			local wait, done;

			local r = new(function (item)
				assert(type(item) == "number")
				if item == 3 then
					wait, done = async.waiter();
					wait();
				end
			end);

			r:run(1);
			assert(r.state == "ready");
			r:run(2);
			assert(r.state == "ready");
			r:run(3);
			assert(r.state == "waiting");
			done();
			assert(r.state == "ready");
			--for k, v in ipairs(l) do print(k,v) end
		end);

		it("should work", function ()
			--------------------
			local wait, done;
			local last_item = 0;
			local r = new(function (item)
				assert(type(item) == "number")
				assert(item == last_item + 1);
				last_item = item;
				if item == 3 then
					wait, done = async.waiter();
					wait();
				end
			end);

			r:run(1);
			assert(r.state == "ready");
			r:run(2);
			assert(r.state == "ready");
			r:run(3);
			assert(r.state == "waiting");
			r:run(4);
			assert(r.state == "waiting");
			done();
			assert(r.state == "ready");
			--for k, v in ipairs(l) do print(k,v) end
		end);
		it("should work", function ()
			--------------------
			local wait, done;
			local last_item = 0;
			local r = new(function (item)
				assert(type(item) == "number")
				assert((item == last_item + 1) or item == 3);
				last_item = item;
				if item == 3 then
					wait, done = async.waiter();
					wait();
				end
			end);

			r:run(1);
			assert(r.state == "ready");
			r:run(2);
			assert(r.state == "ready");

			r:run(3);
			assert(r.state == "waiting");
			r:run(3);
			assert(r.state == "waiting");
			r:run(3);
			assert(r.state == "waiting");
			r:run(4);
			assert(r.state == "waiting");

			for i = 1, 3 do
				done();
				if i < 3 then
					assert(r.state == "waiting");
				end
			end

			assert(r.state == "ready");
			--for k, v in ipairs(l) do print(k,v) end
		end);
		it("should work", function ()
			--------------------
			local wait, done;
			local last_item = 0;
			local r = new(function (item)
				assert(type(item) == "number")
				assert((item == last_item + 1) or item == 3);
				last_item = item;
				if item == 3 then
					wait, done = async.waiter();
					wait();
				end
			end);

			r:run(1);
			assert(r.state == "ready");
			r:run(2);
			assert(r.state == "ready");

			r:run(3);
			assert(r.state == "waiting");
			r:run(3);
			assert(r.state == "waiting");

			for i = 1, 2 do
				done();
				if i < 2 then
					assert(r.state == "waiting");
				end
			end

			assert(r.state == "ready");
			r:run(4);
			assert(r.state == "ready");

			assert(r.state == "ready");
			--for k, v in ipairs(l) do print(k,v) end
		end);
		it("should work with multiple runners in parallel", function ()
			-- Now with multiple runners
			--------------------
			local wait1, done1;
			local last_item1 = 0;
			local r1 = new(function (item)
				assert(type(item) == "number")
				assert((item == last_item1 + 1) or item == 3);
				last_item1 = item;
				if item == 3 then
					wait1, done1 = async.waiter();
					wait1();
				end
			end, "r1");

			local wait2, done2;
			local last_item2 = 0;
			local r2 = new(function (item)
				assert(type(item) == "number")
				assert((item == last_item2 + 1) or item == 3);
				last_item2 = item;
				if item == 3 then
					wait2, done2 = async.waiter();
					wait2();
				end
			end, "r2");

			r1:run(1);
			assert(r1.state == "ready");
			r1:run(2);
			assert(r1.state == "ready");

			r1:run(3);
			assert(r1.state == "waiting");
			r1:run(3);
			assert(r1.state == "waiting");

			r2:run(1);
			assert(r1.state == "waiting");
			assert(r2.state == "ready");

			r2:run(2);
			assert(r1.state == "waiting");
			assert(r2.state == "ready");

			r2:run(3);
			assert(r1.state == "waiting");
			assert(r2.state == "waiting");
			done2();

			r2:run(3);
			assert(r1.state == "waiting");
			assert(r2.state == "waiting");
			done2();

			r2:run(4);
			assert(r1.state == "waiting");
			assert(r2.state == "ready");

			for i = 1, 2 do
				done1();
				if i < 2 then
					assert(r1.state == "waiting");
				end
			end

			assert(r1.state == "ready");
			r1:run(4);
			assert(r1.state == "ready");

			assert(r1.state == "ready");
			--for k, v in ipairs(l1) do print(k,v) end
		end);
		it("should work work with multiple runners in parallel", function ()
			--------------------
			local wait1, done1;
			local last_item1 = 0;
			local r1 = new(function (item)
				print("r1 processing ", item);
				assert(type(item) == "number")
				assert((item == last_item1 + 1) or item == 3);
				last_item1 = item;
				if item == 3 then
					wait1, done1 = async.waiter();
					wait1();
				end
			end, "r1");

			local wait2, done2;
			local last_item2 = 0;
			local r2 = new(function (item)
				print("r2 processing ", item);
				assert.is_number(item);
				assert((item == last_item2 + 1) or item == 3);
				last_item2 = item;
				if item == 3 then
					wait2, done2 = async.waiter();
					wait2();
				end
			end, "r2");

			r1:run(1);
			assert.equal(r1.state, "ready");
			r1:run(2);
			assert.equal(r1.state, "ready");

			r1:run(5);
			assert.equal(r1.state, "ready");

			r1:run(3);
			assert.equal(r1.state, "waiting");
			r1:run(5); -- Will error, when we get to it
			assert.equal(r1.state, "waiting");
			done1();
			assert.equal(r1.state, "ready");
			r1:run(3);
			assert.equal(r1.state, "waiting");

			r2:run(1);
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "ready");

			r2:run(2);
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "ready");

			r2:run(3);
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "waiting");

			done2();
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "ready");

			r2:run(3);
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "waiting");

			done2();
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "ready");

			r2:run(4);
			assert.equal(r1.state, "waiting");
			assert.equal(r2.state, "ready");

			done1();

			assert.equal(r1.state, "ready");
			r1:run(4);
			assert.equal(r1.state, "ready");

			assert.equal(r1.state, "ready");
		end);

		-- luacheck: ignore 211/rf
		-- FIXME what's rf?
		it("should support multiple done() calls", function ()
			local processed_item;
			local wait, done;
			local r, rf = new(function (item)
				wait, done = async.waiter(4);
				wait();
				processed_item = item;
			end);
			r:run("test");
			for _ = 1, 3 do
				done();
				assert.equal(r.state, "waiting");
				assert.is_nil(processed_item);
			end
			done();
			assert.equal(r.state, "ready");
			assert.equal(processed_item, "test");
			assert.spy(r.watchers.error).was_not.called();
		end);

		it("should not allow done() to be called more than specified", function ()
			local processed_item;
			local wait, done;
			local r, rf = new(function (item)
				wait, done = async.waiter(4);
				wait();
				processed_item = item;
			end);
			r:run("test");
			for _ = 1, 4 do
				done();
			end
			assert.has_error(done);
			assert.equal(r.state, "ready");
			assert.equal(processed_item, "test");
			assert.spy(r.watchers.error).was_not.called();
		end);

		it("should allow done() to be called before wait()", function ()
			local processed_item;
			local r, rf = new(function (item)
				local wait, done = async.waiter();
				done();
				wait();
				processed_item = item;
			end);
			r:run("test");
			assert.equal(processed_item, "test");
			assert.equal(r.state, "ready");
			-- Since the observable state did not change,
			-- the watchers should not have been called
			assert.spy(r.watchers.waiting).was_not.called();
			assert.spy(r.watchers.ready).was_not.called();
		end);
	end);

	describe("#ready()", function ()
		it("should return false outside an async context", function ()
			assert.falsy(async.ready());
		end);
		it("should return true inside an async context", function ()
			local r = new(function ()
				assert.truthy(async.ready());
			end);
			r:run(true);
			assert.spy(r.func).was.called();
			assert.spy(r.watchers.error).was_not.called();
		end);
	end);

	describe("#sleep()", function ()
		after_each(function ()
			-- Restore to default
			async.set_schedule_function(nil);
		end);

		it("should fail if no scheduler configured", function ()
			local r = new(function ()
				async.sleep(5);
			end);
			r:run(true);
			assert.spy(r.watchers.error).was.called();

			-- Set dummy scheduler
			async.set_schedule_function(function () end);

			local r2 = new(function ()
				async.sleep(5);
			end);
			r2:run(true);
			assert.spy(r2.watchers.error).was_not.called();
		end);
		it("should work", function ()
			local queue = {};
			local add_task = spy.new(function (t, f)
				table.insert(queue, { t, f });
			end);
			async.set_schedule_function(add_task);

			local processed_item;
			local r = new(function (item)
				async.sleep(5);
				processed_item = item;
			end);
			r:run("test");

			-- Nothing happened, because the runner is sleeping
			assert.is_nil(processed_item);
			assert.equal(r.state, "waiting");
			assert.spy(add_task).was_called(1);
			assert.spy(add_task).was_called_with(match.is_number(), match.is_function());
			assert.spy(r.watchers.waiting).was.called();
			assert.spy(r.watchers.ready).was_not.called();

			-- Pretend the timer has triggered, call the handler
			queue[1][2]();

			assert.equal(processed_item, "test");
			assert.equal(r.state, "ready");

			assert.spy(r.watchers.ready).was.called();
		end);
	end);

	describe("#set_nexttick()", function ()
		after_each(function ()
			-- Restore to default
			async.set_nexttick(nil);
		end);
		it("should work", function ()
			local queue = {};
			local nexttick = spy.new(function (f)
				assert.is_function(f);
				table.insert(queue, f);
			end);
			async.set_nexttick(nexttick);

			local processed_item;
			local wait, done;
			local r = new(function (item)
				wait, done = async.waiter();
				wait();
				processed_item = item;
			end);
			r:run("test");

			-- Nothing happened, because the runner is waiting
			assert.is_nil(processed_item);
			assert.equal(r.state, "waiting");
			assert.spy(nexttick).was_called(0);
			assert.spy(r.watchers.waiting).was.called();
			assert.spy(r.watchers.ready).was_not.called();

			-- Mark the runner as ready, it should be scheduled for
			-- the next tick
			done();

			assert.spy(nexttick).was_called(1);
			assert.spy(nexttick).was_called_with(match.is_function());
			assert.equal(1, #queue);

			-- Pretend it's the next tick - call the pending function
			queue[1]();

			assert.equal(processed_item, "test");
			assert.equal(r.state, "ready");
			assert.spy(r.watchers.ready).was.called();
		end);
	end);
end);