Comparison

mod_rest/mod_rest.lua @ 3910:49efd1323a1b

mod_rest: Add support for token authentication
author Matthew Wild <mwild1@gmail.com>
date Wed, 26 Feb 2020 18:36:40 +0000
parent 3909:eb27e51cf2c9
child 3911:064c32a5be7c
comparison
equal deleted inserted replaced
3909:eb27e51cf2c9 3910:49efd1323a1b
2 -- 2 --
3 -- Copyright (c) 2019-2020 Kim Alvefur 3 -- Copyright (c) 2019-2020 Kim Alvefur
4 -- 4 --
5 -- This file is MIT/X11 licensed. 5 -- This file is MIT/X11 licensed.
6 6
7 local encodings = require "util.encodings";
8 local base64 = encodings.base64;
7 local errors = require "util.error"; 9 local errors = require "util.error";
8 local http = require "net.http"; 10 local http = require "net.http";
9 local id = require "util.id"; 11 local id = require "util.id";
10 local jid = require "util.jid"; 12 local jid = require "util.jid";
11 local json = require "util.json"; 13 local json = require "util.json";
12 local st = require "util.stanza"; 14 local st = require "util.stanza";
15 local um = require "core.usermanager";
13 local xml = require "util.xml"; 16 local xml = require "util.xml";
14 17
15 local allow_any_source = module:get_host_type() == "component";
16 local validate_from_addresses = module:get_option_boolean("validate_from_addresses", true);
17 local secret = assert(module:get_option_string("rest_credentials"), "rest_credentials is a required setting");
18 local auth_type = assert(secret:match("^%S+"), "Format of rest_credentials MUST be like 'Bearer secret'");
19 assert(auth_type == "Bearer" or auth_type == "Basic", "Only 'Bearer' and 'Basic' are supported in rest_credentials");
20
21 local jsonmap = module:require"jsonmap"; 18 local jsonmap = module:require"jsonmap";
19
20 local tokens = module:depends("authtokens");
21
22 local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" });
23
24 local www_authenticate_header;
25 do
26 local header, realm = {}, module.host.."/"..module.name;
27 for mech in auth_mechanisms do
28 header[#header+1] = ("%s realm=%q"):format(mech, realm);
29 end
30 www_authenticate_header = table.concat(header, ", ");
31 end
32
22 -- Bearer token 33 -- Bearer token
23 local function check_credentials(request) 34 local function check_credentials(request)
24 return request.headers.authorization == secret; 35 local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$");
25 end 36 if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then
26 if secret == "Basic" and module:get_host_type() == "local" then 37 return false;
27 local um = require "core.usermanager"; 38 end
28 local encodings = require "util.encodings"; 39
29 local base64 = encodings.base64; 40 if auth_type == "Basic" then
30 41 local creds = base64.decode(auth_data);
31 function check_credentials(request)
32 local creds = string.match(request.headers.authorization, "^Basic%s+([A-Za-z0-9+/]+=?=?)%s*$");
33 if not creds then return false; end
34 creds = base64.decode(creds);
35 if not creds then return false; end 42 if not creds then return false; end
36 local username, password = string.match(creds, "^([^:]+):(.*)$"); 43 local username, password = string.match(creds, "^([^:]+):(.*)$");
37 if not username then return false; end 44 if not username then return false; end
38 username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password); 45 username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password);
39 if not username then return false; end 46 if not username then return false; end
40 if not um.test_password(username, module.host, password) then 47 if not um.test_password(username, module.host, password) then
41 return false; 48 return false;
42 end 49 end
43 return jid.join(username, module.host); 50 return { username = username, host = module.host };
44 end 51 elseif auth_type == "Bearer" then
52 local token_info = tokens.get_token_info(auth_data);
53 if not token_info or not token_info.session then
54 return false;
55 end
56 return token_info.session;
57 end
58 return nil;
45 end 59 end
46 60
47 local function parse(mimetype, data) 61 local function parse(mimetype, data)
48 mimetype = mimetype and mimetype:match("^[^; ]*"); 62 mimetype = mimetype and mimetype:match("^[^; ]*");
49 if mimetype == "application/xmpp+xml" then 63 if mimetype == "application/xmpp+xml" then
82 return tostring(s); 96 return tostring(s);
83 end 97 end
84 98
85 local function handle_post(event) 99 local function handle_post(event)
86 local request, response = event.request, event.response; 100 local request, response = event.request, event.response;
87 local from = module.host; 101 local from;
102 local origin;
103
88 if not request.headers.authorization then 104 if not request.headers.authorization then
89 response.headers.www_authenticate = ("%s realm=%q"):format(auth_type, module.host.."/"..module.name); 105 response.headers.www_authenticate = www_authenticate_header;
90 return 401; 106 return 401;
91 else 107 else
92 local authz = check_credentials(request); 108 origin = check_credentials(request);
93 if not authz then 109 if not origin then
94 return 401; 110 return 401;
95 end 111 end
96 if type(authz) == "string" then 112 from = jid.join(origin.username, origin.host, origin.resource);
97 from = authz;
98 end
99 end 113 end
100 local payload, err = parse(request.headers.content_type, request.body); 114 local payload, err = parse(request.headers.content_type, request.body);
101 if not payload then 115 if not payload then
102 -- parse fail 116 -- parse fail
103 return errors.new({ code = 400, text = "Failed to parse payload" }, { error = err, type = request.headers.content_type, data = request.body }); 117 return errors.new({ code = 400, text = "Failed to parse payload" }, { error = err, type = request.headers.content_type, data = request.body });
109 end 123 end
110 local to = jid.prep(payload.attr.to); 124 local to = jid.prep(payload.attr.to);
111 if not to then 125 if not to then
112 return errors.new({ code = 422, text = "Invalid destination JID" }); 126 return errors.new({ code = 422, text = "Invalid destination JID" });
113 end 127 end
114 if allow_any_source and payload.attr.from then 128 if payload.attr.from then
115 from = jid.prep(payload.attr.from); 129 local requested_from = jid.prep(payload.attr.from);
116 if not from then 130 if not requested_from then
117 return errors.new({ code = 422, text = "Invalid source JID" }); 131 return errors.new({ code = 422, text = "Invalid source JID" });
118 end 132 end
119 if validate_from_addresses and not jid.compare(from, module.host) then 133 if jid.compare(requested_from, from) then
120 return errors.new({ code = 403, text = "Source JID must belong to current host" }); 134 from = requested_from;
135 else
136 return errors.new({ code = 403, text = "Not authorized to send from "..requested_from });
121 end 137 end
122 end 138 end
123 payload.attr = { 139 payload.attr = {
124 from = from, 140 from = from,
125 to = to, 141 to = to,
128 ["xml:lang"] = payload.attr["xml:lang"], 144 ["xml:lang"] = payload.attr["xml:lang"],
129 }; 145 };
130 module:log("debug", "Received[rest]: %s", payload:top_tag()); 146 module:log("debug", "Received[rest]: %s", payload:top_tag());
131 local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type) 147 local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type)
132 if payload.name == "iq" then 148 if payload.name == "iq" then
149 function origin.send(stanza)
150 prosody.core_route_stanza(nil, stanza);
151 end
133 if payload.attr.type ~= "get" and payload.attr.type ~= "set" then 152 if payload.attr.type ~= "get" and payload.attr.type ~= "set" then
134 return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" }); 153 return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" });
135 elseif #payload.tags ~= 1 then 154 elseif #payload.tags ~= 1 then
136 return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" }); 155 return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" });
137 end 156 end
138 return module:send_iq(payload):next( 157 return module:send_iq(payload, origin):next(
139 function (result) 158 function (result)
140 module:log("debug", "Sending[rest]: %s", result.stanza:top_tag()); 159 module:log("debug", "Sending[rest]: %s", result.stanza:top_tag());
141 response.headers.content_type = send_type; 160 response.headers.content_type = send_type;
142 return encode(send_type, result.stanza); 161 return encode(send_type, result.stanza);
143 end, 162 end,
152 else 171 else
153 return error; 172 return error;
154 end 173 end
155 end); 174 end);
156 else 175 else
157 local origin = {};
158 function origin.send(stanza) 176 function origin.send(stanza)
159 module:log("debug", "Sending[rest]: %s", stanza:top_tag()); 177 module:log("debug", "Sending[rest]: %s", stanza:top_tag());
160 response.headers.content_type = send_type; 178 response.headers.content_type = send_type;
161 response:send(encode(send_type, stanza)); 179 response:send(encode(send_type, stanza));
162 return true; 180 return true;