Comparison

mod_mam_sql/mod_mam_sql.lua @ 819:1e0d273bcb75

mod_mam_sql: Fork of mod_mam using SQL.
author Kim Alvefur <zash@zash.se>
date Mon, 17 Sep 2012 20:14:26 +0200
child 1206:04bf76c3e4c6
comparison
equal deleted inserted replaced
818:bf23a8966e20 819:1e0d273bcb75
1 -- XEP-0313: Message Archive Management for Prosody
2 -- Copyright (C) 2011-2012 Kim Alvefur
3 --
4 -- This file is MIT/X11 licensed.
5
6 local xmlns_mam = "urn:xmpp:mam:tmp";
7 local xmlns_delay = "urn:xmpp:delay";
8 local xmlns_forward = "urn:xmpp:forward:0";
9
10 local st = require "util.stanza";
11 local rsm = module:require "mod_mam/rsm";
12 local jid_bare = require "util.jid".bare;
13 local jid_split = require "util.jid".split;
14 local jid_prep = require "util.jid".prep;
15 local host = module.host;
16
17 local dm_load = require "util.datamanager".load;
18 local dm_store = require "util.datamanager".store;
19 local rm_load_roster = require "core.rostermanager".load_roster;
20
21 local serialize, deserialize = require"util.json".encode, require"util.json".decode;
22 local unpack = unpack;
23 local tostring = tostring;
24 local time_now = os.time;
25 local t_insert = table.insert;
26 local m_min = math.min;
27 local timestamp, timestamp_parse = require "util.datetime".datetime, require "util.datetime".parse;
28 local default_max_items, max_max_items = 20, module:get_option_number("max_archive_query_results", 50);
29 local global_default_policy = module:get_option("default_archive_policy", false);
30 -- TODO Should be possible to enforce it too
31
32 local sql, setsql, getsql = {};
33 do -- SQL stuff
34 local dburi;
35 local connection;
36 local connections = module:shared "/*/sql/connection-cache";
37 local build_url = require"socket.url".build;
38 local resolve_relative_path = require "core.configmanager".resolve_relative_path;
39 local params = module:get_option("mam_sql", module:get_option("sql"));
40
41 local function db2uri(params)
42 return build_url{
43 scheme = params.driver,
44 user = params.username,
45 password = params.password,
46 host = params.host,
47 port = params.port,
48 path = params.database,
49 };
50 end
51
52 local function test_connection()
53 if not connection then return nil; end
54 if connection:ping() then
55 return true;
56 else
57 module:log("debug", "Database connection closed");
58 connection = nil;
59 connections[dburi] = nil;
60 end
61 end
62 local function connect()
63 if not test_connection() then
64 prosody.unlock_globals();
65 local dbh, err = DBI.Connect(
66 params.driver, params.database,
67 params.username, params.password,
68 params.host, params.port
69 );
70 prosody.lock_globals();
71 if not dbh then
72 module:log("debug", "Database connection failed: %s", tostring(err));
73 return nil, err;
74 end
75 module:log("debug", "Successfully connected to database");
76 dbh:autocommit(false); -- don't commit automatically
77 connection = dbh;
78
79 connections[dburi] = dbh;
80 end
81 return connection;
82 end
83
84 do -- process options to get a db connection
85 local ok;
86 prosody.unlock_globals();
87 ok, DBI = pcall(require, "DBI");
88 if not ok then
89 package.loaded["DBI"] = {};
90 module:log("error", "Failed to load the LuaDBI library for accessing SQL databases: %s", DBI);
91 module:log("error", "More information on installing LuaDBI can be found at http://prosody.im/doc/depends#luadbi");
92 end
93 prosody.lock_globals();
94 if not ok or not DBI.Connect then
95 return; -- Halt loading of this module
96 end
97
98 params = params or { driver = "SQLite3" };
99
100 if params.driver == "SQLite3" then
101 params.database = resolve_relative_path(prosody.paths.data or ".", params.database or "prosody.sqlite");
102 end
103
104 assert(params.driver and params.database, "Both the SQL driver and the database need to be specified");
105
106 dburi = db2uri(params);
107 connection = connections[dburi];
108
109 assert(connect());
110
111 end
112
113 function getsql(sql, ...)
114 if params.driver == "PostgreSQL" then
115 sql = sql:gsub("`", "\"");
116 end
117 -- do prepared statement stuff
118 local stmt, err = connection:prepare(sql);
119 if not stmt and not test_connection() then error("connection failed"); end
120 if not stmt then module:log("error", "QUERY FAILED: %s %s", err, debug.traceback()); return nil, err; end
121 -- run query
122 local ok, err = stmt:execute(...);
123 if not ok and not test_connection() then error("connection failed"); end
124 if not ok then return nil, err; end
125
126 return stmt;
127 end
128 function setsql(sql, ...)
129 local stmt, err = getsql(sql, ...);
130 if not stmt then return stmt, err; end
131 return stmt:affected();
132 end
133 function sql.rollback(...)
134 if connection then connection:rollback(); end -- FIXME check for rollback error?
135 return ...;
136 end
137 function sql.commit(...)
138 if not connection:commit() then return nil, "SQL commit failed"; end
139 return ...;
140 end
141
142 end
143
144 -- For translating preference names from string to boolean and back
145 local default_attrs = {
146 always = true, [true] = "always",
147 never = false, [false] = "never",
148 roster = "roster",
149 }
150
151 do
152 local prefs_format = {
153 [false] = "roster",
154 -- default ::= true | false | "roster"
155 -- true = always, false = never, nil = global default
156 ["romeo@montague.net"] = true, -- always
157 ["montague@montague.net"] = false, -- newer
158 };
159 end
160
161 local archive_store = "archive2";
162 local prefs_store = archive_store .. "_prefs";
163 local function get_prefs(user)
164 return dm_load(user, host, prefs_store) or
165 { [false] = global_default_policy };
166 end
167 local function set_prefs(user, prefs)
168 return dm_store(user, host, prefs_store, prefs);
169 end
170
171
172 -- Handle prefs.
173 module:hook("iq/self/"..xmlns_mam..":prefs", function(event)
174 local origin, stanza = event.origin, event.stanza;
175 local user = origin.username;
176 if stanza.attr.type == "get" then
177 local prefs = get_prefs(user);
178 local default = prefs[false];
179 default = default ~= nil and default_attrs[default] or global_default_policy;
180 local reply = st.reply(stanza):tag("prefs", { xmlns = xmlns_mam, default = default })
181 local always = st.stanza("always");
182 local never = st.stanza("never");
183 for k,v in pairs(prefs) do
184 if k then
185 (v and always or never):tag("jid"):text(k):up();
186 end
187 end
188 reply:add_child(always):add_child(never);
189 origin.send(reply);
190 return true
191 else -- type == "set"
192 local prefs = {};
193 local new_prefs = stanza:get_child("prefs", xmlns_mam);
194 local new_default = new_prefs.attr.default;
195 if new_default then
196 prefs[false] = default_attrs[new_default];
197 end
198
199 local always = new_prefs:get_child("always");
200 if always then
201 for rule in always:childtags("jid") do
202 local jid = rule:get_text();
203 prefs[jid] = true;
204 end
205 end
206
207 local never = new_prefs:get_child("never");
208 if never then
209 for rule in never:childtags("jid") do
210 local jid = rule:get_text();
211 prefs[jid] = false;
212 end
213 end
214
215 local ok, err = set_prefs(user, prefs);
216 if not ok then
217 origin.send(st.error_reply(stanza, "cancel", "internal-server-error", "Error storing preferences: "..tostring(err)));
218 else
219 origin.send(st.reply(stanza));
220 end
221 return true
222 end
223 end);
224
225 -- Handle archive queries
226 module:hook("iq/self/"..xmlns_mam..":query", function(event)
227 local origin, stanza = event.origin, event.stanza;
228 local query = stanza.tags[1];
229 if stanza.attr.type == "get" then
230 local qid = query.attr.queryid;
231
232 -- Search query parameters
233 local qwith = query:get_child_text("with");
234 local qstart = query:get_child_text("start");
235 local qend = query:get_child_text("end");
236 local qset = rsm.get(query);
237 module:log("debug", "Archive query, id %s with %s from %s until %s)",
238 tostring(qid), qwith or "anyone", qstart or "the dawn of time", qend or "now");
239
240 if qstart or qend then -- Validate timestamps
241 local vstart, vend = (qstart and timestamp_parse(qstart)), (qend and timestamp_parse(qend))
242 if (qstart and not vstart) or (qend and not vend) then
243 origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid timestamp"))
244 return true
245 end
246 qstart, qend = vstart, vend;
247 end
248
249 local qres;
250 if qwith then -- Validate the 'with' jid
251 local pwith = qwith and jid_prep(qwith);
252 if pwith and not qwith then -- it failed prepping
253 origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid JID"))
254 return true
255 end
256 local _, _, resource = jid_split(qwith);
257 qwith = jid_bare(pwith);
258 qres = resource;
259 end
260
261 -- RSM stuff
262 local qmax = m_min(qset and qset.max or default_max_items, max_max_items);
263 local last;
264
265 local sql_query = ([[
266 SELECT `id`, `when`, `stanza`
267 FROM `prosodyarchive`
268 WHERE `host` = ? AND `user` = ? AND `store` = ?
269 AND `when` BETWEEN ? AND ?
270 %s %s
271 AND `id` > ?
272 LIMIT ?;
273 ]]):format(qwith and [[AND `with` = ?]] or "", qres and [[AND `resource` = ?]] or "")
274
275 local p = {
276 host, origin.username, archive_store,
277 qstart or 0, qend or time_now(),
278 qset and tonumber(qset.after) or 0,
279 qmax
280 };
281 if qwith then
282 if qres then
283 t_insert(p, 6, qres);
284 end
285 t_insert(p, 6, qwith);
286 end
287 local data, err = getsql(sql_query, unpack(p));
288 if not data then
289 origin.send(st.error_reply(stanza, "cancel", "internal-server-error", "Error loading archive: "..tostring(err)));
290 return true
291 end
292
293 for item in data:rows() do
294 local id, when, orig_stanza = unpack(item);
295 --module:log("debug", "id is %s", id);
296
297 local fwd_st = st.message{ to = origin.full_jid }
298 :tag("result", { xmlns = xmlns_mam, queryid = qid, id = id }):up()
299 :tag("forwarded", { xmlns = xmlns_forward })
300 :tag("delay", { xmlns = xmlns_delay, stamp = timestamp(when) }):up();
301 orig_stanza = st.deserialize(deserialize(orig_stanza));
302 orig_stanza.attr.xmlns = "jabber:client";
303 fwd_st:add_child(orig_stanza);
304 origin.send(fwd_st);
305 last = id;
306 end
307 -- That's all folks!
308 module:log("debug", "Archive query %s completed", tostring(qid));
309
310 local reply = st.reply(stanza);
311 if last then
312 -- This is a bit redundant, isn't it?
313 reply:query(xmlns_mam):add_child(rsm.generate{last = last});
314 end
315 origin.send(reply);
316 return true
317 end
318 end);
319
320 local function has_in_roster(user, who)
321 local roster = rm_load_roster(user, host);
322 module:log("debug", "%s has %s in roster? %s", user, who, roster[who] and "yes" or "no");
323 return roster[who];
324 end
325
326 local function shall_store(user, who)
327 -- TODO Cache this?
328 local prefs = get_prefs(user);
329 local rule = prefs[who];
330 module:log("debug", "%s's rule for %s is %s", user, who, tostring(rule))
331 if rule ~= nil then
332 return rule;
333 else -- Below could be done by a metatable
334 local default = prefs[false];
335 module:log("debug", "%s's default rule is %s", user, tostring(default))
336 if default == nil then
337 default = global_default_policy;
338 module:log("debug", "Using global default rule, %s", tostring(default))
339 end
340 if default == "roster" then
341 return has_in_roster(user, who);
342 end
343 return default;
344 end
345 end
346
347 -- Handle messages
348 local function message_handler(event, c2s)
349 local origin, stanza = event.origin, event.stanza;
350 local orig_type = stanza.attr.type or "normal";
351 local orig_to = stanza.attr.to;
352 local orig_from = stanza.attr.from;
353
354 if not orig_from and c2s then
355 orig_from = origin.full_jid;
356 end
357 orig_to = orig_to or orig_from; -- Weird corner cases
358
359 -- Don't store messages of these types
360 if orig_type == "error"
361 or orig_type == "headline"
362 or orig_type == "groupchat"
363 or not stanza:get_child("body") then
364 return;
365 -- TODO Maybe headlines should be configurable?
366 end
367
368 local store_user, store_host = jid_split(c2s and orig_from or orig_to);
369 local target_jid = c2s and orig_to or orig_from;
370 local target_bare = jid_bare(target_jid);
371 local _, _, target_resource = jid_split(target_jid);
372
373 if shall_store(store_user, target_bare) then
374 module:log("debug", "Archiving stanza: %s", stanza:top_tag());
375
376 --local id = uuid();
377 local when = time_now();
378 -- And stash it
379 local ok, err = setsql([[
380 INSERT INTO `prosodyarchive`
381 (`host`, `user`, `store`, `when`, `with`, `resource`, `stanza`)
382 VALUES (?, ?, ?, ?, ?, ?, ?);
383 ]], store_host, store_user, archive_store, when, target_bare, target_resource, serialize(st.preserialize(stanza)))
384 if ok then
385 sql.commit();
386 else
387 module:log("error", "SQL error: %s", err);
388 sql.rollback();
389 end
390 --[[ This was dropped from the spec
391 if ok then
392 stanza:tag("archived", { xmlns = xmlns_mam, by = host, id = id }):up();
393 end
394 --]]
395 else
396 module:log("debug", "Not archiving stanza: %s", stanza:top_tag());
397 end
398 end
399
400 local function c2s_message_handler(event)
401 return message_handler(event, true);
402 end
403
404 -- Stanzas sent by local clients
405 module:hook("pre-message/bare", c2s_message_handler, 2);
406 module:hook("pre-message/full", c2s_message_handler, 2);
407 -- Stanszas to local clients
408 module:hook("message/bare", message_handler, 2);
409 module:hook("message/full", message_handler, 2);
410
411 module:add_feature(xmlns_mam);
412
413 -- In the telnet console, run:
414 -- >hosts["this host"].modules.mam_sql.environment.create_sql()
415 function create_sql()
416 local stm = getsql[[
417 CREATE TABLE `prosodyarchive` (
418 `host` TEXT,
419 `user` TEXT,
420 `store` TEXT,
421 `id` INTEGER PRIMARY KEY AUTOINCREMENT,
422 `when` INTEGER,
423 `with` TEXT,
424 `resource` TEXT,
425 `stanza` TEXT
426 );
427 CREATE INDEX `hus` ON `prosodyarchive` (`host`, `user`, `store`);
428 CREATE INDEX `with` ON `prosodyarchive` (`with`);
429 CREATE INDEX `thetime` ON `prosodyarchive` (`when`);
430 ]];
431 stm:execute();
432 sql.commit();
433 end