Software /
code /
prosody
File
util/sqlite3.lua @ 13014:06453c564141
util.startup: Add prosody.started promise to easily execute code after startup
To avoid a race where server-started fires before the promise function body is
run (on next tick), I moved server-started to fire on the next tick, which
seems sensible anyway.
Errors are logged, I'm not sure if we ought to be doing something more here.
I'm sure we'll find out.
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Sat, 01 Apr 2023 11:56:38 +0100 |
parent | 12975:d10957394a3c |
child | 13145:af251471d5ae |
line wrap: on
line source
-- luacheck: ignore 113/unpack 211 212 411 213 local setmetatable, getmetatable = setmetatable, getmetatable; local ipairs, unpack, select = ipairs, table.unpack or unpack, select; local tonumber, tostring = tonumber, tostring; local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; local error = error local type = type local t_concat = table.concat; local t_insert = table.insert; local s_char = string.char; local log = require "prosody.util.logger".init("sql"); local lsqlite3 = require "lsqlite3"; local build_url = require "socket.url".build; local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE; -- from sqlite3.h, no copyright claimed local sqlite_errors = require"prosody.util.error".init("util.sqlite3", { -- FIXME xmpp error conditions? [1] = { code = 1; type = "modify"; condition = "ERROR"; text = "Generic error" }; [2] = { code = 2; type = "cancel"; condition = "INTERNAL"; text = "Internal logic error in SQLite" }; [3] = { code = 3; type = "auth"; condition = "PERM"; text = "Access permission denied" }; [4] = { code = 4; type = "cancel"; condition = "ABORT"; text = "Callback routine requested an abort" }; [5] = { code = 5; type = "wait"; condition = "BUSY"; text = "The database file is locked" }; [6] = { code = 6; type = "wait"; condition = "LOCKED"; text = "A table in the database is locked" }; [7] = { code = 7; type = "wait"; condition = "NOMEM"; text = "A malloc() failed" }; [8] = { code = 8; type = "cancel"; condition = "READONLY"; text = "Attempt to write a readonly database" }; [9] = { code = 9; type = "cancel"; condition = "INTERRUPT"; text = "Operation terminated by sqlite3_interrupt()" }; [10] = { code = 10; type = "wait"; condition = "IOERR"; text = "Some kind of disk I/O error occurred" }; [11] = { code = 11; type = "cancel"; condition = "CORRUPT"; text = "The database disk image is malformed" }; [12] = { code = 12; type = "modify"; condition = "NOTFOUND"; text = "Unknown opcode in sqlite3_file_control()" }; [13] = { code = 13; type = "wait"; condition = "FULL"; text = "Insertion failed because database is full" }; [14] = { code = 14; type = "auth"; condition = "CANTOPEN"; text = "Unable to open the database file" }; [15] = { code = 15; type = "cancel"; condition = "PROTOCOL"; text = "Database lock protocol error" }; [16] = { code = 16; type = "continue"; condition = "EMPTY"; text = "Internal use only" }; [17] = { code = 17; type = "modify"; condition = "SCHEMA"; text = "The database schema changed" }; [18] = { code = 18; type = "modify"; condition = "TOOBIG"; text = "String or BLOB exceeds size limit" }; [19] = { code = 19; type = "modify"; condition = "CONSTRAINT"; text = "Abort due to constraint violation" }; [20] = { code = 20; type = "modify"; condition = "MISMATCH"; text = "Data type mismatch" }; [21] = { code = 21; type = "modify"; condition = "MISUSE"; text = "Library used incorrectly" }; [22] = { code = 22; type = "cancel"; condition = "NOLFS"; text = "Uses OS features not supported on host" }; [23] = { code = 23; type = "auth"; condition = "AUTH"; text = "Authorization denied" }; [24] = { code = 24; type = "modify"; condition = "FORMAT"; text = "Not used" }; [25] = { code = 25; type = "modify"; condition = "RANGE"; text = "2nd parameter to sqlite3_bind out of range" }; [26] = { code = 26; type = "cancel"; condition = "NOTADB"; text = "File opened that is not a database file" }; [27] = { code = 27; type = "continue"; condition = "NOTICE"; text = "Notifications from sqlite3_log()" }; [28] = { code = 28; type = "continue"; condition = "WARNING"; text = "Warnings from sqlite3_log()" }; [100] = { code = 100; type = "continue"; condition = "ROW"; text = "sqlite3_step() has another row ready" }; [101] = { code = 101; type = "continue"; condition = "DONE"; text = "sqlite3_step() has finished executing" }; }); local assert = function(cond, errno, err) return assert(sqlite_errors.coerce(cond, err or errno)); end local _ENV = nil; -- luacheck: std none local column_mt = {}; local table_mt = {}; local query_mt = {}; --local op_mt = {}; local index_mt = {}; local function is_column(x) return getmetatable(x)==column_mt; end local function is_index(x) return getmetatable(x)==index_mt; end local function is_table(x) return getmetatable(x)==table_mt; end local function is_query(x) return getmetatable(x)==query_mt; end local function Integer(n) return "Integer()" end local function String(n) return "String()" end local function Column(definition) return setmetatable(definition, column_mt); end local function Table(definition) local c = {} for i,col in ipairs(definition) do if is_column(col) then c[i], c[col.name] = col, col; elseif is_index(col) then col.table = definition.name; end end return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); end local function Index(definition) return setmetatable(definition, index_mt); end function table_mt:__tostring() local s = { 'name="'..self.__table__.name..'"' } for i,col in ipairs(self.__table__) do s[#s+1] = tostring(col); end return 'Table{ '..t_concat(s, ", ")..' }' end table_mt.__index = {}; function table_mt.__index:create(engine) return engine:_create_table(self); end function table_mt:__call(...) -- TODO end function column_mt:__tostring() return 'Column{ name="'..self.name..'", type="'..self.type..'" }' end function index_mt:__tostring() local s = 'Index{ name="'..self.name..'"'; for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end return s..' }'; -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end local function parse_url(url) local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); assert(scheme, "Invalid URL format"); local username, password, host, port; local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); if not authpart then hostpart = secondpart; end if authpart then username, password = authpart:match("([^:]*):(.*)"); username = username or authpart; password = password and urldecode(password); end if hostpart then host, port = hostpart:match("([^:]*):(.*)"); host = host or hostpart; port = port and assert(tonumber(port), "Invalid URL format"); end return { scheme = scheme:lower(); username = username; password = password; host = host; port = port; database = #database > 0 and database or nil; }; end local engine = {}; function engine:connect() if self.conn then return true; end local params = self.params; assert(params.driver == "SQLite3", "Only sqlite3 is supported"); local dbh, err = sqlite_errors.coerce(lsqlite3.open(params.database)); if not dbh then return nil, err; end self.conn = dbh; self.prepared = {}; local ok, err = self:set_encoding(); if not ok then return ok, err; end local ok, err = self:onconnect(); if ok == false then return ok, err; end return true; end function engine:onconnect() -- Override from create_engine() end function engine:ondisconnect() -- luacheck: ignore 212/self -- Override from create_engine() end function engine:execute(sql, ...) local success, err = self:connect(); if not success then return success, err; end local prepared = self.prepared; if select('#', ...) == 0 then local ret = self.conn:exec(sql); if ret ~= lsqlite3.OK then local err = sqlite_errors.new(err); err.text = self.conn:errmsg(); return err; end return true; end local stmt = prepared[sql]; if not stmt then local err; stmt, err = self.conn:prepare(sql); if not stmt then err = sqlite_errors.new(err); err.text = self.conn:errmsg(); return stmt, err; end prepared[sql] = stmt; end local ret = stmt:bind_values(...); if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end return stmt; end local result_mt = { __index = { affected = function(self) return self.__affected; end; rowcount = function(self) return self.__rowcount; end; }, }; local function iterator(table) local i=0; return function() i=i+1; local item=table[i]; if item ~= nil then return item; end end end local function debugquery(where, sql, ...) local i = 0; local a = {...} sql = sql:gsub("\n?\t+", " "); log("debug", "[%s] %s", where, (sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("'%s'"):format(v:gsub("'", "''")); end return tostring(v); end))); end function engine:execute_query(sql, ...) local prepared = self.prepared; local stmt = prepared[sql]; if stmt and stmt:isopen() then prepared[sql] = nil; -- Can't be used concurrently else stmt = assert(self.conn:prepare(sql)); end local ret = stmt:bind_values(...); if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end local data, ret = {} while stmt:step() == ROW do t_insert(data, stmt:get_values()); end -- FIXME Error handling, BUSY, ERROR, MISUSE if stmt:reset() == lsqlite3.OK then prepared[sql] = stmt; end return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) }); end function engine:execute_update(sql, ...) local prepared = self.prepared; local stmt = prepared[sql]; if not stmt or not stmt:isopen() then stmt = assert(self.conn:prepare(sql)); else prepared[sql] = nil; end local ret = stmt:bind_values(...); if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end local rowcount = 0; repeat ret = stmt:step(); if ret == lsqlite3.ROW then rowcount = rowcount + 1; end until ret ~= lsqlite3.ROW; local affected = self.conn:changes(); if stmt:reset() == lsqlite3.OK then prepared[sql] = stmt; end return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt); end engine.insert = engine.execute_update; engine.select = engine.execute_query; engine.delete = engine.execute_update; engine.update = engine.execute_update; local function debugwrap(name, f) return function (self, sql, ...) debugquery(name, sql, ...) return f(self, sql, ...) end end function engine:debug(enable) self._debug = enable; if enable then engine.insert = debugwrap("insert", engine.execute_update); engine.select = debugwrap("select", engine.execute_query); engine.delete = debugwrap("delete", engine.execute_update); engine.update = debugwrap("update", engine.execute_update); else engine.insert = engine.execute_update; engine.select = engine.execute_query; engine.delete = engine.execute_update; engine.update = engine.execute_update; end end function engine:_(word) local ret = self.conn:exec(word); if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end return true; end function engine:_transaction(func, ...) if not self.conn then local a,b = self:connect(); if not a then return a,b; end end --assert(not self.__transaction, "Recursive transactions not allowed"); local ok, err = self:_"BEGIN"; if not ok then return ok, err; end self.__transaction = true; local success, a, b, c = xpcall(func, debug_traceback, ...); self.__transaction = nil; if success then log("debug", "SQL transaction success [%s]", tostring(func)); local ok, err = self:_"COMMIT"; if not ok then return ok, err; end -- commit failed return success, a, b, c; else log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); if self.conn then self:_"ROLLBACK"; end return success, a; end end function engine:transaction(...) local ok, ret = self:_transaction(...); if not ok then local conn = self.conn; if not conn or not conn:isopen() then self.conn = nil; self:ondisconnect(); ok, ret = self:_transaction(...); end end return ok, ret; end function engine:_create_index(index) local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" ("; for i=1,#index do sql = sql.."\""..index[i].."\""; if i ~= #index then sql = sql..", "; end end sql = sql..");" if index.unique then sql = sql:gsub("^CREATE", "CREATE UNIQUE"); end if self._debug then debugquery("create", sql); end return self:execute(sql); end function engine:_create_table(table) local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" ("; for i,col in ipairs(table.c) do local col_type = col.type; sql = sql.."\""..col.name.."\" "..col_type; if col.nullable == false then sql = sql.." NOT NULL"; end if col.primary_key == true then sql = sql.." PRIMARY KEY"; end if col.auto_increment == true then sql = sql.." AUTOINCREMENT"; end if i ~= #table.c then sql = sql..", "; end end sql = sql.. ");" if self._debug then debugquery("create", sql); end local success,err = self:execute(sql); if not success then return success,err; end for i,v in ipairs(table.__table__) do if is_index(v) then self:_create_index(v); end end return success; end function engine:set_encoding() -- to UTF-8 return self:transaction(function() for encoding in self:select"PRAGMA encoding;" do if encoding[1] == "UTF-8" then self.charset = "utf8"; end end end); end local engine_mt = { __index = engine }; local function db2uri(params) return build_url{ scheme = params.driver, user = params.username, password = params.password, host = params.host, port = params.port, path = params.database, }; end local function create_engine(_, params, onconnect, ondisconnect) assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI"); return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt); end return { is_column = is_column; is_index = is_index; is_table = is_table; is_query = is_query; Integer = Integer; String = String; Column = Column; Table = Table; Index = Index; create_engine = create_engine; db2uri = db2uri; };