Diff

util/sasl/scram.lua @ 5868:bc37c6758f3a

util.sasl.scram: Create the state table as late as possible, keep state in locals for faster access
author Kim Alvefur <zash@zash.se>
date Sun, 13 Oct 2013 00:29:47 +0200
parent 5867:72d49d1e2d11
child 5869:35780ef2d689
line wrap: on
line diff
--- a/util/sasl/scram.lua	Sat Oct 12 21:15:36 2013 +0200
+++ b/util/sasl/scram.lua	Sun Oct 13 00:29:47 2013 +0200
@@ -102,22 +102,19 @@
 
 local function scram_gen(hash_name, H_f, HMAC_f)
 	local function scram_hash(self, message)
-		if not self.state then self["state"] = {} end
 		local support_channel_binding = false;
 		if self.profile.cb then support_channel_binding = true; end
 
 		if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
-		if not self.state.name then
+		local state = self.state;
+		if not state then
 			-- we are processing client_first_message
 			local client_first_message = message;
 
 			-- TODO: fail if authzid is provided, since we don't support them yet
-			self.state["client_first_message"] = client_first_message;
-			self.state["gs2_header"], self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
+			local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, name, clientnonce
 				= client_first_message:match("^(([ynp])=?([%a%-]*),(.*),)n=(.*),r=([^,]*).*");
 
-			local gs2_cbind_flag = self.state.gs2_cbind_flag;
-
 			if not gs2_cbind_flag then
 				return "failure", "malformed-request";
 			end
@@ -135,29 +132,24 @@
 
 			if support_channel_binding and gs2_cbind_flag == "p" then
 				-- check whether we support the proposed channel binding type
-				if not self.profile.cb[self.state.gs2_cbind_name] then
+				if not self.profile.cb[gs2_cbind_name] then
 					return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
 				end
 			else
 				-- no channel binding,
-				self.state.gs2_cbind_name = nil;
+				gs2_cbind_name = nil;
 			end
 
-			if not self.state.name or not self.state.clientnonce then
-				return "failure", "malformed-request", "Channel binding isn't support at this time.";
-			end
-
-			self.state.name = validate_username(self.state.name, self.profile.nodeprep);
-			if not self.state.name then
+			name = validate_username(name, self.profile.nodeprep);
+			if not name then
 				log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
 				return "failure", "malformed-request", "Invalid username.";
 			end
 
-			self.state["servernonce"] = generate_uuid();
-
 			-- retreive credentials
+			local stored_key, server_key, salt, iteration_count;
 			if self.profile.plain then
-				local password, state = self.profile.plain(self, self.state.name, self.realm)
+				local password, state = self.profile.plain(self, name, self.realm)
 				if state == nil then return "failure", "not-authorized"
 				elseif state == false then return "failure", "account-disabled" end
 
@@ -167,64 +159,71 @@
 					return "failure", "not-authorized", "Invalid password."
 				end
 
-				self.state.salt = generate_uuid();
-				self.state.iteration_count = default_i;
+				salt = generate_uuid();
+				iteration_count = default_i;
 
 				local succ = false;
-				succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
+				succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
 				if not succ then
-					log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
+					log("error", "Generating authentication database failed. Reason: %s", stored_key);
 					return "failure", "temporary-auth-failure";
 				end
 			elseif self.profile["scram_"..hashprep(hash_name)] then
-				local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
+				local state;
+				stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, name, self.realm);
 				if state == nil then return "failure", "not-authorized"
 				elseif state == false then return "failure", "account-disabled" end
-
-				self.state.stored_key = stored_key;
-				self.state.server_key = server_key;
-				self.state.iteration_count = iteration_count;
-				self.state.salt = salt
 			end
 
-			local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
-			self.state["server_first_message"] = server_first_message;
+			local nonce = clientnonce .. generate_uuid();
+			local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+			self.state = {
+				gs2_header = gs2_header;
+				gs2_cbind_name = gs2_cbind_name;
+				name = name;
+				nonce = nonce;
+
+				server_key = server_key;
+				stored_key = stored_key;
+				client_first_message = client_first_message;
+				server_first_message = server_first_message;
+			}
 			return "challenge", server_first_message
 		else
 			-- we are processing client_final_message
 			local client_final_message = message;
 
-			self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
+			local channelbinding, nonce, proof = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
 
-			if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+			if not proof or not nonce or not channelbinding then
 				return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
 			end
 
-			local client_gs2_header = base64.decode(self.state.channelbinding)
-			local our_client_gs2_header = self.state["gs2_header"]
-			if self.state.gs2_cbind_name then
+			local client_gs2_header = base64.decode(channelbinding)
+			local our_client_gs2_header = state["gs2_header"]
+			if state.gs2_cbind_name then
 				-- we support channelbinding, so check if the value is valid
-				our_client_gs2_header = our_client_gs2_header .. self.profile.cb[self.state.gs2_cbind_name](self);
+				our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
 			end
 			if client_gs2_header ~= our_client_gs2_header then
 				return "failure", "malformed-request", "Invalid channel binding value.";
 			end
 
-			if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
+			if nonce ~= state.nonce then
 				return "failure", "malformed-request", "Wrong nonce in client-final-message.";
 			end
 
-			local ServerKey = self.state.server_key;
-			local StoredKey = self.state.stored_key;
+			local ServerKey = state.server_key;
+			local StoredKey = state.stored_key;
 
-			local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+			local AuthMessage = "n=" .. s_match(state.client_first_message,"n=(.+)") .. "," .. state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
 			local ClientSignature = HMAC_f(StoredKey, AuthMessage)
-			local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
+			local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
 			local ServerSignature = HMAC_f(ServerKey, AuthMessage)
 
 			if StoredKey == H_f(ClientKey) then
 				local server_final_message = "v="..base64.encode(ServerSignature);
-				self["username"] = self.state.name;
+				self["username"] = state.name;
 				return "success", server_final_message;
 			else
 				return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";