Comparison

plugins/mod_external_services.lua @ 11754:21a9b3f2a728

mod_external_services: Filter services by requested credentials using a Set Please don't be accidentally quadratic.
author Kim Alvefur <zash@zash.se>
date Mon, 30 Aug 2021 20:19:09 +0200
parent 11753:c4599a7c534c
child 11755:ae565e49289a
comparison
equal deleted inserted replaced
11753:c4599a7c534c 11754:21a9b3f2a728
3 local base64 = require "util.encodings".base64; 3 local base64 = require "util.encodings".base64;
4 local hashes = require "util.hashes"; 4 local hashes = require "util.hashes";
5 local st = require "util.stanza"; 5 local st = require "util.stanza";
6 local jid = require "util.jid"; 6 local jid = require "util.jid";
7 local array = require "util.array"; 7 local array = require "util.array";
8 local set = require "util.set";
8 9
9 local default_host = module:get_option_string("external_service_host", module.host); 10 local default_host = module:get_option_string("external_service_host", module.host);
10 local default_port = module:get_option_number("external_service_port"); 11 local default_port = module:get_option_number("external_service_port");
11 local default_secret = module:get_option_string("external_service_secret"); 12 local default_secret = module:get_option_string("external_service_secret");
12 local default_ttl = module:get_option_number("external_service_ttl", 86400); 13 local default_ttl = module:get_option_number("external_service_ttl", 86400);
184 local services = ( configured_services + extras ) / prepare; 185 local services = ( configured_services + extras ) / prepare;
185 services:filter(function (item) 186 services:filter(function (item)
186 return item.restricted; 187 return item.restricted;
187 end) 188 end)
188 189
189 local requested_credentials = {}; 190 local requested_credentials = set.new();
190 for service in action:childtags("service") do 191 for service in action:childtags("service") do
191 if not service.attr.type or not service.attr.host then 192 if not service.attr.type or not service.attr.host then
192 origin.send(st.error_reply(stanza, "modify", "bad-request")); 193 origin.send(st.error_reply(stanza, "modify", "bad-request"));
193 return true; 194 return true;
194 end 195 end
195 196
196 table.insert(requested_credentials, { 197 requested_credentials:add(string.format("%s:%s:%d", service.attr.type, service.attr.host,
197 type = service.attr.type; 198 tonumber(service.attr.port) or 0));
198 host = service.attr.host;
199 port = tonumber(service.attr.port);
200 });
201 end 199 end
202 200
203 setmetatable(services, services_mt); 201 setmetatable(services, services_mt);
204 setmetatable(requested_credentials, services_mt);
205 202
206 module:fire_event("external_service/credentials", { 203 module:fire_event("external_service/credentials", {
207 origin = origin; 204 origin = origin;
208 stanza = stanza; 205 stanza = stanza;
209 reply = reply; 206 reply = reply;
210 requested_credentials = requested_credentials; 207 requested_credentials = requested_credentials;
211 services = services; 208 services = services;
212 }); 209 });
213 210
214 for req_srv in action:childtags("service") do 211 services:filter(function (srv)
215 for _, srv in ipairs(services) do 212 local port_key = string.format("%s:%s:%d", srv.type, srv.host, srv.port or 0);
216 if srv.type == req_srv.attr.type and srv.host == req_srv.attr.host 213 local portless_key = string.format("%s:%s:%d", srv.type, srv.host, 0);
217 and not req_srv.attr.port or srv.port == tonumber(req_srv.attr.port) then 214 return requested_credentials:contains(port_key) or requested_credentials:contains(portless_key);
218 reply:tag("service", { 215 end);
219 type = srv.type; 216
220 transport = srv.transport; 217 for _, srv in ipairs(services) do
221 host = srv.host; 218 reply:tag("service", {
222 port = srv.port and string.format("%d", srv.port) or nil; 219 type = srv.type;
223 username = srv.username; 220 transport = srv.transport;
224 password = srv.password; 221 host = srv.host;
225 expires = srv.expires and dt.datetime(srv.expires) or nil; 222 port = srv.port and string.format("%d", srv.port) or nil;
226 restricted = srv.restricted and "1" or nil; 223 username = srv.username;
227 }):up(); 224 password = srv.password;
228 end 225 expires = srv.expires and dt.datetime(srv.expires) or nil;
229 end 226 restricted = srv.restricted and "1" or nil;
227 }):up();
230 end 228 end
231 229
232 origin.send(reply); 230 origin.send(reply);
233 return true; 231 return true;
234 end 232 end