Diff

util/sasl.lua @ 2193:8fbbdb11a520

Merge with sasl branch.
author Tobias Markmann <tm@ayena.de>
date Mon, 16 Nov 2009 21:43:57 +0100
parent 2080:ca419b92a8c7
parent 2191:e79c0ce6cf54
child 2198:d18b4d22b8da
line wrap: on
line diff
--- a/util/sasl.lua	Fri Nov 13 06:29:37 2009 +0500
+++ b/util/sasl.lua	Mon Nov 16 21:43:57 2009 +0100
@@ -16,9 +16,8 @@
 local log = require "util.logger".init("sasl");
 local tostring = tostring;
 local st = require "util.stanza";
-local generate_uuid = require "util.uuid".generate;
+local pairs, ipairs = pairs, ipairs;
 local t_insert, t_concat = table.insert, table.concat;
-local to_byte, to_char = string.byte, string.char;
 local to_unicode = require "util.encodings".idna.to_unicode;
 local s_match = string.match;
 local gmatch = string.gmatch
@@ -27,244 +26,110 @@
 local type = type
 local error = error
 local print = print
+local setmetatable = setmetatable;
+local assert = assert;
+local dofile = dofile;
+local require = require;
 
+require "util.iterators"
+local keys = keys
+
+local array = require "util.array"
 module "sasl"
 
--- Credentials handler:
---   Arguments: ("PLAIN", user, host, password)
---   Returns: true (success) | false (fail) | nil (user unknown)
-local function new_plain(realm, credentials_handler)
-	local object = { mechanism = "PLAIN", realm = realm, credentials_handler = credentials_handler}
-	function object.feed(self, message)
-		if message == "" or message == nil then return "failure", "malformed-request" end
-		local response = message
-		local authorization = s_match(response, "([^%z]+)")
-		local authentication = s_match(response, "%z([^%z]+)%z")
-		local password = s_match(response, "%z[^%z]+%z([^%z]+)")
-
-    if authentication == nil or password == nil then return "failure", "malformed-request" end
-    self.username = authentication
-    local auth_success = self.credentials_handler("PLAIN", self.username, self.realm, password)
+--[[
+Authentication Backend Prototypes:
 
-    if auth_success then
-      return "success"
-    elseif auth_success == nil then
-      return "failure", "account-disabled"
-    else
-      return "failure", "not-authorized"
-    end
-  end
-  return object
-end
+state = false : disabled
+state = true : enabled
+state = nil : non-existant
 
--- credentials_handler:
---   Arguments: (mechanism, node, domain, realm, decoder)
---   Returns: Password encoding, (plaintext) password
--- implementing RFC 2831
-local function new_digest_md5(realm, credentials_handler)
-	--TODO complete support for authzid
-
-	local function serialize(message)
-		local data = ""
-
-		if type(message) ~= "table" then error("serialize needs an argument of type table.") end
-
-		-- testing all possible values
-		if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
-		if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
-		if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
-		if message["charset"] then data = data..[[charset=]]..message.charset.."," end
-		if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
-		if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
-		data = data:gsub(",$", "")
-		return data
+plain:
+	function(username, realm)
+		return password, state;
 	end
 
-	local function utf8tolatin1ifpossible(passwd)
-		local i = 1;
-		while i <= #passwd do
-			local passwd_i = to_byte(passwd:sub(i, i));
-			if passwd_i > 0x7F then
-				if passwd_i < 0xC0 or passwd_i > 0xC3 then
-					return passwd;
-				end
-				i = i + 1;
-				passwd_i = to_byte(passwd:sub(i, i));
-				if passwd_i < 0x80 or passwd_i > 0xBF then
-					return passwd;
-				end
-			end
-			i = i + 1;
-		end
+plain-test:
+	function(username, realm, password)
+		return true or false, state;
+	end
 
-		local p = {};
-		local j = 0;
-		i = 1;
-		while (i <= #passwd) do
-			local passwd_i = to_byte(passwd:sub(i, i));
-			if passwd_i > 0x7F then
-				i = i + 1;
-				local passwd_i_1 = to_byte(passwd:sub(i, i));
-				t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
-			else
-				t_insert(p, to_char(passwd_i));
-			end
-			i = i + 1;
-		end
-		return t_concat(p);
-	end
-	local function latin1toutf8(str)
-		local p = {};
-		for ch in gmatch(str, ".") do
-			ch = to_byte(ch);
-			if (ch < 0x80) then
-				t_insert(p, to_char(ch));
-			elseif (ch < 0xC0) then
-				t_insert(p, to_char(0xC2, ch));
-			else
-				t_insert(p, to_char(0xC3, ch - 64));
-			end
-		end
-		return t_concat(p);
-	end
-	local function parse(data)
-		local message = {}
-		for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
-			message[k] = v;
-		end
-		return message;
+digest-md5:
+	function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
+												-- implementations it's not
+		return digesthash, state;
 	end
 
-	local object = { mechanism = "DIGEST-MD5", realm = realm, credentials_handler = credentials_handler};
-
-	object.nonce = generate_uuid();
-	object.step = 0;
-	object.nonce_count = {};
-
-	function object.feed(self, message)
-		self.step = self.step + 1;
-		if (self.step == 1) then
-			local challenge = serialize({	nonce = object.nonce,
-											qop = "auth",
-											charset = "utf-8",
-											algorithm = "md5-sess",
-											realm = self.realm});
-			return "challenge", challenge;
-		elseif (self.step == 2) then
-			local response = parse(message);
-			-- check for replay attack
-			if response["nc"] then
-				if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
-			end
+digest-md5-test:
+	function(username, domain, realm, encoding, digesthash)
+		return true or false, state;
+	end
+]]
 
-			-- check for username, it's REQUIRED by RFC 2831
-			if not response["username"] then
-				return "failure", "malformed-request";
-			end
-			self["username"] = response["username"];
-
-			-- check for nonce, ...
-			if not response["nonce"] then
-				return "failure", "malformed-request";
-			else
-				-- check if it's the right nonce
-				if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
-			end
-
-			if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
-			if not response["qop"] then response["qop"] = "auth" end
-
-			if response["realm"] == nil or response["realm"] == "" then
-				response["realm"] = "";
-			elseif response["realm"] ~= self.realm then
-				return "failure", "not-authorized", "Incorrect realm value";
-			end
-
-			local decoder;
-			if response["charset"] == nil then
-				decoder = utf8tolatin1ifpossible;
-			elseif response["charset"] ~= "utf-8" then
-				return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
-			end
+local method = {};
+method.__index = method;
+local mechanisms = {};
+local backend_mechanism = {};
 
-			local domain = "";
-			local protocol = "";
-			if response["digest-uri"] then
-				protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
-				if protocol == nil or domain == nil then return "failure", "malformed-request" end
-			else
-				return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
-			end
+-- register a new SASL mechanims
+local function registerMechanism(name, backends, f)
+	assert(type(name) == "string", "Parameter name MUST be a string.");
+	assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
+	assert(type(f) == "function", "Parameter f MUST be a function.");
+	mechanisms[name] = f
+	for _, backend_name in ipairs(backends) do
+		if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
+		t_insert(backend_mechanism[backend_name], name);
+	end
+end
 
-			--TODO maybe realm support
-			self.username = response["username"];
-			local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
-			if Y == nil then return "failure", "not-authorized"
-			elseif Y == false then return "failure", "account-disabled" end
-			local A1 = "";
-			if response.authzid then
-				if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
-					-- COMPAT
-					log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
-					A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
-				else
-					return "failure", "invalid-authzid";
-				end
-			else
-				A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
+-- create a new SASL object which can be used to authenticate clients
+function new(realm, profile)
+	sasl_i = {profile = profile};
+	sasl_i.realm = realm;
+	return setmetatable(sasl_i, method);
+end
+
+-- get a list of possible SASL mechanims to use
+function method:mechanisms()
+	local mechanisms = {}
+	for backend, f in pairs(self.profile) do
+		print(backend)
+		if backend_mechanism[backend] then
+			for _, mechanism in ipairs(backend_mechanism[backend]) do
+				mechanisms[mechanism] = true;
 			end
-			local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
-
-			local HA1 = md5(A1, true);
-			local HA2 = md5(A2, true);
-
-			local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
-			local response_value = md5(KD, true);
-
-			if response_value == response["response"] then
-				-- calculate rspauth
-				A2 = ":"..protocol.."/"..domain;
-
-				HA1 = md5(A1, true);
-				HA2 = md5(A2, true);
-
-				KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
-				local rspauth = md5(KD, true);
-				self.authenticated = true;
-				return "challenge", serialize({rspauth = rspauth});
-			else
-				return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
-			end
-		elseif self.step == 3 then
-			if self.authenticated ~= nil then return "success"
-			else return "failure", "malformed-request" end
 		end
 	end
-	return object;
+	self["possible_mechanisms"] = mechanisms;
+	return array.collect(keys(mechanisms));
 end
 
--- Credentials handler: Can be nil. If specified, should take the mechanism as
--- the only argument, and return true for OK, or false for not-OK (TODO)
-local function new_anonymous(realm, credentials_handler)
-	local object = { mechanism = "ANONYMOUS", realm = realm, credentials_handler = credentials_handler}
-		function object.feed(self, message)
-			return "success"
-		end
-	object["username"] = generate_uuid()
-	return object
+-- select a mechanism to use
+function method:select(mechanism)
+	if self.mech_i then
+		return false;
+	end
+	
+	self.mech_i = mechanisms[mechanism]
+	if self.mech_i == nil then 
+		return false;
+	end
+	return true;
 end
 
+-- feed new messages to process into the library
+function method:process(message)
+	--if message == "" or message == nil then return "failure", "malformed-request" end
+	return self.mech_i(self, message);
+end
 
-function new(mechanism, realm, credentials_handler)
-	local object
-	if mechanism == "PLAIN" then object = new_plain(realm, credentials_handler)
-	elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, credentials_handler)
-	elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, credentials_handler)
-	else
-		log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
-		return nil
-	end
-	return object
+-- load the mechanisms
+load_mechs = {"plain", "digest-md5", "anonymous"}
+for _, mech in ipairs(load_mechs) do
+	local name = "util.sasl."..mech;
+	local m = require(name);
+	m.init(registerMechanism)
 end
 
 return _M;