Implement MSC3860 to follow redirects for federated media downloads. Note that the Client-Server API doesn't support this (yet) since the media repository in Synapse doesn't have a way of supporting redirects.tags/v1.98.0rc1
@@ -0,0 +1 @@ | |||
Follow redirects when downloading media over federation (per [MSC3860](https://github.com/matrix-org/matrix-spec-proposals/pull/3860)). |
@@ -21,6 +21,7 @@ from typing import ( | |||
TYPE_CHECKING, | |||
AbstractSet, | |||
Awaitable, | |||
BinaryIO, | |||
Callable, | |||
Collection, | |||
Container, | |||
@@ -1862,6 +1863,43 @@ class FederationClient(FederationBase): | |||
return filtered_statuses, filtered_failures | |||
async def download_media( | |||
self, | |||
destination: str, | |||
media_id: str, | |||
output_stream: BinaryIO, | |||
max_size: int, | |||
max_timeout_ms: int, | |||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||
try: | |||
return await self.transport_layer.download_media_v3( | |||
destination, | |||
media_id, | |||
output_stream=output_stream, | |||
max_size=max_size, | |||
max_timeout_ms=max_timeout_ms, | |||
) | |||
except HttpResponseException as e: | |||
# If an error is received that is due to an unrecognised endpoint, | |||
# fallback to the r0 endpoint. Otherwise, consider it a legitimate error | |||
# and raise. | |||
if not is_unknown_endpoint(e): | |||
raise | |||
logger.debug( | |||
"Couldn't download media %s/%s with the v3 API, falling back to the r0 API", | |||
destination, | |||
media_id, | |||
) | |||
return await self.transport_layer.download_media_r0( | |||
destination, | |||
media_id, | |||
output_stream=output_stream, | |||
max_size=max_size, | |||
max_timeout_ms=max_timeout_ms, | |||
) | |||
@attr.s(frozen=True, slots=True, auto_attribs=True) | |||
class TimestampToEventResponse: | |||
@@ -18,6 +18,7 @@ import urllib | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
BinaryIO, | |||
Callable, | |||
Collection, | |||
Dict, | |||
@@ -804,6 +805,58 @@ class TransportLayerClient: | |||
destination=destination, path=path, data={"user_ids": user_ids} | |||
) | |||
async def download_media_r0( | |||
self, | |||
destination: str, | |||
media_id: str, | |||
output_stream: BinaryIO, | |||
max_size: int, | |||
max_timeout_ms: int, | |||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||
path = f"/_matrix/media/r0/download/{destination}/{media_id}" | |||
return await self.client.get_file( | |||
destination, | |||
path, | |||
output_stream=output_stream, | |||
max_size=max_size, | |||
args={ | |||
# tell the remote server to 404 if it doesn't | |||
# recognise the server_name, to make sure we don't | |||
# end up with a routing loop. | |||
"allow_remote": "false", | |||
"timeout_ms": str(max_timeout_ms), | |||
}, | |||
) | |||
async def download_media_v3( | |||
self, | |||
destination: str, | |||
media_id: str, | |||
output_stream: BinaryIO, | |||
max_size: int, | |||
max_timeout_ms: int, | |||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||
path = f"/_matrix/media/v3/download/{destination}/{media_id}" | |||
return await self.client.get_file( | |||
destination, | |||
path, | |||
output_stream=output_stream, | |||
max_size=max_size, | |||
args={ | |||
# tell the remote server to 404 if it doesn't | |||
# recognise the server_name, to make sure we don't | |||
# end up with a routing loop. | |||
"allow_remote": "false", | |||
"timeout_ms": str(max_timeout_ms), | |||
# Matrix 1.7 allows for this to redirect to another URL, this should | |||
# just be ignored for an old homeserver, so always provide it. | |||
"allow_redirect": "true", | |||
}, | |||
follow_redirects=True, | |||
) | |||
def _create_path(federation_prefix: str, path: str, *args: str) -> str: | |||
""" | |||
@@ -153,12 +153,18 @@ class MatrixFederationRequest: | |||
"""Query arguments. | |||
""" | |||
txn_id: Optional[str] = None | |||
"""Unique ID for this request (for logging) | |||
txn_id: str = attr.ib(init=False) | |||
"""Unique ID for this request (for logging), this is autogenerated. | |||
""" | |||
uri: bytes = attr.ib(init=False) | |||
"""The URI of this request | |||
uri: bytes = b"" | |||
"""The URI of this request, usually generated from the above information. | |||
""" | |||
_generate_uri: bool = True | |||
"""True to automatically generate the uri field based on the above information. | |||
Set to False if manually configuring the URI. | |||
""" | |||
def __attrs_post_init__(self) -> None: | |||
@@ -168,22 +174,23 @@ class MatrixFederationRequest: | |||
object.__setattr__(self, "txn_id", txn_id) | |||
destination_bytes = self.destination.encode("ascii") | |||
path_bytes = self.path.encode("ascii") | |||
query_bytes = encode_query_args(self.query) | |||
# The object is frozen so we can pre-compute this. | |||
uri = urllib.parse.urlunparse( | |||
( | |||
b"matrix-federation", | |||
destination_bytes, | |||
path_bytes, | |||
None, | |||
query_bytes, | |||
b"", | |||
if self._generate_uri: | |||
destination_bytes = self.destination.encode("ascii") | |||
path_bytes = self.path.encode("ascii") | |||
query_bytes = encode_query_args(self.query) | |||
# The object is frozen so we can pre-compute this. | |||
uri = urllib.parse.urlunparse( | |||
( | |||
b"matrix-federation", | |||
destination_bytes, | |||
path_bytes, | |||
None, | |||
query_bytes, | |||
b"", | |||
) | |||
) | |||
) | |||
object.__setattr__(self, "uri", uri) | |||
object.__setattr__(self, "uri", uri) | |||
def get_json(self) -> Optional[JsonDict]: | |||
if self.json_callback: | |||
@@ -513,6 +520,7 @@ class MatrixFederationHttpClient: | |||
ignore_backoff: bool = False, | |||
backoff_on_404: bool = False, | |||
backoff_on_all_error_codes: bool = False, | |||
follow_redirects: bool = False, | |||
) -> IResponse: | |||
""" | |||
Sends a request to the given server. | |||
@@ -555,6 +563,9 @@ class MatrixFederationHttpClient: | |||
backoff_on_404: Back off if we get a 404 | |||
backoff_on_all_error_codes: Back off if we get any error response | |||
follow_redirects: True to follow the Location header of 307/308 redirect | |||
responses. This does not recurse. | |||
Returns: | |||
Resolves with the HTTP response object on success. | |||
@@ -714,6 +725,26 @@ class MatrixFederationHttpClient: | |||
response.code, | |||
response_phrase, | |||
) | |||
elif ( | |||
response.code in (307, 308) | |||
and follow_redirects | |||
and response.headers.hasHeader("Location") | |||
): | |||
# The Location header *might* be relative so resolve it. | |||
location = response.headers.getRawHeaders(b"Location")[0] | |||
new_uri = urllib.parse.urljoin(request.uri, location) | |||
return await self._send_request( | |||
attr.evolve(request, uri=new_uri, generate_uri=False), | |||
retry_on_dns_fail, | |||
timeout, | |||
long_retries, | |||
ignore_backoff, | |||
backoff_on_404, | |||
backoff_on_all_error_codes, | |||
# Do not continue following redirects. | |||
follow_redirects=False, | |||
) | |||
else: | |||
logger.info( | |||
"{%s} [%s] Got response headers: %d %s", | |||
@@ -1383,6 +1414,7 @@ class MatrixFederationHttpClient: | |||
retry_on_dns_fail: bool = True, | |||
max_size: Optional[int] = None, | |||
ignore_backoff: bool = False, | |||
follow_redirects: bool = False, | |||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||
"""GETs a file from a given homeserver | |||
Args: | |||
@@ -1392,6 +1424,8 @@ class MatrixFederationHttpClient: | |||
args: Optional dictionary used to create the query string. | |||
ignore_backoff: true to ignore the historical backoff data | |||
and try the request anyway. | |||
follow_redirects: True to follow the Location header of 307/308 redirect | |||
responses. This does not recurse. | |||
Returns: | |||
Resolves with an (int,dict) tuple of | |||
@@ -1412,7 +1446,10 @@ class MatrixFederationHttpClient: | |||
) | |||
response = await self._send_request( | |||
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff | |||
request, | |||
retry_on_dns_fail=retry_on_dns_fail, | |||
ignore_backoff=ignore_backoff, | |||
follow_redirects=follow_redirects, | |||
) | |||
headers = dict(response.headers.getAllRawHeaders()) | |||
@@ -77,7 +77,7 @@ class MediaRepository: | |||
def __init__(self, hs: "HomeServer"): | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.client = hs.get_federation_http_client() | |||
self.client = hs.get_federation_client() | |||
self.clock = hs.get_clock() | |||
self.server_name = hs.hostname | |||
self.store = hs.get_datastores().main | |||
@@ -644,22 +644,13 @@ class MediaRepository: | |||
file_info = FileInfo(server_name=server_name, file_id=file_id) | |||
with self.media_storage.store_into_file(file_info) as (f, fname, finish): | |||
request_path = "/".join( | |||
("/_matrix/media/r0/download", server_name, media_id) | |||
) | |||
try: | |||
length, headers = await self.client.get_file( | |||
length, headers = await self.client.download_media( | |||
server_name, | |||
request_path, | |||
media_id, | |||
output_stream=f, | |||
max_size=self.max_upload_size, | |||
args={ | |||
# tell the remote server to 404 if it doesn't | |||
# recognise the server_name, to make sure we don't | |||
# end up with a routing loop. | |||
"allow_remote": "false", | |||
"timeout_ms": str(max_timeout_ms), | |||
}, | |||
max_timeout_ms=max_timeout_ms, | |||
) | |||
except RequestSendFailed as e: | |||
logger.warning( | |||
@@ -27,10 +27,11 @@ from typing_extensions import Literal | |||
from twisted.internet import defer | |||
from twisted.internet.defer import Deferred | |||
from twisted.python.failure import Failure | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.web.resource import Resource | |||
from synapse.api.errors import Codes | |||
from synapse.api.errors import Codes, HttpResponseException | |||
from synapse.events import EventBase | |||
from synapse.http.types import QueryParams | |||
from synapse.logging.context import make_deferred_yieldable | |||
@@ -247,6 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||
retry_on_dns_fail: bool = True, | |||
max_size: Optional[int] = None, | |||
ignore_backoff: bool = False, | |||
follow_redirects: bool = False, | |||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": | |||
"""A mock for MatrixFederationHttpClient.get_file.""" | |||
@@ -257,10 +259,15 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||
output_stream.write(data) | |||
return response | |||
def write_err(f: Failure) -> Failure: | |||
f.trap(HttpResponseException) | |||
output_stream.write(f.value.response) | |||
return f | |||
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() | |||
self.fetches.append((d, destination, path, args)) | |||
# Note that this callback changes the value held by d. | |||
d_after_callback = d.addCallback(write_to) | |||
d_after_callback = d.addCallbacks(write_to, write_err) | |||
return make_deferred_yieldable(d_after_callback) | |||
# Mock out the homeserver's MatrixFederationHttpClient | |||
@@ -316,10 +323,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||
self.assertEqual(len(self.fetches), 1) | |||
self.assertEqual(self.fetches[0][1], "example.com") | |||
self.assertEqual( | |||
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id | |||
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id | |||
) | |||
self.assertEqual( | |||
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"} | |||
self.fetches[0][3], | |||
{"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"}, | |||
) | |||
headers = { | |||
@@ -671,6 +679,52 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||
[b"cross-origin"], | |||
) | |||
def test_unknown_v3_endpoint(self) -> None: | |||
""" | |||
If the v3 endpoint fails, try the r0 one. | |||
""" | |||
channel = self.make_request( | |||
"GET", | |||
f"/_matrix/media/v3/download/{self.media_id}", | |||
shorthand=False, | |||
await_result=False, | |||
) | |||
self.pump() | |||
# We've made one fetch, to example.com, using the media URL, and asking | |||
# the other server not to do a remote fetch | |||
self.assertEqual(len(self.fetches), 1) | |||
self.assertEqual(self.fetches[0][1], "example.com") | |||
self.assertEqual( | |||
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id | |||
) | |||
# The result which says the endpoint is unknown. | |||
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' | |||
self.fetches[0][0].errback( | |||
HttpResponseException(404, "NOT FOUND", unknown_endpoint) | |||
) | |||
self.pump() | |||
# There should now be another request to the r0 URL. | |||
self.assertEqual(len(self.fetches), 2) | |||
self.assertEqual(self.fetches[1][1], "example.com") | |||
self.assertEqual( | |||
self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}" | |||
) | |||
headers = { | |||
b"Content-Length": [b"%d" % (len(self.test_image.data))], | |||
} | |||
self.fetches[1][0].callback( | |||
(self.test_image.data, (len(self.test_image.data), headers)) | |||
) | |||
self.pump() | |||
self.assertEqual(channel.code, 200) | |||
class TestSpamCheckerLegacy: | |||
"""A spam checker module that rejects all media that includes the bytes | |||
@@ -133,7 +133,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): | |||
self.assertEqual(request.method, b"GET") | |||
self.assertEqual( | |||
request.path, | |||
f"/_matrix/media/r0/download/{target}/{media_id}".encode(), | |||
f"/_matrix/media/v3/download/{target}/{media_id}".encode(), | |||
) | |||
self.assertEqual( | |||
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] | |||