Diff

util/sasl/scram.lua @ 11200:bf8f2da84007

Merge 0.11->trunk
author Kim Alvefur <zash@zash.se>
date Thu, 05 Nov 2020 22:31:25 +0100
parent 11174:ddc17e9c66e4
child 12024:9184bdda22be
line wrap: on
line diff
--- a/util/sasl/scram.lua	Thu Nov 05 22:27:17 2020 +0100
+++ b/util/sasl/scram.lua	Thu Nov 05 22:31:25 2020 +0100
@@ -14,16 +14,12 @@
 local s_match = string.match;
 local type = type
 local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hashes".hmac_sha1;
-local sha1 = require "util.hashes".sha1;
-local Hi = require "util.hashes".scram_Hi_sha1;
+local hashes = require "util.hashes";
 local generate_uuid = require "util.uuid".generate;
 local saslprep = require "util.encodings".stringprep.saslprep;
 local nodeprep = require "util.encodings".stringprep.nodeprep;
 local log = require "util.logger".init("sasl");
-local t_concat = table.concat;
-local char = string.char;
-local byte = string.byte;
+local	binaryXOR = require "util.strbitop".sxor;
 
 local _ENV = nil;
 -- luacheck: std none
@@ -47,32 +43,6 @@
 
 local default_i = 4096
 
-local xor_map = {
-	0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10,
-	13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5,
-	4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5,
-	4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13,
-	10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13,
-	14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10,
-	11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1,
-	0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8,
-	11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15,
-	14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,
-};
-
-local result = {};
-local function binaryXOR( a, b )
-	for i=1, #a do
-		local x, y = byte(a, i), byte(b, i);
-		local lowx, lowy = x % 16, y % 16;
-		local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
-		local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
-		local r = hir * 16 + lowr;
-		result[i] = char(r)
-	end
-	return t_concat(result);
-end
-
 local function validate_username(username, _nodeprep)
 	-- check for forbidden char sequences
 	for eq in username:gmatch("=(.?.?)") do
@@ -99,24 +69,26 @@
 	return hashname:lower():gsub("-", "_");
 end
 
-local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
-	if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
-		return false, "inappropriate argument types"
-	end
-	if iteration_count < 4096 then
-		log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+local function get_scram_hasher(H, HMAC, Hi)
+	return function (password, salt, iteration_count)
+		if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
+			return false, "inappropriate argument types"
+		end
+		if iteration_count < 4096 then
+			log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+		end
+		password = saslprep(password);
+		if not password then
+			return false, "password fails SASLprep";
+		end
+		local salted_password = Hi(password, salt, iteration_count);
+		local stored_key = H(HMAC(salted_password, "Client Key"))
+		local server_key = HMAC(salted_password, "Server Key");
+		return true, stored_key, server_key
 	end
-	password = saslprep(password);
-	if not password then
-		return false, "password fails SASLprep";
-	end
-	local salted_password = Hi(password, salt, iteration_count);
-	local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
-	local server_key = hmac_sha1(salted_password, "Server Key");
-	return true, stored_key, server_key
 end
 
-local function scram_gen(hash_name, H_f, HMAC_f)
+local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb)
 	local profile_name = "scram_" .. hashprep(hash_name);
 	local function scram_hash(self, message)
 		local support_channel_binding = false;
@@ -129,6 +101,7 @@
 			local client_first_message = message;
 
 			-- TODO: fail if authzid is provided, since we don't support them yet
+			-- luacheck: ignore 211/authzid
 			local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
 				= s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
 
@@ -144,6 +117,10 @@
 
 			if gs2_cbind_flag == "n" then
 				-- "n" -> client doesn't support channel binding.
+				if expect_cb then
+					log("debug", "Client unexpectedly doesn't support channel binding");
+					-- XXX Is it sensible to abort if the client starts -PLUS but doesn't use channel binding?
+				end
 				support_channel_binding = false;
 			end
 
@@ -181,7 +158,7 @@
 				iteration_count = default_i;
 
 				local succ;
-				succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
+				succ, stored_key, server_key = get_auth_db(password, salt, iteration_count);
 				if not succ then
 					log("error", "Generating authentication database failed. Reason: %s", stored_key);
 					return "failure", "temporary-auth-failure";
@@ -194,7 +171,7 @@
 			end
 
 			local nonce = clientnonce .. generate_uuid();
-			local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+			local server_first_message = ("r=%s,s=%s,i=%d"):format(nonce, base64.encode(salt), iteration_count);
 			self.state = {
 				gs2_header = gs2_header;
 				gs2_cbind_name = gs2_cbind_name;
@@ -251,22 +228,28 @@
 	return scram_hash;
 end
 
+local auth_db_getters = {}
 local function init(registerMechanism)
-	local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
+	local function registerSCRAMMechanism(hash_name, hash, hmac_hash, pbkdf2)
+		local get_auth_db = get_scram_hasher(hash, hmac_hash, pbkdf2);
+		auth_db_getters[hash_name] = get_auth_db;
 		registerMechanism("SCRAM-"..hash_name,
 			{"plain", "scram_"..(hashprep(hash_name))},
-			scram_gen(hash_name:lower(), hash, hmac_hash));
+			scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db));
 
 		-- register channel binding equivalent
 		registerMechanism("SCRAM-"..hash_name.."-PLUS",
 			{"plain", "scram_"..(hashprep(hash_name))},
-			scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+			scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique"});
 	end
 
-	registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
+	registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1);
+	registerSCRAMMechanism("SHA-256", hashes.sha256, hashes.hmac_sha256, hashes.pbkdf2_hmac_sha256);
 end
 
 return {
-	getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+	get_hash = get_scram_hasher;
+	hashers = auth_db_getters;
+	getAuthenticationDatabaseSHA1 = get_scram_hasher(hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); -- COMPAT
 	init = init;
 }