File

mod_auth_token/mock.lua @ 5193:2bb29ece216b

mod_http_oauth2: Implement stateless dynamic client registration Replaces previous explicit registration that required either the additional module mod_adhoc_oauth2_client or manually editing the database. That method was enough to have something to test with, but would not probably not scale easily. Dynamic client registration allows creating clients on the fly, which may be even easier in theory. In order to not allow basically unauthenticated writes to the database, we implement a stateless model here. per_host_key := HMAC(config -> oauth2_registration_key, hostname) client_id := JWT { client metadata } signed with per_host_key client_secret := HMAC(per_host_key, client_id) This should ensure everything we need to know is part of the client_id, allowing redirects etc to be validated, and the client_secret can be validated with only the client_id and the per_host_key. A nonce injected into the client_id JWT should ensure nobody can submit the same client metadata and retrieve the same client_secret
author Kim Alvefur <zash@zash.se>
date Fri, 03 Mar 2023 21:14:19 +0100
parent 2956:d0ca211e1b0e
line wrap: on
line source

-- Source code taken from https://github.com/britzl/deftest
-- Released under the MIT License. Copyright (c) 2009-2012 Norman Clarke.

--- Provides the ability to mock any module.

-- @usage
--
-- mock.mock(sys)
--
-- -- specifying return values
-- sys.get_sys_info.returns({my_data})
-- ...
-- local sys_info = sys.get_sys_info() -- will be my_data
-- assert(sys.get_sys_info.calls == 1) -- call counting
-- ...
-- local sys_info = sys.get_sys_info() -- original response as we are now out of mocked answers
-- assert(sys.get_sys_info.calls == 2) -- call counting
-- ...
--
-- -- specifying a replacement function
-- sys.get_sys_info.replace(function () return my_data end)
--
-- ...
-- local sys_info = sys.get_sys_info() -- will be my_data
-- assert(sys.get_sys_info.calls == 3) -- call counting
-- ...
-- local sys_info = sys.get_sys_info() -- will still be my_data
-- assert(sys.get_sys_info.calls == 4) -- call counting
-- ...
--
-- -- cleaning up
-- mock.unmock(sys) -- restore the sys library again

local mock = {}

--- Mock the specified module.
-- Mocking the module extends the functions it contains with the ability to have their logic overridden.
-- @param module module to mock
-- @usage
--
-- -- mock module x
-- mock.mock(x)
--
-- -- make x.f return 1, 2 then the original value
-- x.f.returns({1, 2})
-- print(x.f()) -- prints 1
--
-- -- make x.f return 1 forever
-- x.f.replace(function () return 1 end)
-- while true do print(x.f()) end -- prints 1 forever
--
-- -- counting calls
-- assert(x.f.calls > 0)
--
-- -- return to original state of module x
-- mock.unmock(x)
--
function mock.mock(module)
	assert(module, "You must provide a module to mock")
	for k,v in pairs(module) do
		if type(v) == "function" then
			local mock_fn = {
				calls = 0,
				answers = {},
				repl_fn = nil,
				orig_fn = v,
				params = {}
			}
			function mock_fn.returns(...)
				local arg_length = select("#", ...)
				assert(arg_length > 0, "You must provide some answers")
				local args = { ... }
				if arg_length == 1 then
					mock_fn.answers = args[1]
				else
					mock_fn.answers = args
				end
			end
			function mock_fn.always_returns(answer)
				mock_fn.repl_fn = function()
					return answer
				end
			end
			function mock_fn.replace(repl_fn)
				mock_fn.repl_fn = repl_fn
			end
			function mock_fn.original(...)
				return mock_fn.orig_fn(...)
			end
			function mock_fn.restore()
				mock_fn.repl_fn = nil
			end
			local mt = {
				__call = function (mock_fn, ...)
					mock_fn.calls = mock_fn.calls + 1
					local arg = {...}

					if #arg > 0 then
						for i=1,#arg do
							mock_fn.params[i] = arg[i]
						end
					end

					if mock_fn.answers[1] then
						local result = mock_fn.answers[1]
						table.remove(mock_fn.answers, 1)
						return result
					elseif mock_fn.repl_fn then
						return mock_fn.repl_fn(...)
					else
						return v(...)
					end
				end
			}
			setmetatable(mock_fn, mt)
			module[k] = mock_fn
		end
	end
end

--- Remove the mocking capabilities from a module.
-- @param module module to remove mocking from
function mock.unmock(module)
	assert(module, "You must provide a module to unmock")
	for k,v in pairs(module) do
		if type(v) == "table" then
			if v.orig_fn then
				module[k] = v.orig_fn
			end
		end
	end
end

return mock