Comparison

mod_storage_s3/mod_storage_s3.lua @ 5718:b4632d5f840b

mod_storage_s3: Move request signing into a net.http hook
author Kim Alvefur <zash@zash.se>
date Sat, 11 Nov 2023 17:01:29 +0100
parent 5699:799f69a5921a
child 5719:66986f5271c3
comparison
equal deleted inserted replaced
5717:8afa0fb8a73e 5718:b4632d5f840b
23 local region = module:get_option_string("s3_region", "us-east-1"); 23 local region = module:get_option_string("s3_region", "us-east-1");
24 24
25 local access_key = module:get_option_string("s3_access_key"); 25 local access_key = module:get_option_string("s3_access_key");
26 local secret_key = module:get_option_string("s3_secret_key"); 26 local secret_key = module:get_option_string("s3_secret_key");
27 27
28 function driver:open(store, typ)
29 local mt = self[typ or "keyval"]
30 if not mt then
31 return nil, "unsupported-store";
32 end
33 return setmetatable({ store = store; bucket = bucket; type = typ }, mt);
34 end
35
36 local keyval = { };
37 driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };
38
39 local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s"; 28 local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s";
40 29
41 local function new_request(method, path, query, payload) 30 local function aws_auth(event)
42 local request = url.parse(base_uri); 31 local request, options = event.request, event.options;
43 request.path = path; 32 local method = options.method or "GET";
33 local query = options.query;
34 local payload = options.body;
44 35
45 local payload_type = nil; 36 local payload_type = nil;
46 if st.is_stanza(payload) then 37 if st.is_stanza(payload) then
47 payload_type = "application/xml"; 38 payload_type = "application/xml";
48 payload = tostring(payload); 39 payload = tostring(payload);
49 elseif payload ~= nil then 40 elseif payload ~= nil then
50 payload_type = "application/json"; 41 payload_type = "application/json";
51 payload = json.encode(payload); 42 payload = json.encode(payload);
52 end 43 end
44 options.body = payload;
53 45
54 local payload_hash = sha256(payload or "", true); 46 local payload_hash = sha256(payload or "", true);
55 47
56 local now = os.time(); 48 local now = os.time();
57 local aws_datetime = os.date("!%Y%m%dT%H%M%SZ", now); 49 local aws_datetime = os.date("!%Y%m%dT%H%M%SZ", now);
110 102
111 local signature = hmac_sha256(signing_key, signature_payload, true); 103 local signature = hmac_sha256(signing_key, signature_payload, true);
112 104
113 headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature); 105 headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature);
114 106
115 return http.request(url.build(request), { method = method; headers = headers; body = payload }); 107 options.headers = headers;
108 end
109
110 function driver:open(store, typ)
111 local mt = self[typ or "keyval"]
112 if not mt then
113 return nil, "unsupported-store";
114 end
115 local httpclient = http.new({});
116 httpclient.events.add_handler("pre-request", aws_auth);
117 return setmetatable({ store = store; bucket = bucket; type = typ; http = httpclient }, mt);
118 end
119
120 local keyval = { };
121 driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };
122
123 local function new_request(self, method, path, query, payload)
124 local request = url.parse(base_uri);
125 request.path = path;
126
127 return self.http:request(url.build(request), { method = method; body = payload; query = query });
116 end 128 end
117 129
118 -- coerce result back into Prosody data type 130 -- coerce result back into Prosody data type
119 local function on_result(response) 131 local function on_result(response)
120 if response.code == 404 and response.request.method == "GET" then 132 if response.code == 404 and response.request.method == "GET" then
145 jid.escape(key or "@"); 157 jid.escape(key or "@");
146 }) 158 })
147 end 159 end
148 160
149 function keyval:get(user) 161 function keyval:get(user)
150 return async.wait_for(new_request("GET", self:_path(user)):next(on_result)); 162 return async.wait_for(new_request(self, "GET", self:_path(user)):next(on_result));
151 end 163 end
152 164
153 function keyval:set(user, data) 165 function keyval:set(user, data)
154 166
155 if data == nil or (type(data) == "table" and next(data) == nil) then 167 if data == nil or (type(data) == "table" and next(data) == nil) then
156 return async.wait_for(new_request("DELETE", self:_path(user))); 168 return async.wait_for(new_request(self, "DELETE", self:_path(user)));
157 end 169 end
158 170
159 return async.wait_for(new_request("PUT", self:_path(user), nil, data)); 171 return async.wait_for(new_request(self, "PUT", self:_path(user), nil, data));
160 end 172 end
161 173
162 function keyval:users() 174 function keyval:users()
163 local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true }); 175 local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
164 local prefix = url.build_path({ jid.escape(module.host); jid.escape(self.store); is_directory = true }); 176 local prefix = url.build_path({ jid.escape(module.host); jid.escape(self.store); is_directory = true });
165 local list_result, err = async.wait_for(new_request("GET", bucket_path, { prefix = prefix })) 177 local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix }))
166 if err or list_result.code ~= 200 then 178 if err or list_result.code ~= 200 then
167 return nil, err; 179 return nil, err;
168 end 180 end
169 local list_bucket_result = xml.parse(list_result.body); 181 local list_bucket_result = xml.parse(list_result.body);
170 if list_bucket_result:get_child_text("IsTruncated") == "true" then 182 if list_bucket_result:get_child_text("IsTruncated") == "true" then
206 local wrapper = st.stanza("wrapper"); 218 local wrapper = st.stanza("wrapper");
207 -- Minio had trouble with timestamps, probably the ':' characters, in paths. 219 -- Minio had trouble with timestamps, probably the ':' characters, in paths.
208 wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up(); 220 wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up();
209 wrapper:add_direct_child(value); 221 wrapper:add_direct_child(value);
210 key = key or new_uuid(); 222 key = key or new_uuid();
211 return async.wait_for(new_request("PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r) 223 return async.wait_for(new_request(self, "PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
212 if r.code == 200 then 224 if r.code == 200 then
213 return key; 225 return key;
214 else 226 else
215 error(r.body); 227 error(r.body);
216 end 228 end
230 table.insert(prefix, sha256(jid.prep(query["with"]), true):sub(1,24)); 242 table.insert(prefix, sha256(jid.prep(query["with"]), true):sub(1,24));
231 end 243 end
232 end 244 end
233 245
234 prefix = url.build_path(prefix); 246 prefix = url.build_path(prefix);
235 local list_result, err = async.wait_for(new_request("GET", bucket_path, { 247 local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, {
236 prefix = prefix; 248 prefix = prefix;
237 ["max-keys"] = query["max"] and tostring(query["max"]); 249 ["max-keys"] = query["max"] and tostring(query["max"]);
238 })); 250 }));
239 if err or list_result.code ~= 200 then 251 if err or list_result.code ~= 200 then
240 return nil, err; 252 return nil, err;
274 local item = keys[i]; 286 local item = keys[i];
275 if item == nil then 287 if item == nil then
276 return nil; 288 return nil;
277 end 289 end
278 -- luacheck: ignore 431/err 290 -- luacheck: ignore 431/err
279 local value, err = async.wait_for(new_request("GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result)); 291 local value, err = async.wait_for(new_request(self, "GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
280 if not value then 292 if not value then
281 module:log("error", "%s", err); 293 module:log("error", "%s", err);
282 return nil; 294 return nil;
283 end 295 end
284 local delay = value:get_child("delay", "urn:xmpp:delay"); 296 local delay = value:get_child("delay", "urn:xmpp:delay");