Software /
code /
prosody
File
net/websocket/frames.lua @ 10066:216ae100c04a
mod_pep: Only log when creating new pubsub services
Once upon a time get_pep_service() would get called with random bare
JIDs and remote hostnames, which is why it was logged this way. This
seems to have been fixed, so it's not as useful anymore. It's still
useful to know when it creates a new service object.
author | Kim Alvefur <zash@zash.se> |
---|---|
date | Tue, 09 Jul 2019 15:12:32 +0200 |
parent | 9692:affcbccc1dff |
child | 10241:48f7cda4174d |
line wrap: on
line source
-- Prosody IM -- Copyright (C) 2012 Florian Zeitz -- Copyright (C) 2014 Daurnimator -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local softreq = require "util.dependencies".softreq; local random_bytes = require "util.random".bytes; local bit = assert(softreq"bit32" or softreq"bit", "No bit module found. See https://prosody.im/doc/depends#bitop"); local band = bit.band; local bor = bit.bor; local bxor = bit.bxor; local lshift = bit.lshift; local rshift = bit.rshift; local unpack = table.unpack or unpack; -- luacheck: ignore 113 local t_concat = table.concat; local s_byte = string.byte; local s_char= string.char; local s_sub = string.sub; local s_pack = string.pack; local s_unpack = string.unpack; if not s_pack and softreq"struct" then s_pack = softreq"struct".pack; s_unpack = softreq"struct".unpack; end local function read_uint16be(str, pos) local l1, l2 = s_byte(str, pos, pos+1); return l1*256 + l2; end -- FIXME: this may lose precision local function read_uint64be(str, pos) local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4; local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8; return h * 2^32 + l; end local function pack_uint16be(x) return s_char(rshift(x, 8), band(x, 0xFF)); end local function get_byte(x, n) return band(rshift(x, n), 0xFF); end local function pack_uint64be(x) local h = band(x / 2^32, 2^32-1); return s_char(get_byte(h, 24), get_byte(h, 16), get_byte(h, 8), band(h, 0xFF), get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF)); end if s_pack then function pack_uint16be(x) return s_pack(">I2", x); end function pack_uint64be(x) return s_pack(">I8", x); end end if s_unpack then function read_uint16be(str, pos) return s_unpack(">I2", str, pos); end function read_uint64be(str, pos) return s_unpack(">I8", str, pos); end end local function parse_frame_header(frame) if #frame < 2 then return; end local byte1, byte2 = s_byte(frame, 1, 2); local result = { FIN = band(byte1, 0x80) > 0; RSV1 = band(byte1, 0x40) > 0; RSV2 = band(byte1, 0x20) > 0; RSV3 = band(byte1, 0x10) > 0; opcode = band(byte1, 0x0F); MASK = band(byte2, 0x80) > 0; length = band(byte2, 0x7F); }; local length_bytes = 0; if result.length == 126 then length_bytes = 2; elseif result.length == 127 then length_bytes = 8; end local header_length = 2 + length_bytes + (result.MASK and 4 or 0); if #frame < header_length then return; end if length_bytes == 2 then result.length = read_uint16be(frame, 3); elseif length_bytes == 8 then result.length = read_uint64be(frame, 3); end if result.MASK then result.key = { s_byte(frame, length_bytes+3, length_bytes+6) }; end return result, header_length; end -- XORs the string `str` with the array of bytes `key` -- TODO: optimize local function apply_mask(str, key, from, to) from = from or 1 if from < 0 then from = #str + from + 1 end -- negative indices to = to or #str if to < 0 then to = #str + to + 1 end -- negative indices local key_len = #key local counter = 0; local data = {}; for i = from, to do local key_index = counter%key_len + 1; counter = counter + 1; data[counter] = s_char(bxor(key[key_index], s_byte(str, i))); end return t_concat(data); end local function parse_frame_body(frame, header, pos) if header.MASK then return apply_mask(frame, header.key, pos, pos + header.length - 1); else return frame:sub(pos, pos + header.length - 1); end end local function parse_frame(frame) local result, pos = parse_frame_header(frame); if result == nil or #frame < (pos + result.length) then return; end result.data = parse_frame_body(frame, result, pos+1); return result, pos + result.length; end local function build_frame(desc) local data = desc.data or ""; assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode"); if desc.opcode >= 0x8 then -- RFC 6455 5.5 assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less."); end local b1 = bor(desc.opcode, desc.FIN and 0x80 or 0, desc.RSV1 and 0x40 or 0, desc.RSV2 and 0x20 or 0, desc.RSV3 and 0x10 or 0); local b2 = #data; local length_extra; if b2 <= 125 then -- 7-bit length length_extra = ""; elseif b2 <= 0xFFFF then -- 2-byte length b2 = 126; length_extra = pack_uint16be(#data); else -- 8-byte length b2 = 127; length_extra = pack_uint64be(#data); end local key = "" if desc.MASK then local key_a = desc.key if key_a then key = s_char(unpack(key_a, 1, 4)); else key = random_bytes(4); key_a = {key:byte(1,4)}; end b2 = bor(b2, 0x80); data = apply_mask(data, key_a); end return s_char(b1, b2) .. length_extra .. key .. data end local function parse_close(data) local code, message if #data >= 2 then code = read_uint16be(data, 1); if #data > 2 then message = s_sub(data, 3); end end return code, message end local function build_close(code, message, mask) local data = pack_uint16be(code); if message then assert(#message<=123, "Close reason must be <=123 bytes"); data = data .. message; end return build_frame({ opcode = 0x8; FIN = true; MASK = mask; data = data; }); end return { parse_header = parse_frame_header; parse_body = parse_frame_body; parse = parse_frame; build = build_frame; parse_close = parse_close; build_close = build_close; };