Diff

net/websocket/frames.lua @ 11106:76f46c2579a2 0.11

net.websocket.frames: Allow all methods to work on non-string objects Instead of using the string library, use methods from the passed object, which are assumed to be equivalent. This provides compatibility with objects from util.ringbuffer and util.dbuffer, for example.
author Matthew Wild <mwild1@gmail.com>
date Thu, 17 Sep 2020 13:00:19 +0100
parent 8728:41c959c5c84b
child 11107:ddd0007e0f1b
line wrap: on
line diff
--- a/net/websocket/frames.lua	Mon Aug 24 17:28:48 2020 +0200
+++ b/net/websocket/frames.lua	Thu Sep 17 13:00:19 2020 +0100
@@ -16,13 +16,12 @@
 local bxor = bit.bxor;
 local lshift = bit.lshift;
 local rshift = bit.rshift;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
 
 local t_concat = table.concat;
-local s_byte = string.byte;
 local s_char= string.char;
-local s_sub = string.sub;
-local s_pack = string.pack; -- luacheck: ignore 143
-local s_unpack = string.unpack; -- luacheck: ignore 143
+local s_pack = string.pack;
+local s_unpack = string.unpack;
 
 if not s_pack and softreq"struct" then
 	s_pack = softreq"struct".pack;
@@ -30,12 +29,12 @@
 end
 
 local function read_uint16be(str, pos)
-	local l1, l2 = s_byte(str, pos, pos+1);
+	local l1, l2 = str:byte(pos, pos+1);
 	return l1*256 + l2;
 end
 -- FIXME: this may lose precision
 local function read_uint64be(str, pos)
-	local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
+	local l1, l2, l3, l4, l5, l6, l7, l8 = str:byte(pos, pos+7);
 	local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
 	local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
 	return h * 2^32 + l;
@@ -63,9 +62,15 @@
 
 if s_unpack then
 	function read_uint16be(str, pos)
+		if type(str) ~= "string" then
+			str, pos = str:sub(pos, pos+1), 1;
+		end
 		return s_unpack(">I2", str, pos);
 	end
 	function read_uint64be(str, pos)
+		if type(str) ~= "string" then
+			str, pos = str:sub(pos, pos+7), 1;
+		end
 		return s_unpack(">I8", str, pos);
 	end
 end
@@ -73,7 +78,7 @@
 local function parse_frame_header(frame)
 	if #frame < 2 then return; end
 
-	local byte1, byte2 = s_byte(frame, 1, 2);
+	local byte1, byte2 = frame:byte(1, 2);
 	local result = {
 		FIN = band(byte1, 0x80) > 0;
 		RSV1 = band(byte1, 0x40) > 0;
@@ -102,7 +107,7 @@
 	end
 
 	if result.MASK then
-		result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
+		result.key = { frame:byte(length_bytes+3, length_bytes+6) };
 	end
 
 	return result, header_length;
@@ -121,7 +126,7 @@
 	for i = from, to do
 		local key_index = counter%key_len + 1;
 		counter = counter + 1;
-		data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
+		data[counter] = s_char(bxor(key[key_index], str:byte(i)));
 	end
 	return t_concat(data);
 end
@@ -189,7 +194,7 @@
 	if #data >= 2 then
 		code = read_uint16be(data, 1);
 		if #data > 2 then
-			message = s_sub(data, 3);
+			message = data:sub(3);
 		end
 	end
 	return code, message