File

util/jwt.lua @ 13801:a5d5fefb8b68 13.0

mod_tls: Enable Prosody's certificate checking for incoming s2s connections (fixes #1916) (thanks Damian, Zash) Various options in Prosody allow control over the behaviour of the certificate verification process For example, some deployments choose to allow falling back to traditional "dialback" authentication (XEP-0220), while others verify via DANE, hard-coded fingerprints, or other custom plugins. Implementing this flexibility requires us to override OpenSSL's default certificate verification, to allow Prosody to verify the certificate itself, apply custom policies and make decisions based on the outcome. To enable our custom logic, we have to suppress OpenSSL's default behaviour of aborting the connection with a TLS alert message. With LuaSec, this can be achieved by using the verifyext "lsec_continue" flag. We also need to use the lsec_ignore_purpose flag, because XMPP s2s uses server certificates as "client" certificates (for mutual TLS verification in outgoing s2s connections). Commit 99d2100d2918 moved these settings out of the defaults and into mod_s2s, because we only really need these changes for s2s, and they should be opt-in, rather than automatically applied to all TLS services we offer. That commit was incomplete, because it only added the flags for incoming direct TLS connections. StartTLS connections are handled by mod_tls, which was not applying the lsec_* flags. It previously worked because they were already in the defaults. This resulted in incoming s2s connections with "invalid" certificates being aborted early by OpenSSL, even if settings such as `s2s_secure_auth = false` or DANE were present in the config. Outgoing s2s connections inherit verify "none" from the defaults, which means OpenSSL will receive the cert but will not terminate the connection when it is deemed invalid. This means we don't need lsec_continue there, and we also don't need lsec_ignore_purpose (because the remote peer is a "server"). Wondering why we can't just use verify "none" for incoming s2s? It's because in that mode, OpenSSL won't request a certificate from the peer for incoming connections. Setting verify "peer" is how you ask OpenSSL to request a certificate from the client, but also what triggers its built-in verification.
author Matthew Wild <mwild1@gmail.com>
date Tue, 01 Apr 2025 17:26:56 +0100
parent 12975:d10957394a3c
line wrap: on
line source

local s_gsub = string.gsub;
local crypto = require "prosody.util.crypto";
local json = require "prosody.util.json";
local hashes = require "prosody.util.hashes";
local base64_encode = require "prosody.util.encodings".base64.encode;
local base64_decode = require "prosody.util.encodings".base64.decode;
local secure_equals = require "prosody.util.hashes".equals;

local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" };
local function b64url(data)
	return (s_gsub(base64_encode(data), "[+/=]", b64url_rep));
end
local function unb64url(data)
	return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
end

local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$"
local function decode_jwt(blob, expected_alg)
	local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern);
	if not signed then
		return nil, "invalid-encoding";
	end
	local header = json.decode(unb64url(bheader));
	if not header or type(header) ~= "table" then
		return nil, "invalid-header";
	elseif header.alg ~= expected_alg then
		return nil, "unsupported-algorithm";
	end
	return signed, signature, bpayload;
end

local function new_static_header(algorithm_name)
	return b64url('{"alg":"'..algorithm_name..'","typ":"JWT"}') .. '.';
end

local function decode_raw_payload(raw_payload)
	local payload, err = json.decode(unb64url(raw_payload));
	if err ~= nil then
		return nil, "json-decode-error";
	elseif type(payload) ~= "table" then
		return nil, "invalid-payload-type";
	end
	return true, payload;
end

-- HS*** family
local function new_hmac_algorithm(name)
	local static_header = new_static_header(name);

	local hmac = hashes["hmac_sha"..name:sub(-3)];

	local function sign(key, payload)
		local encoded_payload = json.encode(payload);
		local signed = static_header .. b64url(encoded_payload);
		local signature = hmac(key, signed);
		return signed .. "." .. b64url(signature);
	end

	local function verify(key, blob)
		local signed, signature, raw_payload = decode_jwt(blob, name);
		if not signed then return nil, signature; end -- nil, err

		if not secure_equals(b64url(hmac(key, signed)), signature) then
			return false, "signature-mismatch";
		end

		return decode_raw_payload(raw_payload);
	end

	local function load_key(key)
		assert(type(key) == "string", "key must be string (long, random, secure)");
		return key;
	end

	return { sign = sign, verify = verify, load_key = load_key };
end

local function new_crypto_algorithm(name, key_type, c_sign, c_verify, sig_encode, sig_decode)
	local static_header = new_static_header(name);

	return {
		sign = function (private_key, payload)
			local encoded_payload = json.encode(payload);
			local signed = static_header .. b64url(encoded_payload);

			local signature = c_sign(private_key, signed);
			if sig_encode then
				signature = sig_encode(signature);
			end

			return signed.."."..b64url(signature);
		end;

		verify = function (public_key, blob)
			local signed, signature, raw_payload = decode_jwt(blob, name);
			if not signed then return nil, signature; end -- nil, err

			signature = unb64url(signature);
			if sig_decode and signature then
				signature = sig_decode(signature);
			end
			if not signature then
				return false, "signature-mismatch";
			end

			local verify_ok = c_verify(public_key, signed, signature);
			if not verify_ok then
				return false, "signature-mismatch";
			end

			return decode_raw_payload(raw_payload);
		end;

		load_public_key = function (public_key_pem)
			local key = assert(crypto.import_public_pem(public_key_pem));
			assert(key:get_type() == key_type, "incorrect key type");
			return key;
		end;

		load_private_key = function (private_key_pem)
			local key = assert(crypto.import_private_pem(private_key_pem));
			assert(key:get_type() == key_type, "incorrect key type");
			return key;
		end;
	};
end

-- RS***, PS***
local rsa_sign_algos = { RS = "rsassa_pkcs1", PS = "rsassa_pss" };
local function new_rsa_algorithm(name)
	local family, digest_bits = name:match("^(..)(...)$");
	local c_sign = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_sign"];
	local c_verify = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_verify"];
	return new_crypto_algorithm(name, "rsaEncryption", c_sign, c_verify);
end

-- ES***
local function new_ecdsa_algorithm(name, c_sign, c_verify, sig_bytes)
	local function encode_ecdsa_sig(der_sig)
		local r, s = crypto.parse_ecdsa_signature(der_sig, sig_bytes);
		return r..s;
	end

	local expected_sig_length = sig_bytes*2;
	local function decode_ecdsa_sig(jwk_sig)
		if #jwk_sig ~= expected_sig_length then
			return nil;
		end
		return crypto.build_ecdsa_signature(jwk_sig:sub(1, sig_bytes), jwk_sig:sub(sig_bytes+1));
	end
	return new_crypto_algorithm(name, "id-ecPublicKey", c_sign, c_verify, encode_ecdsa_sig, decode_ecdsa_sig);
end

local algorithms = {
	HS256 = new_hmac_algorithm("HS256"), HS384 = new_hmac_algorithm("HS384"), HS512 = new_hmac_algorithm("HS512");
	ES256 = new_ecdsa_algorithm("ES256", crypto.ecdsa_sha256_sign, crypto.ecdsa_sha256_verify, 32);
	ES512 = new_ecdsa_algorithm("ES512", crypto.ecdsa_sha512_sign, crypto.ecdsa_sha512_verify, 66);
	RS256 = new_rsa_algorithm("RS256"), RS384 = new_rsa_algorithm("RS384"), RS512 = new_rsa_algorithm("RS512");
	PS256 = new_rsa_algorithm("PS256"), PS384 = new_rsa_algorithm("PS384"), PS512 = new_rsa_algorithm("PS512");
};

local function new_signer(algorithm, key_input, options)
	local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm);
	local key = (impl.load_private_key or impl.load_key)(key_input);
	local sign = impl.sign;
	local default_ttl = (options and options.default_ttl) or 3600;
	return function (payload)
		local issued_at;
		if not payload.iat then
			issued_at = os.time();
			payload.iat = issued_at;
		end
		if not payload.exp then
			payload.exp = (issued_at or os.time()) + default_ttl;
		end
		return sign(key, payload);
	end
end

local function new_verifier(algorithm, key_input, options)
	local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm);
	local key = (impl.load_public_key or impl.load_key)(key_input);
	local verify = impl.verify;
	local check_expiry = not (options and options.accept_expired);
	local claim_verifier = options and options.claim_verifier;
	return function (token)
		local ok, payload = verify(key, token);
		if ok then
			local expires_at = check_expiry and payload.exp;
			if expires_at then
				if type(expires_at) ~= "number" then
					return nil, "invalid-expiry";
				elseif expires_at < os.time() then
					return nil, "token-expired";
				end
			end
			if claim_verifier and not claim_verifier(payload) then
				return nil, "incorrect-claims";
			end
		end
		return ok, payload;
	end
end

local function init(algorithm, private_key, public_key, options)
	return new_signer(algorithm, private_key, options), new_verifier(algorithm, public_key or private_key, options);
end

return {
	init = init;
	new_signer = new_signer;
	new_verifier = new_verifier;
	-- Exported mainly for tests
	_algorithms = algorithms;
	-- Deprecated
	sign = algorithms.HS256.sign;
	verify = algorithms.HS256.verify;
};