Changeset

11117:590ac42d81c5 0.11

Merge
author Matthew Wild <mwild1@gmail.com>
date Wed, 30 Sep 2020 09:46:30 +0100
parents 11115:7d4c292f178e (current diff) 11113:10301c214f4e (diff)
children 11118:ece430d49809
files
diffstat 5 files changed, 434 insertions(+), 85 deletions(-) [+]
line wrap: on
line diff
--- a/net/websocket/frames.lua	Tue Sep 29 21:27:16 2020 -0500
+++ b/net/websocket/frames.lua	Wed Sep 30 09:46:30 2020 +0100
@@ -16,11 +16,10 @@
 local bxor = bit.bxor;
 local lshift = bit.lshift;
 local rshift = bit.rshift;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
 
 local t_concat = table.concat;
-local s_byte = string.byte;
 local s_char= string.char;
-local s_sub = string.sub;
 local s_pack = string.pack; -- luacheck: ignore 143
 local s_unpack = string.unpack; -- luacheck: ignore 143
 
@@ -30,12 +29,12 @@
 end
 
 local function read_uint16be(str, pos)
-	local l1, l2 = s_byte(str, pos, pos+1);
+	local l1, l2 = str:byte(pos, pos+1);
 	return l1*256 + l2;
 end
 -- FIXME: this may lose precision
 local function read_uint64be(str, pos)
-	local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
+	local l1, l2, l3, l4, l5, l6, l7, l8 = str:byte(pos, pos+7);
 	local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
 	local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
 	return h * 2^32 + l;
@@ -63,9 +62,15 @@
 
 if s_unpack then
 	function read_uint16be(str, pos)
+		if type(str) ~= "string" then
+			str, pos = str:sub(pos, pos+1), 1;
+		end
 		return s_unpack(">I2", str, pos);
 	end
 	function read_uint64be(str, pos)
+		if type(str) ~= "string" then
+			str, pos = str:sub(pos, pos+7), 1;
+		end
 		return s_unpack(">I8", str, pos);
 	end
 end
@@ -73,7 +78,7 @@
 local function parse_frame_header(frame)
 	if #frame < 2 then return; end
 
-	local byte1, byte2 = s_byte(frame, 1, 2);
+	local byte1, byte2 = frame:byte(1, 2);
 	local result = {
 		FIN = band(byte1, 0x80) > 0;
 		RSV1 = band(byte1, 0x40) > 0;
@@ -102,7 +107,7 @@
 	end
 
 	if result.MASK then
-		result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
+		result.key = { frame:byte(length_bytes+3, length_bytes+6) };
 	end
 
 	return result, header_length;
@@ -121,7 +126,7 @@
 	for i = from, to do
 		local key_index = counter%key_len + 1;
 		counter = counter + 1;
-		data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
+		data[counter] = s_char(bxor(key[key_index], str:byte(i)));
 	end
 	return t_concat(data);
 end
@@ -136,7 +141,7 @@
 
 local function parse_frame(frame)
 	local result, pos = parse_frame_header(frame);
-	if result == nil or #frame < (pos + result.length) then return; end
+	if result == nil or #frame < (pos + result.length) then return nil, nil, result; end
 	result.data = parse_frame_body(frame, result, pos+1);
 	return result, pos + result.length;
 end
@@ -189,7 +194,7 @@
 	if #data >= 2 then
 		code = read_uint16be(data, 1);
 		if #data > 2 then
-			message = s_sub(data, 3);
+			message = data:sub(3);
 		end
 	end
 	return code, message
--- a/plugins/mod_websocket.lua	Tue Sep 29 21:27:16 2020 -0500
+++ b/plugins/mod_websocket.lua	Wed Sep 30 09:46:30 2020 +0100
@@ -18,6 +18,7 @@
 local portmanager = require "core.portmanager";
 local sm_destroy_session = require"core.sessionmanager".destroy_session;
 local log = module._log;
+local dbuffer = require "util.dbuffer";
 
 local websocket_frames = require"net.websocket.frames";
 local parse_frame = websocket_frames.parse;
@@ -27,6 +28,9 @@
 
 local t_concat = table.concat;
 
+local stanza_size_limit = module:get_option_number("c2s_stanza_size_limit", 10 * 1024 * 1024);
+local frame_buffer_limit = module:get_option_number("websocket_frame_buffer_limit", 2 * stanza_size_limit);
+local frame_fragment_limit = module:get_option_number("websocket_frame_fragment_limit", 8);
 local stream_close_timeout = module:get_option_number("c2s_close_timeout", 5);
 local consider_websocket_secure = module:get_option_boolean("consider_websocket_secure");
 local cross_domain = module:get_option_set("cross_domain_websocket", {});
@@ -138,6 +142,65 @@
 
 	return data;
 end
+
+local function validate_frame(frame, max_length)
+	local opcode, length = frame.opcode, frame.length;
+
+	if max_length and length > max_length then
+		return false, 1009, "Payload too large";
+	end
+
+	-- Error cases
+	if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero
+		return false, 1002, "Reserved bits not zero";
+	end
+
+	if opcode == 0x8 and frame.data then -- close frame
+		if length == 1 then
+			return false, 1002, "Close frame with payload, but too short for status code";
+		elseif length >= 2 then
+			local status_code = parse_close(frame.data)
+			if status_code < 1000 then
+				return false, 1002, "Closed with invalid status code";
+			elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then
+				return false, 1002, "Closed with reserved status code";
+			end
+		end
+	end
+
+	if opcode >= 0x8 then
+		if length > 125 then -- Control frame with too much payload
+			return false, 1002, "Payload too large";
+		end
+
+		if not frame.FIN then -- Fragmented control frame
+			return false, 1002, "Fragmented control frame";
+		end
+	end
+
+	if (opcode > 0x2 and opcode < 0x8) or (opcode > 0xA) then
+		return false, 1002, "Reserved opcode";
+	end
+
+	-- Check opcode
+	if opcode == 0x2 then -- Binary frame
+		return false, 1003, "Only text frames are supported, RFC 7395 3.2";
+	elseif opcode == 0x8 then -- Close request
+		return false, 1000, "Goodbye";
+	end
+
+	-- Other (XMPP-specific) validity checks
+	if not frame.FIN then
+		return false, 1003, "Continuation frames are not supported, RFC 7395 3.3.3";
+	end
+	if opcode == 0x01 and frame.data and frame.data:byte(1, 1) ~= 60 then
+		return false, 1007, "Invalid payload start character, RFC 7395 3.3.3";
+	end
+
+	return true;
+end
+
+
 function handle_request(event)
 	local request, response = event.request, event.response;
 	local conn = response.conn;
@@ -168,90 +231,40 @@
 		conn:close();
 	end
 
-	local dataBuffer;
+	local function websocket_handle_error(session, code, message)
+		if code == 1009 then -- stanza size limit exceeded
+			-- we close the session, rather than the connection,
+			-- otherwise a resuming client will simply resend the
+			-- offending stanza
+			session:close({ condition = "policy-violation", text = "stanza too large" });
+		else
+			websocket_close(code, message);
+		end
+	end
+
 	local function handle_frame(frame)
-		local opcode = frame.opcode;
-		local length = frame.length;
 		module:log("debug", "Websocket received frame: opcode=%0x, %i bytes", frame.opcode, #frame.data);
 
-		-- Error cases
-		if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero
-			websocket_close(1002, "Reserved bits not zero");
-			return false;
-		end
-
-		if opcode == 0x8 then -- close frame
-			if length == 1 then
-				websocket_close(1002, "Close frame with payload, but too short for status code");
-				return false;
-			elseif length >= 2 then
-				local status_code = parse_close(frame.data)
-				if status_code < 1000 then
-					websocket_close(1002, "Closed with invalid status code");
-					return false;
-				elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then
-					websocket_close(1002, "Closed with reserved status code");
-					return false;
-				end
-			end
+		-- Check frame makes sense
+		local frame_ok, err_status, err_text = validate_frame(frame, stanza_size_limit);
+		if not frame_ok then
+			return frame_ok, err_status, err_text;
 		end
 
-		if opcode >= 0x8 then
-			if length > 125 then -- Control frame with too much payload
-				websocket_close(1002, "Payload too large");
-				return false;
-			end
-
-			if not frame.FIN then -- Fragmented control frame
-				websocket_close(1002, "Fragmented control frame");
-				return false;
-			end
-		end
-
-		if (opcode > 0x2 and opcode < 0x8) or (opcode > 0xA) then
-			websocket_close(1002, "Reserved opcode");
-			return false;
-		end
-
-		if opcode == 0x0 and not dataBuffer then
-			websocket_close(1002, "Unexpected continuation frame");
-			return false;
-		end
-
-		if (opcode == 0x1 or opcode == 0x2) and dataBuffer then
-			websocket_close(1002, "Continuation frame expected");
-			return false;
-		end
-
-		-- Valid cases
-		if opcode == 0x0 then -- Continuation frame
-			dataBuffer[#dataBuffer+1] = frame.data;
-		elseif opcode == 0x1 then -- Text frame
-			dataBuffer = {frame.data};
-		elseif opcode == 0x2 then -- Binary frame
-			websocket_close(1003, "Only text frames are supported");
-			return;
-		elseif opcode == 0x8 then -- Close request
-			websocket_close(1000, "Goodbye");
-			return;
-		elseif opcode == 0x9 then -- Ping frame
+		local opcode = frame.opcode;
+		if opcode == 0x9 then -- Ping frame
 			frame.opcode = 0xA;
 			frame.MASK = false; -- Clients send masked frames, servers don't, see #1484
 			conn:write(build_frame(frame));
 			return "";
 		elseif opcode == 0xA then -- Pong frame, MAY be sent unsolicited, eg as keepalive
 			return "";
-		else
+		elseif opcode ~= 0x1 then -- Not text frame (which is all we support)
 			log("warn", "Received frame with unsupported opcode %i", opcode);
 			return "";
 		end
 
-		if frame.FIN then
-			local data = t_concat(dataBuffer, "");
-			dataBuffer = nil;
-			return data;
-		end
-		return "";
+		return frame.data;
 	end
 
 	conn:setlistener(c2s_listener);
@@ -269,19 +282,37 @@
 	session.open_stream = session_open_stream;
 	session.close = session_close;
 
-	local frameBuffer = "";
+	local frameBuffer = dbuffer.new(frame_buffer_limit, frame_fragment_limit);
 	add_filter(session, "bytes/in", function(data)
+		if not frameBuffer:write(data) then
+			session.log("warn", "websocket frame buffer full - terminating session");
+			session:close({ condition = "resource-constraint", text = "frame buffer exceeded" });
+			return;
+		end
+
 		local cache = {};
-		frameBuffer = frameBuffer .. data;
-		local frame, length = parse_frame(frameBuffer);
+		local frame, length, partial = parse_frame(frameBuffer);
 
 		while frame do
-			frameBuffer = frameBuffer:sub(length + 1);
-			local result = handle_frame(frame);
-			if not result then return; end
+			frameBuffer:discard(length);
+			local result, err_status, err_text = handle_frame(frame);
+			if not result then
+				websocket_handle_error(session, err_status, err_text);
+				break;
+			end
 			cache[#cache+1] = filter_open_close(result);
-			frame, length = parse_frame(frameBuffer);
+			frame, length, partial = parse_frame(frameBuffer);
 		end
+
+		if partial then
+			-- The header of the next frame is already in the buffer, run
+			-- some early validation here
+			local frame_ok, err_status, err_text = validate_frame(partial, stanza_size_limit);
+			if not frame_ok then
+				websocket_handle_error(session, err_status, err_text);
+			end
+		end
+
 		return t_concat(cache, "");
 	end);
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/spec/util_dbuffer_spec.lua	Wed Sep 30 09:46:30 2020 +0100
@@ -0,0 +1,130 @@
+local dbuffer = require "util.dbuffer";
+describe("util.dbuffer", function ()
+	describe("#new", function ()
+		it("has a constructor", function ()
+			assert.Function(dbuffer.new);
+		end);
+		it("can be created", function ()
+			assert.truthy(dbuffer.new());
+		end);
+		it("won't create an empty buffer", function ()
+			assert.falsy(dbuffer.new(0));
+		end);
+		it("won't create a negatively sized buffer", function ()
+			assert.falsy(dbuffer.new(-1));
+		end);
+	end);
+	describe(":write", function ()
+		local b = dbuffer.new();
+		it("works", function ()
+			assert.truthy(b:write("hi"));
+		end);
+	end);
+
+	describe(":read", function ()
+		it("supports optional bytes parameter", function ()
+			-- should return the frontmost chunk
+			local b = dbuffer.new();
+			assert.truthy(b:write("hello"));
+			assert.truthy(b:write(" "));
+			assert.truthy(b:write("world"));
+			assert.equal("h", b:read(1));
+
+			assert.equal("ello", b:read());
+			assert.equal(" ", b:read());
+			assert.equal("world", b:read());
+		end);
+	end);
+
+	describe(":discard", function ()
+		local b = dbuffer.new();
+		it("works", function ()
+			assert.truthy(b:write("hello world"));
+			assert.truthy(b:discard(6));
+			assert.equal(5, b:length());
+			assert.equal("world", b:read(5));
+		end);
+	end);
+
+	describe(":collapse()", function ()
+		it("works on an empty buffer", function ()
+			local b = dbuffer.new();
+			b:collapse();
+		end);
+	end);
+
+	describe(":sub", function ()
+		-- Helper function to compare buffer:sub() with string:sub()
+		local s = "hello world";
+		local function test_sub(b, x, y)
+			local string_result, buffer_result = s:sub(x, y), b:sub(x, y);
+			assert.equals(string_result, buffer_result, ("buffer:sub(%d, %s) does not match string:sub()"):format(x, y and ("%d"):format(y) or "nil"));
+		end
+
+		it("works", function ()
+			local b = dbuffer.new();
+			assert.truthy(b:write("hello world"));
+			assert.equals("hello", b:sub(1, 5));
+		end);
+
+		it("works after discard", function ()
+			local b = dbuffer.new(256);
+			assert.truthy(b:write("foobar"));
+			assert.equals("foobar", b:sub(1, 6));
+			assert.truthy(b:discard(3)); -- consume "foo"
+			assert.equals("bar", b:sub(1, 3));
+		end);
+
+		it("supports optional end parameter", function ()
+			local b = dbuffer.new();
+			assert.truthy(b:write("hello world"));
+			assert.equals("hello world", b:sub(1));
+			assert.equals("world", b:sub(-5));
+		end);
+
+		it("is equivalent to string:sub", function ()
+			local b = dbuffer.new(11);
+			assert.truthy(b:write(s));
+			for i = -13, 13 do
+				for j = -13, 13 do
+					test_sub(b, i, j);
+				end
+			end
+		end);
+	end);
+
+	describe(":byte", function ()
+		-- Helper function to compare buffer:byte() with string:byte()
+		local s = "hello world"
+		local function test_byte(b, x, y)
+			local string_result, buffer_result = {s:byte(x, y)}, {b:byte(x, y)};
+			assert.same(string_result, buffer_result, ("buffer:byte(%d, %s) does not match string:byte()"):format(x, y and ("%d"):format(y) or "nil"));
+		end
+
+		it("is equivalent to string:byte", function ()
+			local b = dbuffer.new(11);
+			assert.truthy(b:write(s));
+			test_byte(b, 1);
+			test_byte(b, 3);
+			test_byte(b, -1);
+			test_byte(b, -3);
+			for i = -13, 13 do
+				for j = -13, 13 do
+					test_byte(b, i, j);
+				end
+			end
+		end);
+
+		it("works with characters > 127", function ()
+			local b = dbuffer.new();
+			b:write(string.char(0, 140));
+			local r = { b:byte(1, 2) };
+			assert.same({ 0, 140 }, r);
+		end);
+
+		it("works on an empty buffer", function ()
+			local b = dbuffer.new();
+			assert.equal("", b:sub(1,1));
+		end);
+	end);
+end);
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/util/dbuffer.lua	Wed Sep 30 09:46:30 2020 +0100
@@ -0,0 +1,176 @@
+local queue = require "util.queue";
+
+local dbuffer_methods = {};
+local dynamic_buffer_mt = { __index = dbuffer_methods };
+
+function dbuffer_methods:write(data)
+	if self.max_size and #data + self._length > self.max_size then
+		return nil;
+	end
+	local ok = self.items:push(data);
+	if not ok then
+		self:collapse();
+		ok = self.items:push(data);
+	end
+	if not ok then
+		return nil;
+	end
+	self._length = self._length + #data;
+	return true;
+end
+
+function dbuffer_methods:read_chunk(requested_bytes)
+	local chunk, consumed = self.items:peek(), self.front_consumed;
+	if not chunk then return; end
+	local chunk_length = #chunk;
+	local remaining_chunk_length = chunk_length - consumed;
+	if not requested_bytes then
+		requested_bytes = remaining_chunk_length;
+	end
+	if remaining_chunk_length <= requested_bytes then
+		self.front_consumed = 0;
+		self._length = self._length - remaining_chunk_length;
+		self.items:pop();
+		assert(#chunk:sub(consumed + 1, -1) == remaining_chunk_length);
+		return chunk:sub(consumed + 1, -1), remaining_chunk_length;
+	end
+	local end_pos = consumed + requested_bytes;
+	self.front_consumed = end_pos;
+	self._length = self._length - requested_bytes;
+	assert(#chunk:sub(consumed + 1, end_pos) == requested_bytes);
+	return chunk:sub(consumed + 1, end_pos), requested_bytes;
+end
+
+function dbuffer_methods:read(requested_bytes)
+	local chunks;
+
+	if requested_bytes and requested_bytes > self._length then
+		return nil;
+	end
+
+	local chunk, read_bytes = self:read_chunk(requested_bytes);
+	if not requested_bytes then
+		return chunk;
+	elseif chunk then
+		requested_bytes = requested_bytes - read_bytes;
+		if requested_bytes == 0 then -- Already read everything we need
+			return chunk;
+		end
+		chunks = {};
+	else
+		return nil;
+	end
+
+	-- Need to keep reading more chunks
+	while chunk do
+		table.insert(chunks, chunk);
+		if requested_bytes > 0 then
+			chunk, read_bytes = self:read_chunk(requested_bytes);
+			requested_bytes = requested_bytes - read_bytes;
+		else
+			break;
+		end
+	end
+
+	return table.concat(chunks);
+end
+
+function dbuffer_methods:discard(requested_bytes)
+	if requested_bytes > self._length then
+		return nil;
+	end
+
+	local chunk, read_bytes = self:read_chunk(requested_bytes);
+	if chunk then
+		requested_bytes = requested_bytes - read_bytes;
+		if requested_bytes == 0 then -- Already read everything we need
+			return true;
+		end
+	else
+		return nil;
+	end
+
+	while chunk do
+		if requested_bytes > 0 then
+			chunk, read_bytes = self:read_chunk(requested_bytes);
+			requested_bytes = requested_bytes - read_bytes;
+		else
+			break;
+		end
+	end
+	return true;
+end
+
+function dbuffer_methods:sub(i, j)
+	if j == nil then
+		j = -1;
+	end
+	if j < 0 then
+		j = self._length + (j+1);
+	end
+	if i < 0 then
+		i = self._length + (i+1);
+	end
+	if i < 1 then
+		i = 1;
+	end
+	if j > self._length then
+		j = self._length;
+	end
+	if i > j then
+		return "";
+	end
+
+	self:collapse(j);
+
+	return self.items:peek():sub(self.front_consumed+1):sub(i, j);
+end
+
+function dbuffer_methods:byte(i, j)
+	i = i or 1;
+	j = j or i;
+	return string.byte(self:sub(i, j), 1, -1);
+end
+
+function dbuffer_methods:length()
+	return self._length;
+end
+dynamic_buffer_mt.__len = dbuffer_methods.length; -- support # operator
+
+function dbuffer_methods:collapse(bytes)
+	bytes = bytes or self._length;
+
+	local front_chunk = self.items:peek();
+
+	if not front_chunk or #front_chunk - self.front_consumed >= bytes then
+		return;
+	end
+
+	local front_chunks = { front_chunk:sub(self.front_consumed+1) };
+	local front_bytes = #front_chunks[1];
+
+	while front_bytes < bytes do
+		self.items:pop();
+		local chunk = self.items:peek();
+		front_bytes = front_bytes + #chunk;
+		table.insert(front_chunks, chunk);
+	end
+	self.items:replace(table.concat(front_chunks));
+	self.front_consumed = 0;
+end
+
+local function new(max_size, max_chunks)
+	if max_size and max_size <= 0 then
+		return nil;
+	end
+	return setmetatable({
+		front_consumed = 0;
+		_length = 0;
+		max_size = max_size;
+		items = queue.new(max_chunks or 32);
+	}, dynamic_buffer_mt);
+end
+
+return {
+	new = new;
+};
--- a/util/queue.lua	Tue Sep 29 21:27:16 2020 -0500
+++ b/util/queue.lua	Wed Sep 30 09:46:30 2020 +0100
@@ -51,6 +51,13 @@
 			end
 			return t[tail];
 		end;
+		replace = function (self, data)
+			if items == 0 then
+				return self:push(data);
+			end
+			t[tail] = data;
+			return true;
+		end;
 		items = function (self)
 			--luacheck: ignore 431/t
 			return function (t, pos)