Diff

util/serialization.lua @ 9008:ae3c52419ec1

util.serialization: Rewritte for performance and flexibility ... and because rewrites are fun!
author Kim Alvefur <zash@zash.se>
date Tue, 10 Jul 2018 22:04:26 +0200
parent 8555:4f0f5b49bb03
child 9060:69bc3144c2b7
line wrap: on
line diff
--- a/util/serialization.lua	Tue Jul 10 21:10:12 2018 +0200
+++ b/util/serialization.lua	Tue Jul 10 22:04:26 2018 +0200
@@ -1,84 +1,262 @@
 -- Prosody IM
 -- Copyright (C) 2008-2010 Matthew Wild
 -- Copyright (C) 2008-2010 Waqas Hussain
+-- Copyright (C) 2018 Kim Alvefur
 --
 -- This project is MIT/X11 licensed. Please see the
 -- COPYING file in the source package for more information.
 --
 
-local string_rep = string.rep;
-local type = type;
-local tostring = tostring;
-local t_insert = table.insert;
+local getmetatable = getmetatable;
+local next, type = next, type;
+local s_format = string.format;
+local s_gsub = string.gsub;
+local s_rep = string.rep;
+local s_char = string.char;
+local s_match = string.match;
 local t_concat = table.concat;
-local pairs = pairs;
-local next = next;
 
 local pcall = pcall;
-
-local debug_traceback = debug.traceback;
-local log = require "util.logger".init("serialization");
 local envload = require"util.envload".envload;
 
-local _ENV = nil;
--- luacheck: std none
-
-local indent = function(i)
-	return string_rep("\t", i);
+local pos_inf, neg_inf = math.huge, -math.huge;
+local m_log = math.log;
+local m_log10 = math.log10 or function (n)
+	return m_log(n, 10);
 end
-local function basicSerialize (o)
-	if type(o) == "number" or type(o) == "boolean" then
-		-- no need to check for NaN, as that's not a valid table index
-		if o == 1/0 then return "(1/0)";
-		elseif o == -1/0 then return "(-1/0)";
-		else return tostring(o); end
-	else -- assume it is a string -- FIXME make sure it's a string. throw an error otherwise.
-		return (("%q"):format(tostring(o)):gsub("\\\n", "\\n"));
-	end
+local m_floor = math.floor;
+-- luacheck: ignore 143/math
+local m_type = math.type or function (n)
+	return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
+end;
+
+local char_to_hex = {};
+for i = 0,255 do
+	char_to_hex[s_char(i)] = s_format("%02x", i);
+end
+
+local function to_hex(s)
+	return (s_gsub(s, ".", char_to_hex));
 end
-local function _simplesave(o, ind, t, func)
-	if type(o) == "number" then
-		if o ~= o then func(t, "(0/0)");
-		elseif o == 1/0 then func(t, "(1/0)");
-		elseif o == -1/0 then func(t, "(-1/0)");
-		else func(t, tostring(o)); end
-	elseif type(o) == "string" then
-		func(t, (("%q"):format(o):gsub("\\\n", "\\n")));
-	elseif type(o) == "table" then
-		if next(o) ~= nil then
-			func(t, "{\n");
-			for k,v in pairs(o) do
-				func(t, indent(ind));
-				func(t, "[");
-				func(t, basicSerialize(k));
-				func(t, "] = ");
-				if ind == 0 then
-					_simplesave(v, 0, t, func);
-				else
-					_simplesave(v, ind+1, t, func);
-				end
-				func(t, ";\n");
-			end
-			func(t, indent(ind-1));
-			func(t, "}");
-		else
-			func(t, "{}");
-		end
-	elseif type(o) == "boolean" then
-		func(t, (o and "true" or "false"));
-	else
-		log("error", "cannot serialize a %s: %s", type(o), debug_traceback())
-		func(t, "nil");
+
+local function fatal_error(obj, why)
+	error("Can't serialize "..type(obj) .. (why and ": ".. why or ""));
+end
+
+local function default_fallback(x, why)
+	return s_format("nil --[[%s: %s]]", type(x), why or "fail");
+end
+
+local string_escapes = {
+	['\a'] = [[\a]]; ['\b'] = [[\b]];
+	['\f'] = [[\f]]; ['\n'] = [[\n]];
+	['\r'] = [[\r]]; ['\t'] = [[\t]];
+	['\v'] = [[\v]]; ['\\'] = [[\\]];
+	['\"'] = [[\"]]; ['\''] = [[\']];
+}
+
+for i = 0, 255 do
+	local c = s_char(i);
+	if not string_escapes[c] then
+		string_escapes[c] = s_format("\\%03d", i);
 	end
 end
 
-local function append(t, o)
-	_simplesave(o, 1, t, t.write or t_insert);
-	return t;
-end
+local default_keywords = {
+	["do"] = true; ["and"] = true; ["else"] = true; ["break"] = true;
+	["if"] = true; ["end"] = true; ["goto"] = true; ["false"] = true;
+	["in"] = true; ["for"] = true; ["then"] = true; ["local"] = true;
+	["or"] = true; ["nil"] = true; ["true"] = true; ["until"] = true;
+	["elseif"] = true; ["function"] = true; ["not"] = true;
+	["repeat"] = true; ["return"] = true; ["while"] = true;
+};
+
+local function new(opt)
+	if type(opt) ~= "table" then
+		opt = { preset = opt };
+	end
+
+	local types = {
+		table = true;
+		string = true;
+		number = true;
+		boolean = true;
+		["nil"] = true;
+	};
+
+	-- presets
+	if opt.preset == "debug" then
+		opt.preset = "oneline";
+		opt.freeze = true;
+		opt.fatal = false;
+		opt.fallback = default_fallback;
+	end
+	if opt.preset == "oneline" then
+		opt.indentwith = opt.indentwith or "";
+		opt.itemstart = opt.itemstart or " ";
+		opt.itemlast = opt.itemlast or "";
+		opt.tend = opt.tend or " }";
+	elseif opt.preset == "compact" then
+		opt.indentwith = opt.indentwith or "";
+		opt.itemstart = opt.itemstart or "";
+		opt.itemlast = opt.itemlast or "";
+		opt.equals = opt.equals or "=";
+	end
+
+	local fallback = opt.fatal and fatal_error or opt.fallback or default_fallback;
+
+	local function ser(v)
+		return (types[type(v)] or fallback)(v);
+	end
+
+	local keywords = opt.keywords or default_keywords;
+
+	-- indented
+	local indentwith = opt.indentwith or "\t";
+	local itemstart = opt.itemstart or "\n";
+	local itemsep = opt.itemsep or ";";
+	local itemlast = opt.itemlast or ";\n";
+	local tstart = opt.tstart or "{";
+	local tend = opt.tend or "}";
+	local kstart = opt.kstart or "[";
+	local kend = opt.kend or "]";
+	local equals = opt.equals or " = ";
+	local unquoted = opt.unquoted == nil and "^[%a_][%w_]*$" or opt.unquoted;
+	local hex = opt.hex;
+	local freeze = opt.freeze;
+	local precision = opt.precision or 10;
+
+	-- serialize one table, recursively
+	-- t - table being serialized
+	-- o - array where tokens are added, concatenate to get final result
+	--   - also used to detect cycles
+	-- l - position in o of where to insert next token
+	-- d - depth, used for indentation
+	local function serialize_table(t, o, l, d)
+		if o[t] or d > 127 then
+			o[l], l = fallback(t, "recursion"), l + 1;
+			return l;
+		end
 
-local function serialize(o)
-	return t_concat(append({}, o));
+		o[t] = true;
+		if freeze then
+			-- opportunity to do pre-serialization
+			local mt = getmetatable(t);
+			local fr = (freeze ~= true and freeze[mt]);
+			local mf = mt and mt.__freeze;
+			local tag;
+			if type(fr) == "string" then
+				tag = fr;
+				fr = mf;
+			elseif mt then
+				tag = mt.__type;
+			end
+			if fr then
+				t = fr(t);
+				if tag then
+					o[l], l = tag, l + 1;
+				end
+			end
+		end
+		o[l], l = tstart, l + 1;
+		local indent = s_rep(indentwith, d);
+		local numkey = 1;
+		local ktyp, vtyp;
+		for k,v in next,t do
+			o[l], l = itemstart, l + 1;
+			o[l], l = indent, l + 1;
+			ktyp, vtyp = type(k), type(v);
+			if k == numkey then
+				-- next index in array part
+				-- assuming that these are found in order
+				numkey = numkey + 1;
+			elseif unquoted and ktyp == "string" and
+				not keywords[k] and s_match(k, unquoted) then
+				-- unquoted keys
+				o[l], l = k, l + 1;
+				o[l], l = equals, l + 1;
+			else
+				-- quoted keys
+				o[l], l = kstart, l + 1;
+				if ktyp == "table" then
+					l = serialize_table(k, o, l, d+1);
+				else
+					o[l], l = ser(k), l + 1;
+				end
+				-- =
+				o[l], o[l+1], l = kend, equals, l + 2;
+			end
+
+			-- the value
+			if vtyp == "table" then
+				l = serialize_table(v, o, l, d+1);
+			else
+				o[l], l = ser(v), l + 1;
+			end
+			-- last item?
+			if next(t, k) ~= nil then
+				o[l], l = itemsep, l + 1;
+			else
+				o[l], l = itemlast, l + 1;
+			end
+		end
+		if next(t) ~= nil then
+			o[l], l = s_rep(indentwith, d-1), l + 1;
+		end
+		o[l], l = tend, l +1;
+		return l;
+	end
+
+	function types.table(t)
+		local o = {};
+		serialize_table(t, o, 1, 1);
+		return t_concat(o);
+	end
+
+	local function serialize_string(s)
+		return '"' .. s_gsub(s, "[%z\1-\31\"\'\\\127-\255]", string_escapes) .. '"';
+	end
+
+	if hex then
+		function types.string(s)
+			local esc = serialize_string(s);
+			if #esc > (#s*2+2+#hex) then
+				return hex .. '"' .. to_hex(s) .. '"';
+			end
+			return esc;
+		end
+	else
+		types.string = serialize_string;
+	end
+
+	function types.number(t)
+		if m_type(t) == "integer" then
+			return s_format("%d", t);
+		elseif t == pos_inf then
+			return "(1/0)";
+		elseif t == neg_inf then
+			return "(-1/0)";
+		elseif t ~= t then
+			return "(0/0)";
+		end
+		local log = m_floor(m_log10(t));
+		if log > precision then
+			return s_format("%.18e", t);
+		else
+			return s_format("%.18g", t);
+		end
+	end
+
+	-- Are these faster than tostring?
+	types["nil"] = function()
+		return "nil";
+	end
+
+	function types.boolean(t)
+		return t and "true" or "false";
+	end
+
+	return ser;
 end
 
 local function deserialize(str)
@@ -92,7 +270,9 @@
 end
 
 return {
-	append = append;
-	serialize = serialize;
+	new = new;
+	serialize = function (x, opt)
+		return new(opt)(x);
+	end;
 	deserialize = deserialize;
 };