File

util/sasl/scram.lua @ 505:289c866d7fb0

verse: Fix to work with server_epoll
author Kim Alvefur <zash@zash.se>
date Sat, 24 Jun 2023 09:48:23 +0200
parent 490:6b2f31da9610
line wrap: on
line source


local base64, unbase64 = require "mime".b64, require"mime".unb64;
local hashes = require"prosody.util.hashes";
local bit = require"bit";
local random = require"prosody.util.random";

local tonumber = tonumber;
local char, byte = string.char, string.byte;
local gsub = string.gsub;
local xor = bit.bxor;

local function XOR(a, b)
	return (gsub(a, "()(.)", function(i, c)
		return char(xor(byte(c), byte(b, i)))
	end));
end

local H, HMAC = hashes.sha1, hashes.hmac_sha1;

local function Hi(str, salt, i)
	local U = HMAC(str, salt .. "\0\0\0\1");
	local ret = U;
	for _ = 2, i do
		U = HMAC(str, U);
		ret = XOR(ret, U);
	end
	return ret;
end

local function Normalize(str)
	return str; -- TODO
end

local function value_safe(str)
	return (gsub(str, "[,=]", { [","] = "=2C", ["="] = "=3D" }));
end

local function cb(conn)
	if conn:ssl() then
		local sock = conn:socket();
		if sock.info and sock:info().protocol == "TLSv1.3" then
			if sock.exportkeyingmaterial then
				return "p=tls-exporter", sock:exportkeyingmaterial("EXPORTER-Channel-Binding", 32, "");
			end
		elseif sock.getfinished then
			return "p=tls-unique", sock:getfinished();
		end
	end
end

local function scram(stream, name)
	local username = "n=" .. value_safe(stream.username);
	local c_nonce = base64(random.bytes(15));
	local our_nonce = "r=" .. c_nonce;
	local client_first_message_bare = username .. "," .. our_nonce;
	local cbind_data = "";
	local gs2_cbind_flag = "n";
	if name == "SCRAM-SHA-1-PLUS" then
		gs2_cbind_flag, cbind_data = cb(stream.conn);
	elseif cb(stream.conn) then
		gs2_cbind_flag = "y";
	end
	local gs2_header = gs2_cbind_flag .. ",,";
	local client_first_message = gs2_header .. client_first_message_bare;
	local cont, server_first_message = coroutine.yield(client_first_message);
	if cont ~= "challenge" then return false end

	local nonce, salt, iteration_count = server_first_message:match("(r=[^,]+),s=([^,]*),i=(%d+)");
	local i = tonumber(iteration_count);
	salt = unbase64(salt);
	if not nonce or not salt or not i then
		return false, "Could not parse server_first_message";
	elseif nonce:find(c_nonce, 3, true) ~= 3 then
		return false, "nonce sent by server does not match our nonce";
	elseif nonce == our_nonce then
		return false, "server did not append s-nonce to nonce";
	end

	local cbind_input = gs2_header .. cbind_data;
	local channel_binding = "c=" .. base64(cbind_input);
	local client_final_message_without_proof = channel_binding .. "," .. nonce;

	local SaltedPassword;
	local ClientKey;
	local ServerKey;

	if stream.client_key and stream.server_key then
		ClientKey = stream.client_key;
		ServerKey = stream.server_key;
	else
		if stream.salted_password then
			SaltedPassword = stream.salted_password;
		elseif stream.password then
			SaltedPassword = Hi(Normalize(stream.password), salt, i);
		end
		ServerKey = HMAC(SaltedPassword, "Server Key");
		ClientKey = HMAC(SaltedPassword, "Client Key");
	end

	local StoredKey       = H(ClientKey);
	local AuthMessage     = client_first_message_bare .. "," ..  server_first_message .. "," ..  client_final_message_without_proof;
	local ClientSignature = HMAC(StoredKey, AuthMessage);
	local ClientProof     = XOR(ClientKey, ClientSignature);
	local ServerSignature = HMAC(ServerKey, AuthMessage);

	local proof = "p=" .. base64(ClientProof);
	local client_final_message = client_final_message_without_proof .. "," .. proof;

	local ok, server_final_message = coroutine.yield(client_final_message);
	if ok ~= "success" then return false, "success-expected" end

	local verifier = server_final_message:match("v=([^,]+)");
	if unbase64(verifier) ~= ServerSignature then
		return false, "server signature did not match";
	end
	return true;
end

return function (stream, name)
	if stream.username and (stream.password or (stream.client_key or stream.server_key)) then
		if name == "SCRAM-SHA-1" then
			return scram, 99;
		elseif name == "SCRAM-SHA-1-PLUS" then
			if cb(stream.conn) then
				return scram, 100;
			end
		end
	end
end