Diff

net/http/server.lua @ 11200:bf8f2da84007

Merge 0.11->trunk
author Kim Alvefur <zash@zash.se>
date Thu, 05 Nov 2020 22:31:25 +0100
parent 11160:e9eeaefa09a7
child 11371:73f7acf8a61f
line wrap: on
line diff
--- a/net/http/server.lua	Thu Nov 05 22:27:17 2020 +0100
+++ b/net/http/server.lua	Thu Nov 05 22:31:25 2020 +0100
@@ -13,6 +13,8 @@
 local tostring = tostring;
 local cache = require "util.cache";
 local codes = require "net.http.codes";
+local promise = require "util.promise";
+local errors = require "util.error";
 local blocksize = 2^16;
 
 local _M = {};
@@ -170,6 +172,49 @@
 	end
 });
 
+local function handle_result(request, response, result)
+	if result == nil then
+		result = 404;
+	end
+
+	if result == true then
+		return;
+	end
+
+	local body;
+	local result_type = type(result);
+	if result_type == "number" then
+		response.status_code = result;
+		if result >= 400 then
+			body = events.fire_event("http-error", { request = request, response = response, code = result });
+		end
+	elseif result_type == "string" then
+		body = result;
+	elseif errors.is_err(result) then
+		response.status_code = result.code or 500;
+		body = events.fire_event("http-error", { request = request, response = response, code = result.code or 500, error = result });
+	elseif promise.is_promise(result) then
+		result:next(function (ret)
+			handle_result(request, response, ret);
+		end, function (err)
+			response.status_code = 500;
+			handle_result(request, response, err or 500);
+		end);
+		return true;
+	elseif result_type == "table" then
+		for k, v in pairs(result) do
+			if k ~= "headers" then
+				response[k] = v;
+			else
+				for header_name, header_value in pairs(v) do
+					response.headers[header_name] = header_value;
+				end
+			end
+		end
+	end
+	return response:send(body);
+end
+
 function _M.hijack_response(response, listener) -- luacheck: ignore
 	error("TODO");
 end
@@ -194,8 +239,11 @@
 		response_conn_header = httpversion == "1.1" and "close" or nil
 	end
 
+	local is_head_request = request.method == "HEAD";
+
 	local response = {
 		request = request;
+		is_head_request = is_head_request;
 		status_code = 200;
 		headers = { date = date_header, connection = response_conn_header };
 		persistent = persistent;
@@ -227,6 +275,11 @@
 	local payload = { request = request, response = response };
 	log("debug", "Firing event: %s", global_event);
 	local result = events.fire_event(global_event, payload);
+	if result == nil and is_head_request then
+		local global_head_event = "GET "..request.path:match("[^?]*");
+		log("debug", "Firing event: %s", global_head_event);
+		result = events.fire_event(global_head_event, payload);
+	end
 	if result == nil then
 		if not hosts[host] then
 			if hosts[default_host] then
@@ -247,40 +300,17 @@
 		local host_event = request.method.." "..host..request.path:match("[^?]*");
 		log("debug", "Firing event: %s", host_event);
 		result = events.fire_event(host_event, payload);
-	end
-	if result ~= nil then
-		if result ~= true then
-			local body;
-			local result_type = type(result);
-			if result_type == "number" then
-				response.status_code = result;
-				if result >= 400 then
-					payload.code = result;
-					body = events.fire_event("http-error", payload);
-				end
-			elseif result_type == "string" then
-				body = result;
-			elseif result_type == "table" then
-				for k, v in pairs(result) do
-					if k ~= "headers" then
-						response[k] = v;
-					else
-						for header_name, header_value in pairs(v) do
-							response.headers[header_name] = header_value;
-						end
-					end
-				end
-			end
-			response:send(body);
+
+		if result == nil and is_head_request then
+			local host_head_event = "GET "..host..request.path:match("[^?]*");
+			log("debug", "Firing event: %s", host_head_event);
+			result = events.fire_event(host_head_event, payload);
 		end
-		return;
 	end
 
-	-- if handler not called, return 404
-	response.status_code = 404;
-	payload.code = 404;
-	response:send(events.fire_event("http-error", payload));
+	return handle_result(request, response, result);
 end
+
 local function prepare_header(response)
 	local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
 	local headers = response.headers;
@@ -292,12 +322,21 @@
 	return output;
 end
 _M.prepare_header = prepare_header;
+function _M.send_head_response(response)
+	if response.finished then return; end
+	local output = prepare_header(response);
+	response.conn:write(t_concat(output));
+	response:done();
+end
 function _M.send_response(response, body)
 	if response.finished then return; end
 	body = body or response.body or "";
 	-- Per RFC 7230, informational (1xx) and 204 (no content) should have no c-l header
 	if response.status_code > 199 and response.status_code ~= 204 then
-		response.headers.content_length = #body;
+		response.headers.content_length = ("%d"):format(#body);
+	end
+	if response.is_head_request then
+		return _M.send_head_response(response)
 	end
 	local output = prepare_header(response);
 	t_insert(output, body);
@@ -305,6 +344,10 @@
 	response:done();
 end
 function _M.send_file(response, f)
+	if response.is_head_request then
+		if f.close then f:close(); end
+		return _M.send_head_response(response);
+	end
 	if response.finished then return; end
 	local chunked = not response.headers.content_length;
 	if chunked then response.headers.transfer_encoding = "chunked"; end