Software /
code /
verse
File
util/sasl/scram.lua @ 456:6a65142052c8
sasl: Include offered mechanisms in event when no supported mechanisms found
This allows consumers of the event to determine what may be required in order
to successfully authenticate.
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Fri, 03 Mar 2023 12:10:54 +0000 |
parent | 455:753d6983dc45 |
child | 490:6b2f31da9610 |
line wrap: on
line source
local base64, unbase64 = require "mime".b64, require"mime".unb64; local hashes = require"util.hashes"; local bit = require"bit"; local random = require"util.random"; local tonumber = tonumber; local char, byte = string.char, string.byte; local gsub = string.gsub; local xor = bit.bxor; local function XOR(a, b) return (gsub(a, "()(.)", function(i, c) return char(xor(byte(c), byte(b, i))) end)); end local H, HMAC = hashes.sha1, hashes.hmac_sha1; local function Hi(str, salt, i) local U = HMAC(str, salt .. "\0\0\0\1"); local ret = U; for _ = 2, i do U = HMAC(str, U); ret = XOR(ret, U); end return ret; end local function Normalize(str) return str; -- TODO end local function value_safe(str) return (gsub(str, "[,=]", { [","] = "=2C", ["="] = "=3D" })); end local function cb(conn) if conn:ssl() then local sock = conn:socket(); if sock.info and sock:info().protocol == "TLSv1.3" then if sock.exportkeyingmaterial then return "p=tls-exporter", sock:exportkeyingmaterial("EXPORTER-Channel-Binding", 32, ""); end elseif sock.getfinished then return "p=tls-unique", sock:getfinished(); end end end local function scram(stream, name) local username = "n=" .. value_safe(stream.username); local c_nonce = base64(random.bytes(15)); local our_nonce = "r=" .. c_nonce; local client_first_message_bare = username .. "," .. our_nonce; local cbind_data = ""; local gs2_cbind_flag = "n"; if name == "SCRAM-SHA-1-PLUS" then gs2_cbind_flag, cbind_data = cb(stream.conn); elseif cb(stream.conn) then gs2_cbind_flag = "y"; end local gs2_header = gs2_cbind_flag .. ",,"; local client_first_message = gs2_header .. client_first_message_bare; local cont, server_first_message = coroutine.yield(client_first_message); if cont ~= "challenge" then return false end local nonce, salt, iteration_count = server_first_message:match("(r=[^,]+),s=([^,]*),i=(%d+)"); local i = tonumber(iteration_count); salt = unbase64(salt); if not nonce or not salt or not i then return false, "Could not parse server_first_message"; elseif nonce:find(c_nonce, 3, true) ~= 3 then return false, "nonce sent by server does not match our nonce"; elseif nonce == our_nonce then return false, "server did not append s-nonce to nonce"; end local cbind_input = gs2_header .. cbind_data; local channel_binding = "c=" .. base64(cbind_input); local client_final_message_without_proof = channel_binding .. "," .. nonce; local SaltedPassword; local ClientKey; local ServerKey; if stream.client_key and stream.server_key then ClientKey = stream.client_key; ServerKey = stream.server_key; else if stream.salted_password then SaltedPassword = stream.salted_password; elseif stream.password then SaltedPassword = Hi(Normalize(stream.password), salt, i); end ServerKey = HMAC(SaltedPassword, "Server Key"); ClientKey = HMAC(SaltedPassword, "Client Key"); end local StoredKey = H(ClientKey); local AuthMessage = client_first_message_bare .. "," .. server_first_message .. "," .. client_final_message_without_proof; local ClientSignature = HMAC(StoredKey, AuthMessage); local ClientProof = XOR(ClientKey, ClientSignature); local ServerSignature = HMAC(ServerKey, AuthMessage); local proof = "p=" .. base64(ClientProof); local client_final_message = client_final_message_without_proof .. "," .. proof; local ok, server_final_message = coroutine.yield(client_final_message); if ok ~= "success" then return false, "success-expected" end local verifier = server_final_message:match("v=([^,]+)"); if unbase64(verifier) ~= ServerSignature then return false, "server signature did not match"; end return true; end return function (stream, name) if stream.username and (stream.password or (stream.client_key or stream.server_key)) then if name == "SCRAM-SHA-1" then return scram, 99; elseif name == "SCRAM-SHA-1-PLUS" then if cb(stream.conn) then return scram, 100; end end end end