Software /
code /
prosody
File
util/sql.lua @ 5915:e6fed1d80116
Back out 1b0ac7950129, as SSLv3 appears to still be in moderate use on the network. Also, although obsolete, SSLv3 isn't documented to have any weaknesses that TLS 1.0 (the most common version used today) doesn't also have. Get your act together clients!
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Tue, 12 Nov 2013 02:13:01 +0000 |
parent | 5494:9916f0a2d178 |
child | 5739:2e9ad9cb206f |
line wrap: on
line source
local setmetatable, getmetatable = setmetatable, getmetatable; local ipairs, unpack, select = ipairs, unpack, select; local tonumber, tostring = tonumber, tostring; local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; local t_concat = table.concat; local s_char = string.char; local log = require "util.logger".init("sql"); local DBI = require "DBI"; -- This loads all available drivers while globals are unlocked -- LuaDBI should be fixed to not set globals. DBI.Drivers(); local build_url = require "socket.url".build; module("sql") local column_mt = {}; local table_mt = {}; local query_mt = {}; --local op_mt = {}; local index_mt = {}; function is_column(x) return getmetatable(x)==column_mt; end function is_index(x) return getmetatable(x)==index_mt; end function is_table(x) return getmetatable(x)==table_mt; end function is_query(x) return getmetatable(x)==query_mt; end --function is_op(x) return getmetatable(x)==op_mt; end --function expr(...) return setmetatable({...}, op_mt); end function Integer(n) return "Integer()" end function String(n) return "String()" end --[[local ops = { __add = function(a, b) return "("..a.."+"..b..")" end; __sub = function(a, b) return "("..a.."-"..b..")" end; __mul = function(a, b) return "("..a.."*"..b..")" end; __div = function(a, b) return "("..a.."/"..b..")" end; __mod = function(a, b) return "("..a.."%"..b..")" end; __pow = function(a, b) return "POW("..a..","..b..")" end; __unm = function(a) return "NOT("..a..")" end; __len = function(a) return "COUNT("..a..")" end; __eq = function(a, b) return "("..a.."=="..b..")" end; __lt = function(a, b) return "("..a.."<"..b..")" end; __le = function(a, b) return "("..a.."<="..b..")" end; }; local functions = { }; local cmap = { [Integer] = Integer(); [String] = String(); };]] function Column(definition) return setmetatable(definition, column_mt); end function Table(definition) local c = {} for i,col in ipairs(definition) do if is_column(col) then c[i], c[col.name] = col, col; elseif is_index(col) then col.table = definition.name; end end return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); end function Index(definition) return setmetatable(definition, index_mt); end function table_mt:__tostring() local s = { 'name="'..self.__table__.name..'"' } for i,col in ipairs(self.__table__) do s[#s+1] = tostring(col); end return 'Table{ '..t_concat(s, ", ")..' }' end table_mt.__index = {}; function table_mt.__index:create(engine) return engine:_create_table(self); end function table_mt:__call(...) -- TODO end function column_mt:__tostring() return 'Column{ name="'..self.name..'", type="'..self.type..'" }' end function index_mt:__tostring() local s = 'Index{ name="'..self.name..'"'; for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end return s..' }'; -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end -- local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end local function parse_url(url) local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); assert(scheme, "Invalid URL format"); local username, password, host, port; local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); if not authpart then hostpart = secondpart; end if authpart then username, password = authpart:match("([^:]*):(.*)"); username = username or authpart; password = password and urldecode(password); end if hostpart then host, port = hostpart:match("([^:]*):(.*)"); host = host or hostpart; port = port and assert(tonumber(port), "Invalid URL format"); end return { scheme = scheme:lower(); username = username; password = password; host = host; port = port; database = #database > 0 and database or nil; }; end --[[local session = {}; function session.query(...) local rets = {...}; local query = setmetatable({ __rets = rets, __filters }, query_mt); return query; end -- local function db2uri(params) return build_url{ scheme = params.driver, user = params.username, password = params.password, host = params.host, port = params.port, path = params.database, }; end]] local engine = {}; function engine:connect() if self.conn then return true; end local params = self.params; assert(params.driver, "no driver") local dbh, err = DBI.Connect( params.driver, params.database, params.username, params.password, params.host, params.port ); if not dbh then return nil, err; end dbh:autocommit(false); -- don't commit automatically self.conn = dbh; self.prepared = {}; return true; end function engine:execute(sql, ...) local success, err = self:connect(); if not success then return success, err; end local prepared = self.prepared; local stmt = prepared[sql]; if not stmt then local err; stmt, err = self.conn:prepare(sql); if not stmt then return stmt, err; end prepared[sql] = stmt; end local success, err = stmt:execute(...); if not success then return success, err; end return stmt; end local result_mt = { __index = { affected = function(self) return self.__affected; end; rowcount = function(self) return self.__rowcount; end; } }; function engine:execute_query(sql, ...) if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); end local stmt = assert(self.conn:prepare(sql)); assert(stmt:execute(...)); return stmt:rows(); end function engine:execute_update(sql, ...) if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); end local prepared = self.prepared; local stmt = prepared[sql]; if not stmt then stmt = assert(self.conn:prepare(sql)); prepared[sql] = stmt; end assert(stmt:execute(...)); return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt); end engine.insert = engine.execute_update; engine.select = engine.execute_query; engine.delete = engine.execute_update; engine.update = engine.execute_update; function engine:_transaction(func, ...) if not self.conn then local a,b = self:connect(); if not a then return a,b; end end --assert(not self.__transaction, "Recursive transactions not allowed"); local args, n_args = {...}, select("#", ...); local function f() return func(unpack(args, 1, n_args)); end self.__transaction = true; local success, a, b, c = xpcall(f, debug_traceback); self.__transaction = nil; if success then log("debug", "SQL transaction success [%s]", tostring(func)); local ok, err = self.conn:commit(); if not ok then return ok, err; end -- commit failed return success, a, b, c; else log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); if self.conn then self.conn:rollback(); end return success, a; end end function engine:transaction(...) local a,b = self:_transaction(...); if not a then local conn = self.conn; if not conn or not conn:ping() then self.conn = nil; a,b = self:_transaction(...); end end return a,b; end function engine:_create_index(index) local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` ("; for i=1,#index do sql = sql.."`"..index[i].."`"; if i ~= #index then sql = sql..", "; end end sql = sql..");" if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); elseif self.params.driver == "MySQL" then sql = sql:gsub("`([,)])", "`(20)%1"); end --print(sql); return self:execute(sql); end function engine:_create_table(table) local sql = "CREATE TABLE `"..table.name.."` ("; for i,col in ipairs(table.c) do sql = sql.."`"..col.name.."` "..col.type; if col.nullable == false then sql = sql.." NOT NULL"; end if i ~= #table.c then sql = sql..", "; end end sql = sql.. ");" if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); end local success,err = self:execute(sql); if not success then return success,err; end for i,v in ipairs(table.__table__) do if is_index(v) then self:_create_index(v); end end return success; end local engine_mt = { __index = engine }; local function db2uri(params) return build_url{ scheme = params.driver, user = params.username, password = params.password, host = params.host, port = params.port, path = params.database, }; end local engine_cache = {}; -- TODO make weak valued function create_engine(self, params) local url = db2uri(params); if not engine_cache[url] then local engine = setmetatable({ url = url, params = params }, engine_mt); engine_cache[url] = engine; end return engine_cache[url]; end --[[Users = Table { name="users"; Column { name="user_id", type=String(), primary_key=true }; }; print(Users) print(Users.c.user_id)]] --local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase'); --[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" }; local i = 0; for row in assert(engine:execute("select * from sqlite_master")):rows(true) do i = i+1; print(i); for k,v in pairs(row) do print("",k,v); end end print("---") Prosody = Table { name="prosody"; Column { name="host", type="TEXT", nullable=false }; Column { name="user", type="TEXT", nullable=false }; Column { name="store", type="TEXT", nullable=false }; Column { name="key", type="TEXT", nullable=false }; Column { name="type", type="TEXT", nullable=false }; Column { name="value", type="TEXT", nullable=false }; Index { name="prosody_index", "host", "user", "store", "key" }; }; --print(Prosody); assert(engine:transaction(function() assert(Prosody:create(engine)); end)); for row in assert(engine:execute("select user from prosody")):rows(true) do print("username:", row['username']) end --result.close();]] return _M;