Diff

util/set.lua @ 1028:594a07e753a0

util.set: Add metatable to sets to allow +, -, /, ==, tostring and to double as iterators
author Matthew Wild <mwild1@gmail.com>
date Wed, 22 Apr 2009 18:00:45 +0100
parent 917:f12f88b3d4a1
child 1029:4ead03974759
line wrap: on
line diff
--- a/util/set.lua	Wed Apr 22 17:46:17 2009 +0100
+++ b/util/set.lua	Wed Apr 22 18:00:45 2009 +0100
@@ -1,10 +1,60 @@
-local ipairs, pairs = 
-      ipairs, pairs;
+local ipairs, pairs, setmetatable, next, tostring = 
+      ipairs, pairs, setmetatable, next, tostring;
+local t_concat = table.concat;
 
 module "set"
 
+local set_mt = {};
+function set_mt.__call(set, _, k)
+	return next(set._items, k);
+end
+function set_mt.__add(set1, set2)
+	return _M.union(set1, set2);
+end
+function set_mt.__sub(set1, set2)
+	return _M.difference(set1, set2);
+end
+function set_mt.__div(set, func)
+	local new_set, new_items = _M.new();
+	local items, new_items = set._items, new_set._items;
+	for item in pairs(items) do
+		if func(item) then
+			new_items[item] = true;
+		end
+	end
+	return new_set;
+end
+function set_mt.__eq(set1, set2)
+	local set1, set2 = set1._items, set2._items;
+	for item in pairs(set1) do
+		if not set2[item] then
+			return false;
+		end
+	end
+	
+	for item in pairs(set2) do
+		if not set1[item] then
+			return false;
+		end
+	end
+	
+	return true;
+end
+function set_mt.__tostring(set)
+	local s, items = { }, set._items;
+	for item in pairs(items) do
+		s[#s+1] = tostring(item);
+	end
+	return t_concat(s, ", ");
+end
+
+local items_mt = {};
+function items_mt.__call(items, _, k)
+	return next(items, k);
+end
+
 function new(list)
-	local items = {};
+	local items = setmetatable({}, items_mt);
 	local set = { _items = items };
 	
 	function set:add(item)
@@ -45,7 +95,7 @@
 		set:add_list(list);
 	end
 	
-	return set;
+	return setmetatable(set, set_mt);
 end
 
 function union(set1, set2)