600 lines
22 KiB
Python
600 lines
22 KiB
Python
|
import asyncio
|
||
|
import asyncio.streams
|
||
|
import traceback
|
||
|
import warnings
|
||
|
from collections import deque
|
||
|
from contextlib import suppress
|
||
|
from html import escape as html_escape
|
||
|
from http import HTTPStatus
|
||
|
from logging import Logger
|
||
|
from typing import (
|
||
|
TYPE_CHECKING,
|
||
|
Any,
|
||
|
Awaitable,
|
||
|
Callable,
|
||
|
Optional,
|
||
|
Type,
|
||
|
cast,
|
||
|
)
|
||
|
|
||
|
import yarl
|
||
|
|
||
|
from .abc import AbstractAccessLogger, AbstractStreamWriter
|
||
|
from .base_protocol import BaseProtocol
|
||
|
from .helpers import CeilTimeout, current_task
|
||
|
from .http import (
|
||
|
HttpProcessingError,
|
||
|
HttpRequestParser,
|
||
|
HttpVersion10,
|
||
|
RawRequestMessage,
|
||
|
StreamWriter,
|
||
|
)
|
||
|
from .log import access_logger, server_logger
|
||
|
from .streams import EMPTY_PAYLOAD, StreamReader
|
||
|
from .tcp_helpers import tcp_keepalive
|
||
|
from .web_exceptions import HTTPException
|
||
|
from .web_log import AccessLogger
|
||
|
from .web_request import BaseRequest
|
||
|
from .web_response import Response, StreamResponse
|
||
|
|
||
|
__all__ = ('RequestHandler', 'RequestPayloadError', 'PayloadAccessError')
|
||
|
|
||
|
if TYPE_CHECKING: # pragma: no cover
|
||
|
from .web_server import Server # noqa
|
||
|
|
||
|
|
||
|
_RequestFactory = Callable[[RawRequestMessage,
|
||
|
StreamReader,
|
||
|
'RequestHandler',
|
||
|
AbstractStreamWriter,
|
||
|
'asyncio.Task[None]'],
|
||
|
BaseRequest]
|
||
|
|
||
|
_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
|
||
|
|
||
|
|
||
|
ERROR = RawRequestMessage(
|
||
|
'UNKNOWN', '/', HttpVersion10, {},
|
||
|
{}, True, False, False, False, yarl.URL('/'))
|
||
|
|
||
|
|
||
|
class RequestPayloadError(Exception):
|
||
|
"""Payload parsing error."""
|
||
|
|
||
|
|
||
|
class PayloadAccessError(Exception):
|
||
|
"""Payload was accessed after response was sent."""
|
||
|
|
||
|
|
||
|
class RequestHandler(BaseProtocol):
|
||
|
"""HTTP protocol implementation.
|
||
|
|
||
|
RequestHandler handles incoming HTTP request. It reads request line,
|
||
|
request headers and request payload and calls handle_request() method.
|
||
|
By default it always returns with 404 response.
|
||
|
|
||
|
RequestHandler handles errors in incoming request, like bad
|
||
|
status line, bad headers or incomplete payload. If any error occurs,
|
||
|
connection gets closed.
|
||
|
|
||
|
:param keepalive_timeout: number of seconds before closing
|
||
|
keep-alive connection
|
||
|
:type keepalive_timeout: int or None
|
||
|
|
||
|
:param bool tcp_keepalive: TCP keep-alive is on, default is on
|
||
|
|
||
|
:param bool debug: enable debug mode
|
||
|
|
||
|
:param logger: custom logger object
|
||
|
:type logger: aiohttp.log.server_logger
|
||
|
|
||
|
:param access_log_class: custom class for access_logger
|
||
|
:type access_log_class: aiohttp.abc.AbstractAccessLogger
|
||
|
|
||
|
:param access_log: custom logging object
|
||
|
:type access_log: aiohttp.log.server_logger
|
||
|
|
||
|
:param str access_log_format: access log format string
|
||
|
|
||
|
:param loop: Optional event loop
|
||
|
|
||
|
:param int max_line_size: Optional maximum header line size
|
||
|
|
||
|
:param int max_field_size: Optional maximum header field size
|
||
|
|
||
|
:param int max_headers: Optional maximum header size
|
||
|
|
||
|
"""
|
||
|
KEEPALIVE_RESCHEDULE_DELAY = 1
|
||
|
|
||
|
__slots__ = ('_request_count', '_keepalive', '_manager',
|
||
|
'_request_handler', '_request_factory', '_tcp_keepalive',
|
||
|
'_keepalive_time', '_keepalive_handle', '_keepalive_timeout',
|
||
|
'_lingering_time', '_messages', '_message_tail',
|
||
|
'_waiter', '_error_handler', '_task_handler',
|
||
|
'_upgrade', '_payload_parser', '_request_parser',
|
||
|
'_reading_paused', 'logger', 'debug', 'access_log',
|
||
|
'access_logger', '_close', '_force_close')
|
||
|
|
||
|
def __init__(self, manager: 'Server', *,
|
||
|
loop: asyncio.AbstractEventLoop,
|
||
|
keepalive_timeout: float=75., # NGINX default is 75 secs
|
||
|
tcp_keepalive: bool=True,
|
||
|
logger: Logger=server_logger,
|
||
|
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
|
||
|
access_log: Logger=access_logger,
|
||
|
access_log_format: str=AccessLogger.LOG_FORMAT,
|
||
|
debug: bool=False,
|
||
|
max_line_size: int=8190,
|
||
|
max_headers: int=32768,
|
||
|
max_field_size: int=8190,
|
||
|
lingering_time: float=10.0):
|
||
|
|
||
|
super().__init__(loop)
|
||
|
|
||
|
self._request_count = 0
|
||
|
self._keepalive = False
|
||
|
self._manager = manager # type: Optional[Server]
|
||
|
self._request_handler = manager.request_handler # type: Optional[_RequestHandler] # noqa
|
||
|
self._request_factory = manager.request_factory # type: Optional[_RequestFactory] # noqa
|
||
|
|
||
|
self._tcp_keepalive = tcp_keepalive
|
||
|
# placeholder to be replaced on keepalive timeout setup
|
||
|
self._keepalive_time = 0.0
|
||
|
self._keepalive_handle = None # type: Optional[asyncio.Handle]
|
||
|
self._keepalive_timeout = keepalive_timeout
|
||
|
self._lingering_time = float(lingering_time)
|
||
|
|
||
|
self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
|
||
|
self._message_tail = b''
|
||
|
|
||
|
self._waiter = None # type: Optional[asyncio.Future[None]]
|
||
|
self._error_handler = None # type: Optional[asyncio.Task[None]]
|
||
|
self._task_handler = None # type: Optional[asyncio.Task[None]]
|
||
|
|
||
|
self._upgrade = False
|
||
|
self._payload_parser = None # type: Any
|
||
|
self._request_parser = HttpRequestParser(
|
||
|
self, loop,
|
||
|
max_line_size=max_line_size,
|
||
|
max_field_size=max_field_size,
|
||
|
max_headers=max_headers,
|
||
|
payload_exception=RequestPayloadError) # type: Optional[HttpRequestParser] # noqa
|
||
|
|
||
|
self.logger = logger
|
||
|
self.debug = debug
|
||
|
self.access_log = access_log
|
||
|
if access_log:
|
||
|
self.access_logger = access_log_class(
|
||
|
access_log, access_log_format) # type: Optional[AbstractAccessLogger] # noqa
|
||
|
else:
|
||
|
self.access_logger = None
|
||
|
|
||
|
self._close = False
|
||
|
self._force_close = False
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return "<{} {}>".format(
|
||
|
self.__class__.__name__,
|
||
|
'connected' if self.transport is not None else 'disconnected')
|
||
|
|
||
|
@property
|
||
|
def keepalive_timeout(self) -> float:
|
||
|
return self._keepalive_timeout
|
||
|
|
||
|
async def shutdown(self, timeout: Optional[float]=15.0) -> None:
|
||
|
"""Worker process is about to exit, we need cleanup everything and
|
||
|
stop accepting requests. It is especially important for keep-alive
|
||
|
connections."""
|
||
|
self._force_close = True
|
||
|
|
||
|
if self._keepalive_handle is not None:
|
||
|
self._keepalive_handle.cancel()
|
||
|
|
||
|
if self._waiter:
|
||
|
self._waiter.cancel()
|
||
|
|
||
|
# wait for handlers
|
||
|
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
|
||
|
with CeilTimeout(timeout, loop=self._loop):
|
||
|
if (self._error_handler is not None and
|
||
|
not self._error_handler.done()):
|
||
|
await self._error_handler
|
||
|
|
||
|
if (self._task_handler is not None and
|
||
|
not self._task_handler.done()):
|
||
|
await self._task_handler
|
||
|
|
||
|
# force-close non-idle handler
|
||
|
if self._task_handler is not None:
|
||
|
self._task_handler.cancel()
|
||
|
|
||
|
if self.transport is not None:
|
||
|
self.transport.close()
|
||
|
self.transport = None
|
||
|
|
||
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||
|
super().connection_made(transport)
|
||
|
|
||
|
real_transport = cast(asyncio.Transport, transport)
|
||
|
if self._tcp_keepalive:
|
||
|
tcp_keepalive(real_transport)
|
||
|
|
||
|
self._task_handler = self._loop.create_task(self.start())
|
||
|
assert self._manager is not None
|
||
|
self._manager.connection_made(self, real_transport)
|
||
|
|
||
|
def connection_lost(self, exc: Optional[BaseException]) -> None:
|
||
|
if self._manager is None:
|
||
|
return
|
||
|
self._manager.connection_lost(self, exc)
|
||
|
|
||
|
super().connection_lost(exc)
|
||
|
|
||
|
self._manager = None
|
||
|
self._force_close = True
|
||
|
self._request_factory = None
|
||
|
self._request_handler = None
|
||
|
self._request_parser = None
|
||
|
|
||
|
if self._keepalive_handle is not None:
|
||
|
self._keepalive_handle.cancel()
|
||
|
|
||
|
if self._task_handler is not None:
|
||
|
self._task_handler.cancel()
|
||
|
|
||
|
if self._error_handler is not None:
|
||
|
self._error_handler.cancel()
|
||
|
|
||
|
self._task_handler = None
|
||
|
|
||
|
if self._payload_parser is not None:
|
||
|
self._payload_parser.feed_eof()
|
||
|
self._payload_parser = None
|
||
|
|
||
|
def set_parser(self, parser: Any) -> None:
|
||
|
# Actual type is WebReader
|
||
|
assert self._payload_parser is None
|
||
|
|
||
|
self._payload_parser = parser
|
||
|
|
||
|
if self._message_tail:
|
||
|
self._payload_parser.feed_data(self._message_tail)
|
||
|
self._message_tail = b''
|
||
|
|
||
|
def eof_received(self) -> None:
|
||
|
pass
|
||
|
|
||
|
def data_received(self, data: bytes) -> None:
|
||
|
if self._force_close or self._close:
|
||
|
return
|
||
|
# parse http messages
|
||
|
if self._payload_parser is None and not self._upgrade:
|
||
|
assert self._request_parser is not None
|
||
|
try:
|
||
|
messages, upgraded, tail = self._request_parser.feed_data(data)
|
||
|
except HttpProcessingError as exc:
|
||
|
# something happened during parsing
|
||
|
self._error_handler = self._loop.create_task(
|
||
|
self.handle_parse_error(
|
||
|
StreamWriter(self, self._loop),
|
||
|
400, exc, exc.message))
|
||
|
self.close()
|
||
|
except Exception as exc:
|
||
|
# 500: internal error
|
||
|
self._error_handler = self._loop.create_task(
|
||
|
self.handle_parse_error(
|
||
|
StreamWriter(self, self._loop),
|
||
|
500, exc))
|
||
|
self.close()
|
||
|
else:
|
||
|
if messages:
|
||
|
# sometimes the parser returns no messages
|
||
|
for (msg, payload) in messages:
|
||
|
self._request_count += 1
|
||
|
self._messages.append((msg, payload))
|
||
|
|
||
|
waiter = self._waiter
|
||
|
if waiter is not None:
|
||
|
if not waiter.done():
|
||
|
# don't set result twice
|
||
|
waiter.set_result(None)
|
||
|
|
||
|
self._upgrade = upgraded
|
||
|
if upgraded and tail:
|
||
|
self._message_tail = tail
|
||
|
|
||
|
# no parser, just store
|
||
|
elif self._payload_parser is None and self._upgrade and data:
|
||
|
self._message_tail += data
|
||
|
|
||
|
# feed payload
|
||
|
elif data:
|
||
|
eof, tail = self._payload_parser.feed_data(data)
|
||
|
if eof:
|
||
|
self.close()
|
||
|
|
||
|
def keep_alive(self, val: bool) -> None:
|
||
|
"""Set keep-alive connection mode.
|
||
|
|
||
|
:param bool val: new state.
|
||
|
"""
|
||
|
self._keepalive = val
|
||
|
if self._keepalive_handle:
|
||
|
self._keepalive_handle.cancel()
|
||
|
self._keepalive_handle = None
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""Stop accepting new pipelinig messages and close
|
||
|
connection when handlers done processing messages"""
|
||
|
self._close = True
|
||
|
if self._waiter:
|
||
|
self._waiter.cancel()
|
||
|
|
||
|
def force_close(self) -> None:
|
||
|
"""Force close connection"""
|
||
|
self._force_close = True
|
||
|
if self._waiter:
|
||
|
self._waiter.cancel()
|
||
|
if self.transport is not None:
|
||
|
self.transport.close()
|
||
|
self.transport = None
|
||
|
|
||
|
def log_access(self,
|
||
|
request: BaseRequest,
|
||
|
response: StreamResponse,
|
||
|
time: float) -> None:
|
||
|
if self.access_logger is not None:
|
||
|
self.access_logger.log(request, response, time)
|
||
|
|
||
|
def log_debug(self, *args: Any, **kw: Any) -> None:
|
||
|
if self.debug:
|
||
|
self.logger.debug(*args, **kw)
|
||
|
|
||
|
def log_exception(self, *args: Any, **kw: Any) -> None:
|
||
|
self.logger.exception(*args, **kw)
|
||
|
|
||
|
def _process_keepalive(self) -> None:
|
||
|
if self._force_close or not self._keepalive:
|
||
|
return
|
||
|
|
||
|
next = self._keepalive_time + self._keepalive_timeout
|
||
|
|
||
|
# handler in idle state
|
||
|
if self._waiter:
|
||
|
if self._loop.time() > next:
|
||
|
self.force_close()
|
||
|
return
|
||
|
|
||
|
# not all request handlers are done,
|
||
|
# reschedule itself to next second
|
||
|
self._keepalive_handle = self._loop.call_later(
|
||
|
self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive)
|
||
|
|
||
|
async def start(self) -> None:
|
||
|
"""Process incoming request.
|
||
|
|
||
|
It reads request line, request headers and request payload, then
|
||
|
calls handle_request() method. Subclass has to override
|
||
|
handle_request(). start() handles various exceptions in request
|
||
|
or response handling. Connection is being closed always unless
|
||
|
keep_alive(True) specified.
|
||
|
"""
|
||
|
loop = self._loop
|
||
|
handler = self._task_handler
|
||
|
assert handler is not None
|
||
|
manager = self._manager
|
||
|
assert manager is not None
|
||
|
keepalive_timeout = self._keepalive_timeout
|
||
|
resp = None
|
||
|
assert self._request_factory is not None
|
||
|
assert self._request_handler is not None
|
||
|
|
||
|
while not self._force_close:
|
||
|
if not self._messages:
|
||
|
try:
|
||
|
# wait for next request
|
||
|
self._waiter = loop.create_future()
|
||
|
await self._waiter
|
||
|
except asyncio.CancelledError:
|
||
|
break
|
||
|
finally:
|
||
|
self._waiter = None
|
||
|
|
||
|
message, payload = self._messages.popleft()
|
||
|
|
||
|
if self.access_log:
|
||
|
now = loop.time()
|
||
|
|
||
|
manager.requests_count += 1
|
||
|
writer = StreamWriter(self, loop)
|
||
|
request = self._request_factory(
|
||
|
message, payload, self, writer, handler)
|
||
|
try:
|
||
|
# a new task is used for copy context vars (#3406)
|
||
|
task = self._loop.create_task(
|
||
|
self._request_handler(request))
|
||
|
try:
|
||
|
resp = await task
|
||
|
except HTTPException as exc:
|
||
|
resp = exc
|
||
|
except (asyncio.CancelledError, ConnectionError):
|
||
|
self.log_debug('Ignored premature client disconnection')
|
||
|
break
|
||
|
except asyncio.TimeoutError as exc:
|
||
|
self.log_debug('Request handler timed out.', exc_info=exc)
|
||
|
resp = self.handle_error(request, 504)
|
||
|
except Exception as exc:
|
||
|
resp = self.handle_error(request, 500, exc)
|
||
|
else:
|
||
|
# Deprecation warning (See #2415)
|
||
|
if getattr(resp, '__http_exception__', False):
|
||
|
warnings.warn(
|
||
|
"returning HTTPException object is deprecated "
|
||
|
"(#2415) and will be removed, "
|
||
|
"please raise the exception instead",
|
||
|
DeprecationWarning)
|
||
|
|
||
|
# Drop the processed task from asyncio.Task.all_tasks() early
|
||
|
del task
|
||
|
|
||
|
if self.debug:
|
||
|
if not isinstance(resp, StreamResponse):
|
||
|
if resp is None:
|
||
|
raise RuntimeError("Missing return "
|
||
|
"statement on request handler")
|
||
|
else:
|
||
|
raise RuntimeError("Web-handler should return "
|
||
|
"a response instance, "
|
||
|
"got {!r}".format(resp))
|
||
|
try:
|
||
|
prepare_meth = resp.prepare
|
||
|
except AttributeError:
|
||
|
if resp is None:
|
||
|
raise RuntimeError("Missing return "
|
||
|
"statement on request handler")
|
||
|
else:
|
||
|
raise RuntimeError("Web-handler should return "
|
||
|
"a response instance, "
|
||
|
"got {!r}".format(resp))
|
||
|
try:
|
||
|
await prepare_meth(request)
|
||
|
await resp.write_eof()
|
||
|
except ConnectionError:
|
||
|
self.log_debug('Ignored premature client disconnection 2')
|
||
|
break
|
||
|
|
||
|
# notify server about keep-alive
|
||
|
self._keepalive = bool(resp.keep_alive)
|
||
|
|
||
|
# log access
|
||
|
if self.access_log:
|
||
|
self.log_access(request, resp, loop.time() - now)
|
||
|
|
||
|
# check payload
|
||
|
if not payload.is_eof():
|
||
|
lingering_time = self._lingering_time
|
||
|
if not self._force_close and lingering_time:
|
||
|
self.log_debug(
|
||
|
'Start lingering close timer for %s sec.',
|
||
|
lingering_time)
|
||
|
|
||
|
now = loop.time()
|
||
|
end_t = now + lingering_time
|
||
|
|
||
|
with suppress(
|
||
|
asyncio.TimeoutError, asyncio.CancelledError):
|
||
|
while not payload.is_eof() and now < end_t:
|
||
|
with CeilTimeout(end_t - now, loop=loop):
|
||
|
# read and ignore
|
||
|
await payload.readany()
|
||
|
now = loop.time()
|
||
|
|
||
|
# if payload still uncompleted
|
||
|
if not payload.is_eof() and not self._force_close:
|
||
|
self.log_debug('Uncompleted request.')
|
||
|
self.close()
|
||
|
|
||
|
payload.set_exception(PayloadAccessError())
|
||
|
|
||
|
except asyncio.CancelledError:
|
||
|
self.log_debug('Ignored premature client disconnection ')
|
||
|
break
|
||
|
except RuntimeError as exc:
|
||
|
if self.debug:
|
||
|
self.log_exception(
|
||
|
'Unhandled runtime exception', exc_info=exc)
|
||
|
self.force_close()
|
||
|
except Exception as exc:
|
||
|
self.log_exception('Unhandled exception', exc_info=exc)
|
||
|
self.force_close()
|
||
|
finally:
|
||
|
if self.transport is None and resp is not None:
|
||
|
self.log_debug('Ignored premature client disconnection.')
|
||
|
elif not self._force_close:
|
||
|
if self._keepalive and not self._close:
|
||
|
# start keep-alive timer
|
||
|
if keepalive_timeout is not None:
|
||
|
now = self._loop.time()
|
||
|
self._keepalive_time = now
|
||
|
if self._keepalive_handle is None:
|
||
|
self._keepalive_handle = loop.call_at(
|
||
|
now + keepalive_timeout,
|
||
|
self._process_keepalive)
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
# remove handler, close transport if no handlers left
|
||
|
if not self._force_close:
|
||
|
self._task_handler = None
|
||
|
if self.transport is not None and self._error_handler is None:
|
||
|
self.transport.close()
|
||
|
|
||
|
def handle_error(self,
|
||
|
request: BaseRequest,
|
||
|
status: int=500,
|
||
|
exc: Optional[BaseException]=None,
|
||
|
message: Optional[str]=None) -> StreamResponse:
|
||
|
"""Handle errors.
|
||
|
|
||
|
Returns HTTP response with specific status code. Logs additional
|
||
|
information. It always closes current connection."""
|
||
|
self.log_exception("Error handling request", exc_info=exc)
|
||
|
|
||
|
ct = 'text/plain'
|
||
|
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
|
||
|
title = '{0.value} {0.phrase}'.format(
|
||
|
HTTPStatus.INTERNAL_SERVER_ERROR
|
||
|
)
|
||
|
msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
|
||
|
tb = None
|
||
|
if self.debug:
|
||
|
with suppress(Exception):
|
||
|
tb = traceback.format_exc()
|
||
|
|
||
|
if 'text/html' in request.headers.get('Accept', ''):
|
||
|
if tb:
|
||
|
tb = html_escape(tb)
|
||
|
msg = '<h2>Traceback:</h2>\n<pre>{}</pre>'.format(tb)
|
||
|
message = (
|
||
|
"<html><head>"
|
||
|
"<title>{title}</title>"
|
||
|
"</head><body>\n<h1>{title}</h1>"
|
||
|
"\n{msg}\n</body></html>\n"
|
||
|
).format(title=title, msg=msg)
|
||
|
ct = 'text/html'
|
||
|
else:
|
||
|
if tb:
|
||
|
msg = tb
|
||
|
message = title + '\n\n' + msg
|
||
|
|
||
|
resp = Response(status=status, text=message, content_type=ct)
|
||
|
resp.force_close()
|
||
|
|
||
|
# some data already got sent, connection is broken
|
||
|
if request.writer.output_size > 0 or self.transport is None:
|
||
|
self.force_close()
|
||
|
|
||
|
return resp
|
||
|
|
||
|
async def handle_parse_error(self,
|
||
|
writer: AbstractStreamWriter,
|
||
|
status: int,
|
||
|
exc: Optional[BaseException]=None,
|
||
|
message: Optional[str]=None) -> None:
|
||
|
request = BaseRequest( # type: ignore
|
||
|
ERROR,
|
||
|
EMPTY_PAYLOAD,
|
||
|
self, writer,
|
||
|
current_task(),
|
||
|
self._loop)
|
||
|
|
||
|
resp = self.handle_error(request, status, exc, message)
|
||
|
await resp.prepare(request)
|
||
|
await resp.write_eof()
|
||
|
|
||
|
if self.transport is not None:
|
||
|
self.transport.close()
|
||
|
|
||
|
self._error_handler = None
|