Comparison

util/sqlite3.lua @ 12845:f306336b7e99

util.sqlite3: SQLite3-only variant of util.sql using LuaSQLite3 http://lua.sqlite.org/
author Kim Alvefur <zash@zash.se>
date Mon, 01 Aug 2022 15:23:33 +0200
child 12847:d6cdde74cd9b
comparison
equal deleted inserted replaced
12844:a3ec87ad8e48 12845:f306336b7e99
1
2 -- luacheck: ignore 113/unpack 211 212 411 213
3 local setmetatable, getmetatable = setmetatable, getmetatable;
4 local ipairs, unpack, select = ipairs, table.unpack or unpack, select;
5 local tonumber, tostring = tonumber, tostring;
6 local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
7 local error = error
8 local type = type
9 local t_concat = table.concat;
10 local t_insert = table.insert;
11 local s_char = string.char;
12 local log = require "util.logger".init("sql");
13
14 local lsqlite3 = require "lsqlite3";
15 local build_url = require "socket.url".build;
16 local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE;
17 local err2str = {
18 [0] = "OK";
19 "ERROR";
20 "INTERNAL";
21 "PERM";
22 "ABORT";
23 "BUSY";
24 "LOCKED";
25 "NOMEM";
26 "READONLY";
27 "INTERRUPT";
28 "IOERR";
29 "CORRUPT";
30 "NOTFOUND";
31 "FULL";
32 "CANTOPEN";
33 "PROTOCOL";
34 "EMPTY";
35 "SCHEMA";
36 "TOOBIG";
37 "CONSTRAINT";
38 "MISMATCH";
39 "MISUSE";
40 "NOLFS";
41 [24] = "FORMAT";
42 [25] = "RANGE";
43 [26] = "NOTADB";
44 [100] = "ROW";
45 [101] = "DONE";
46 };
47
48 local assert = function(cond, errno, err)
49 return assert(cond, err or err2str[errno]);
50 end
51 local _ENV = nil;
52 -- luacheck: std none
53
54 local column_mt = {};
55 local table_mt = {};
56 local query_mt = {};
57 --local op_mt = {};
58 local index_mt = {};
59
60 local function is_column(x) return getmetatable(x)==column_mt; end
61 local function is_index(x) return getmetatable(x)==index_mt; end
62 local function is_table(x) return getmetatable(x)==table_mt; end
63 local function is_query(x) return getmetatable(x)==query_mt; end
64 local function Integer(n) return "Integer()" end
65 local function String(n) return "String()" end
66
67 local function Column(definition)
68 return setmetatable(definition, column_mt);
69 end
70 local function Table(definition)
71 local c = {}
72 for i,col in ipairs(definition) do
73 if is_column(col) then
74 c[i], c[col.name] = col, col;
75 elseif is_index(col) then
76 col.table = definition.name;
77 end
78 end
79 return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
80 end
81 local function Index(definition)
82 return setmetatable(definition, index_mt);
83 end
84
85 function table_mt:__tostring()
86 local s = { 'name="'..self.__table__.name..'"' }
87 for i,col in ipairs(self.__table__) do
88 s[#s+1] = tostring(col);
89 end
90 return 'Table{ '..t_concat(s, ", ")..' }'
91 end
92 table_mt.__index = {};
93 function table_mt.__index:create(engine)
94 return engine:_create_table(self);
95 end
96 function table_mt:__call(...)
97 -- TODO
98 end
99 function column_mt:__tostring()
100 return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
101 end
102 function index_mt:__tostring()
103 local s = 'Index{ name="'..self.name..'"';
104 for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
105 return s..' }';
106 -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
107 end
108
109 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
110 local function parse_url(url)
111 local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
112 assert(scheme, "Invalid URL format");
113 local username, password, host, port;
114 local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
115 if not authpart then hostpart = secondpart; end
116 if authpart then
117 username, password = authpart:match("([^:]*):(.*)");
118 username = username or authpart;
119 password = password and urldecode(password);
120 end
121 if hostpart then
122 host, port = hostpart:match("([^:]*):(.*)");
123 host = host or hostpart;
124 port = port and assert(tonumber(port), "Invalid URL format");
125 end
126 return {
127 scheme = scheme:lower();
128 username = username; password = password;
129 host = host; port = port;
130 database = #database > 0 and database or nil;
131 };
132 end
133
134 local engine = {};
135 function engine:connect()
136 if self.conn then return true; end
137
138 local params = self.params;
139 assert(params.driver == "SQLite3", "Only sqlite3 is supported");
140 local dbh, err = lsqlite3.open(params.database);
141 if not dbh then return nil, err2str[err]; end
142 self.conn = dbh;
143 self.prepared = {};
144 local ok, err = self:set_encoding();
145 if not ok then
146 return ok, err;
147 end
148 local ok, err = self:onconnect();
149 if ok == false then
150 return ok, err;
151 end
152 return true;
153 end
154 function engine:onconnect()
155 -- Override from create_engine()
156 end
157 function engine:execute(sql, ...)
158 local success, err = self:connect();
159 if not success then return success, err; end
160 local prepared = self.prepared;
161
162 local stmt = prepared[sql];
163 if not stmt then
164 local err;
165 stmt, err = self.conn:prepare(sql);
166 if not stmt then return stmt, err; end
167 prepared[sql] = stmt;
168 end
169
170 local ret = stmt:bind_values(...);
171 if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end
172 return stmt;
173 end
174
175 local result_mt = {
176 __index = {
177 affected = function(self) return self.__affected; end;
178 rowcount = function(self) return self.__rowcount; end;
179 },
180 };
181
182 local function iterator(table)
183 local i=0;
184 return function()
185 i=i+1;
186 local item=table[i];
187 if item ~= nil then
188 return item;
189 end
190 end
191 end
192
193 local function debugquery(where, sql, ...)
194 local i = 0; local a = {...}
195 sql = sql:gsub("\n?\t+", " ");
196 log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
197 i = i + 1;
198 local v = a[i];
199 if type(v) == "string" then
200 v = ("'%s'"):format(v:gsub("'", "''"));
201 end
202 return tostring(v);
203 end)));
204 end
205
206 function engine:execute_query(sql, ...)
207 local prepared = self.prepared;
208 local stmt = prepared[sql];
209 if stmt and stmt:isopen() then
210 prepared[sql] = nil; -- Can't be used concurrently
211 else
212 stmt = assert(self.conn:prepare(sql));
213 end
214 local ret = stmt:bind_values(...);
215 if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
216 local data, ret = {}
217 while stmt:step() == ROW do
218 t_insert(data, stmt:get_values());
219 end
220 -- FIXME Error handling, BUSY, ERROR, MISUSE
221 if stmt:reset() == lsqlite3.OK then
222 prepared[sql] = stmt;
223 end
224 return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) });
225 end
226 function engine:execute_update(sql, ...)
227 local prepared = self.prepared;
228 local stmt = prepared[sql];
229 if not stmt or not stmt:isopen() then
230 stmt = assert(self.conn:prepare(sql));
231 else
232 prepared[sql] = nil;
233 end
234 local ret = stmt:bind_values(...);
235 if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
236 local rowcount = 0;
237 repeat
238 ret = stmt:step();
239 if ret == lsqlite3.ROW then
240 rowcount = rowcount + 1;
241 end
242 until ret ~= lsqlite3.ROW;
243 local affected = self.conn:changes();
244 if stmt:reset() == lsqlite3.OK then
245 prepared[sql] = stmt;
246 end
247 return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt);
248 end
249 engine.insert = engine.execute_update;
250 engine.select = engine.execute_query;
251 engine.delete = engine.execute_update;
252 engine.update = engine.execute_update;
253 local function debugwrap(name, f)
254 return function (self, sql, ...)
255 debugquery(name, sql, ...)
256 return f(self, sql, ...)
257 end
258 end
259 function engine:debug(enable)
260 self._debug = enable;
261 if enable then
262 engine.insert = debugwrap("insert", engine.execute_update);
263 engine.select = debugwrap("select", engine.execute_query);
264 engine.delete = debugwrap("delete", engine.execute_update);
265 engine.update = debugwrap("update", engine.execute_update);
266 else
267 engine.insert = engine.execute_update;
268 engine.select = engine.execute_query;
269 engine.delete = engine.execute_update;
270 engine.update = engine.execute_update;
271 end
272 end
273 function engine:_(word)
274 local ret = self.conn:exec(word);
275 if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end
276 return true;
277 end
278 function engine:_transaction(func, ...)
279 if not self.conn then
280 local a,b = self:connect();
281 if not a then return a,b; end
282 end
283 --assert(not self.__transaction, "Recursive transactions not allowed");
284 local ok, err = self:_"BEGIN";
285 if not ok then return ok, err; end
286 self.__transaction = true;
287 local success, a, b, c = xpcall(func, debug_traceback, ...);
288 self.__transaction = nil;
289 if success then
290 log("debug", "SQL transaction success [%s]", tostring(func));
291 local ok, err = self:_"COMMIT";
292 if not ok then return ok, err; end -- commit failed
293 return success, a, b, c;
294 else
295 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
296 if self.conn then self:_"ROLLBACK"; end
297 return success, a;
298 end
299 end
300 function engine:transaction(...)
301 local ok, ret = self:_transaction(...);
302 if not ok then
303 local conn = self.conn;
304 if not conn or not conn:isopen() then
305 self.conn = nil;
306 ok, ret = self:_transaction(...);
307 end
308 end
309 return ok, ret;
310 end
311 function engine:_create_index(index)
312 local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" (";
313 for i=1,#index do
314 sql = sql.."\""..index[i].."\"";
315 if i ~= #index then sql = sql..", "; end
316 end
317 sql = sql..");"
318 if index.unique then
319 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
320 end
321 if self._debug then
322 debugquery("create", sql);
323 end
324 return self:execute(sql);
325 end
326 function engine:_create_table(table)
327 local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" (";
328 for i,col in ipairs(table.c) do
329 local col_type = col.type;
330 sql = sql.."\""..col.name.."\" "..col_type;
331 if col.nullable == false then sql = sql.." NOT NULL"; end
332 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
333 if col.auto_increment == true then
334 sql = sql.." AUTOINCREMENT";
335 end
336 if i ~= #table.c then sql = sql..", "; end
337 end
338 sql = sql.. ");"
339 if self._debug then
340 debugquery("create", sql);
341 end
342 local success,err = self:execute(sql);
343 if not success then return success,err; end
344 for i,v in ipairs(table.__table__) do
345 if is_index(v) then
346 self:_create_index(v);
347 end
348 end
349 return success;
350 end
351 function engine:set_encoding() -- to UTF-8
352 return self:transaction(function()
353 for encoding in self:select"PRAGMA encoding;" do
354 if encoding[1] == "UTF-8" then
355 self.charset = "utf8";
356 end
357 end
358 end);
359 end
360 local engine_mt = { __index = engine };
361
362 local function db2uri(params)
363 return build_url{
364 scheme = params.driver,
365 user = params.username,
366 password = params.password,
367 host = params.host,
368 port = params.port,
369 path = params.database,
370 };
371 end
372
373 local function create_engine(_, params, onconnect)
374 assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI");
375 return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
376 end
377
378 return {
379 is_column = is_column;
380 is_index = is_index;
381 is_table = is_table;
382 is_query = is_query;
383 Integer = Integer;
384 String = String;
385 Column = Column;
386 Table = Table;
387 Index = Index;
388 create_engine = create_engine;
389 db2uri = db2uri;
390 };