|
|
@@ -30,6 +30,7 @@ from typing import ( |
|
|
|
Iterable, |
|
|
|
Iterator, |
|
|
|
List, |
|
|
|
NoReturn, |
|
|
|
Optional, |
|
|
|
Pattern, |
|
|
|
Tuple, |
|
|
@@ -170,7 +171,9 @@ def return_html_error( |
|
|
|
respond_with_html(request, code, body) |
|
|
|
|
|
|
|
|
|
|
|
def wrap_async_request_handler(h): |
|
|
|
def wrap_async_request_handler( |
|
|
|
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] |
|
|
|
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: |
|
|
|
"""Wraps an async request handler so that it calls request.processing. |
|
|
|
|
|
|
|
This helps ensure that work done by the request handler after the request is completed |
|
|
@@ -183,7 +186,9 @@ def wrap_async_request_handler(h): |
|
|
|
logged until the deferred completes. |
|
|
|
""" |
|
|
|
|
|
|
|
async def wrapped_async_request_handler(self, request): |
|
|
|
async def wrapped_async_request_handler( |
|
|
|
self: "_AsyncResource", request: SynapseRequest |
|
|
|
) -> None: |
|
|
|
with request.processing(): |
|
|
|
await h(self, request) |
|
|
|
|
|
|
@@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): |
|
|
|
context from the request the servlet is handling. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, extract_context=False): |
|
|
|
def __init__(self, extract_context: bool = False): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self._extract_context = extract_context |
|
|
|
|
|
|
|
def render(self, request): |
|
|
|
def render(self, request: SynapseRequest) -> int: |
|
|
|
"""This gets called by twisted every time someone sends us a request.""" |
|
|
|
defer.ensureDeferred(self._async_render_wrapper(request)) |
|
|
|
return NOT_DONE_YET |
|
|
|
|
|
|
|
@wrap_async_request_handler |
|
|
|
async def _async_render_wrapper(self, request: SynapseRequest): |
|
|
|
async def _async_render_wrapper(self, request: SynapseRequest) -> None: |
|
|
|
"""This is a wrapper that delegates to `_async_render` and handles |
|
|
|
exceptions, return values, metrics, etc. |
|
|
|
""" |
|
|
@@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): |
|
|
|
f = failure.Failure() |
|
|
|
self._send_error_response(f, request) |
|
|
|
|
|
|
|
async def _async_render(self, request: Request): |
|
|
|
async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: |
|
|
|
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if |
|
|
|
no appropriate method exists. Can be overridden in sub classes for |
|
|
|
different routing. |
|
|
@@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource): |
|
|
|
formatting responses and errors as JSON. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, canonical_json=False, extract_context=False): |
|
|
|
def __init__(self, canonical_json: bool = False, extract_context: bool = False): |
|
|
|
super().__init__(extract_context) |
|
|
|
self.canonical_json = canonical_json |
|
|
|
|
|
|
@@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource): |
|
|
|
request: SynapseRequest, |
|
|
|
code: int, |
|
|
|
response_object: Any, |
|
|
|
): |
|
|
|
) -> None: |
|
|
|
"""Implements _AsyncResource._send_response""" |
|
|
|
# TODO: Only enable CORS for the requests that need it. |
|
|
|
respond_with_json( |
|
|
@@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource): |
|
|
|
|
|
|
|
isLeaf = True |
|
|
|
|
|
|
|
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
hs: "HomeServer", |
|
|
|
canonical_json: bool = True, |
|
|
|
extract_context: bool = False, |
|
|
|
): |
|
|
|
super().__init__(canonical_json, extract_context) |
|
|
|
self.clock = hs.get_clock() |
|
|
|
self.path_regexs: Dict[bytes, List[_PathEntry]] = {} |
|
|
|
self.hs = hs |
|
|
|
|
|
|
|
def register_paths(self, method, path_patterns, callback, servlet_classname): |
|
|
|
def register_paths( |
|
|
|
self, |
|
|
|
method: str, |
|
|
|
path_patterns: Iterable[Pattern], |
|
|
|
callback: ServletCallback, |
|
|
|
servlet_classname: str, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Registers a request handler against a regular expression. Later request URLs are |
|
|
|
checked against these regular expressions in order to identify an appropriate |
|
|
|
handler for that request. |
|
|
|
|
|
|
|
Args: |
|
|
|
method (str): GET, POST etc |
|
|
|
method: GET, POST etc |
|
|
|
|
|
|
|
path_patterns (Iterable[str]): A list of regular expressions to which |
|
|
|
the request URLs are compared. |
|
|
|
path_patterns: A list of regular expressions to which the request |
|
|
|
URLs are compared. |
|
|
|
|
|
|
|
callback (function): The handler for the request. Usually a Servlet |
|
|
|
callback: The handler for the request. Usually a Servlet |
|
|
|
|
|
|
|
servlet_classname (str): The name of the handler to be used in prometheus |
|
|
|
servlet_classname: The name of the handler to be used in prometheus |
|
|
|
and opentracing logs. |
|
|
|
""" |
|
|
|
method = method.encode("utf-8") # method is bytes on py3 |
|
|
|
method_bytes = method.encode("utf-8") |
|
|
|
|
|
|
|
for path_pattern in path_patterns: |
|
|
|
logger.debug("Registering for %s %s", method, path_pattern.pattern) |
|
|
|
self.path_regexs.setdefault(method, []).append( |
|
|
|
self.path_regexs.setdefault(method_bytes, []).append( |
|
|
|
_PathEntry(path_pattern, callback, servlet_classname) |
|
|
|
) |
|
|
|
|
|
|
@@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource): |
|
|
|
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. |
|
|
|
return _unrecognised_request_handler, "unrecognised_request_handler", {} |
|
|
|
|
|
|
|
async def _async_render(self, request): |
|
|
|
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: |
|
|
|
callback, servlet_classname, group_dict = self._get_handler_for_request(request) |
|
|
|
|
|
|
|
# Make sure we have an appropriate name for this handler in prometheus |
|
|
@@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource): |
|
|
|
request: SynapseRequest, |
|
|
|
code: int, |
|
|
|
response_object: Any, |
|
|
|
): |
|
|
|
) -> None: |
|
|
|
"""Implements _AsyncResource._send_response""" |
|
|
|
# We expect to get bytes for us to write |
|
|
|
assert isinstance(response_object, bytes) |
|
|
@@ -492,12 +508,12 @@ class StaticResource(File): |
|
|
|
Differs from the File resource by adding clickjacking protection. |
|
|
|
""" |
|
|
|
|
|
|
|
def render_GET(self, request: Request): |
|
|
|
def render_GET(self, request: Request) -> bytes: |
|
|
|
set_clickjacking_protection_headers(request) |
|
|
|
return super().render_GET(request) |
|
|
|
|
|
|
|
|
|
|
|
def _unrecognised_request_handler(request): |
|
|
|
def _unrecognised_request_handler(request: Request) -> NoReturn: |
|
|
|
"""Request handler for unrecognised requests |
|
|
|
|
|
|
|
This is a request handler suitable for return from |
|
|
@@ -505,7 +521,7 @@ def _unrecognised_request_handler(request): |
|
|
|
UnrecognizedRequestError. |
|
|
|
|
|
|
|
Args: |
|
|
|
request (twisted.web.http.Request): |
|
|
|
request: Unused, but passed in to match the signature of ServletCallback. |
|
|
|
""" |
|
|
|
raise UnrecognizedRequestError() |
|
|
|
|
|
|
@@ -513,14 +529,14 @@ def _unrecognised_request_handler(request): |
|
|
|
class RootRedirect(resource.Resource): |
|
|
|
"""Redirects the root '/' path to another path.""" |
|
|
|
|
|
|
|
def __init__(self, path): |
|
|
|
def __init__(self, path: str): |
|
|
|
resource.Resource.__init__(self) |
|
|
|
self.url = path |
|
|
|
|
|
|
|
def render_GET(self, request): |
|
|
|
def render_GET(self, request: Request) -> bytes: |
|
|
|
return redirectTo(self.url.encode("ascii"), request) |
|
|
|
|
|
|
|
def getChild(self, name, request): |
|
|
|
def getChild(self, name: str, request: Request) -> resource.Resource: |
|
|
|
if len(name) == 0: |
|
|
|
return self # select ourselves as the child to render |
|
|
|
return resource.Resource.getChild(self, name, request) |
|
|
@@ -529,7 +545,7 @@ class RootRedirect(resource.Resource): |
|
|
|
class OptionsResource(resource.Resource): |
|
|
|
"""Responds to OPTION requests for itself and all children.""" |
|
|
|
|
|
|
|
def render_OPTIONS(self, request): |
|
|
|
def render_OPTIONS(self, request: Request) -> bytes: |
|
|
|
request.setResponseCode(204) |
|
|
|
request.setHeader(b"Content-Length", b"0") |
|
|
|
|
|
|
@@ -537,7 +553,7 @@ class OptionsResource(resource.Resource): |
|
|
|
|
|
|
|
return b"" |
|
|
|
|
|
|
|
def getChildWithDefault(self, path, request): |
|
|
|
def getChildWithDefault(self, path: str, request: Request) -> resource.Resource: |
|
|
|
if request.method == b"OPTIONS": |
|
|
|
return self # select ourselves as the child to render |
|
|
|
return resource.Resource.getChildWithDefault(self, path, request) |
|
|
@@ -649,7 +665,7 @@ def respond_with_json( |
|
|
|
json_object: Any, |
|
|
|
send_cors: bool = False, |
|
|
|
canonical_json: bool = True, |
|
|
|
): |
|
|
|
) -> Optional[int]: |
|
|
|
"""Sends encoded JSON in response to the given request. |
|
|
|
|
|
|
|
Args: |
|
|
@@ -696,7 +712,7 @@ def respond_with_json_bytes( |
|
|
|
code: int, |
|
|
|
json_bytes: bytes, |
|
|
|
send_cors: bool = False, |
|
|
|
): |
|
|
|
) -> Optional[int]: |
|
|
|
"""Sends encoded JSON in response to the given request. |
|
|
|
|
|
|
|
Args: |
|
|
@@ -713,7 +729,7 @@ def respond_with_json_bytes( |
|
|
|
logger.warning( |
|
|
|
"Not sending response to request %s, already disconnected.", request |
|
|
|
) |
|
|
|
return |
|
|
|
return None |
|
|
|
|
|
|
|
request.setResponseCode(code) |
|
|
|
request.setHeader(b"Content-Type", b"application/json") |
|
|
@@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread( |
|
|
|
request: SynapseRequest, |
|
|
|
json_encoder: Callable[[Any], bytes], |
|
|
|
json_object: Any, |
|
|
|
): |
|
|
|
) -> None: |
|
|
|
"""Encodes the given JSON object on a thread and then writes it to the |
|
|
|
request. |
|
|
|
|
|
|
@@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: |
|
|
|
_ByteProducer(request, bytes_generator) |
|
|
|
|
|
|
|
|
|
|
|
def set_cors_headers(request: Request): |
|
|
|
def set_cors_headers(request: Request) -> None: |
|
|
|
"""Set the CORS headers so that javascript running in a web browsers can |
|
|
|
use this API |
|
|
|
|
|
|
@@ -790,14 +806,14 @@ def set_cors_headers(request: Request): |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def respond_with_html(request: Request, code: int, html: str): |
|
|
|
def respond_with_html(request: Request, code: int, html: str) -> None: |
|
|
|
""" |
|
|
|
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. |
|
|
|
""" |
|
|
|
respond_with_html_bytes(request, code, html.encode("utf-8")) |
|
|
|
|
|
|
|
|
|
|
|
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): |
|
|
|
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None: |
|
|
|
""" |
|
|
|
Sends HTML (encoded as UTF-8 bytes) as the response to the given request. |
|
|
|
|
|
|
@@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): |
|
|
|
logger.warning( |
|
|
|
"Not sending response to request %s, already disconnected.", request |
|
|
|
) |
|
|
|
return |
|
|
|
return None |
|
|
|
|
|
|
|
request.setResponseCode(code) |
|
|
|
request.setHeader(b"Content-Type", b"text/html; charset=utf-8") |
|
|
@@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): |
|
|
|
finish_request(request) |
|
|
|
|
|
|
|
|
|
|
|
def set_clickjacking_protection_headers(request: Request): |
|
|
|
def set_clickjacking_protection_headers(request: Request) -> None: |
|
|
|
""" |
|
|
|
Set headers to guard against clickjacking of embedded content. |
|
|
|
|
|
|
@@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: |
|
|
|
finish_request(request) |
|
|
|
|
|
|
|
|
|
|
|
def finish_request(request: Request): |
|
|
|
def finish_request(request: Request) -> None: |
|
|
|
"""Finish writing the response to the request. |
|
|
|
|
|
|
|
Twisted throws a RuntimeException if the connection closed before the |
|
|
|