Implement MSC3860 to follow redirects for federated media downloads. Note that the Client-Server API doesn't support this (yet) since the media repository in Synapse doesn't have a way of supporting redirects.tags/v1.98.0rc1
@@ -0,0 +1 @@ | |||||
Follow redirects when downloading media over federation (per [MSC3860](https://github.com/matrix-org/matrix-spec-proposals/pull/3860)). |
@@ -21,6 +21,7 @@ from typing import ( | |||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
AbstractSet, | AbstractSet, | ||||
Awaitable, | Awaitable, | ||||
BinaryIO, | |||||
Callable, | Callable, | ||||
Collection, | Collection, | ||||
Container, | Container, | ||||
@@ -1862,6 +1863,43 @@ class FederationClient(FederationBase): | |||||
return filtered_statuses, filtered_failures | return filtered_statuses, filtered_failures | ||||
async def download_media( | |||||
self, | |||||
destination: str, | |||||
media_id: str, | |||||
output_stream: BinaryIO, | |||||
max_size: int, | |||||
max_timeout_ms: int, | |||||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||||
try: | |||||
return await self.transport_layer.download_media_v3( | |||||
destination, | |||||
media_id, | |||||
output_stream=output_stream, | |||||
max_size=max_size, | |||||
max_timeout_ms=max_timeout_ms, | |||||
) | |||||
except HttpResponseException as e: | |||||
# If an error is received that is due to an unrecognised endpoint, | |||||
# fallback to the r0 endpoint. Otherwise, consider it a legitimate error | |||||
# and raise. | |||||
if not is_unknown_endpoint(e): | |||||
raise | |||||
logger.debug( | |||||
"Couldn't download media %s/%s with the v3 API, falling back to the r0 API", | |||||
destination, | |||||
media_id, | |||||
) | |||||
return await self.transport_layer.download_media_r0( | |||||
destination, | |||||
media_id, | |||||
output_stream=output_stream, | |||||
max_size=max_size, | |||||
max_timeout_ms=max_timeout_ms, | |||||
) | |||||
@attr.s(frozen=True, slots=True, auto_attribs=True) | @attr.s(frozen=True, slots=True, auto_attribs=True) | ||||
class TimestampToEventResponse: | class TimestampToEventResponse: | ||||
@@ -18,6 +18,7 @@ import urllib | |||||
from typing import ( | from typing import ( | ||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
Any, | Any, | ||||
BinaryIO, | |||||
Callable, | Callable, | ||||
Collection, | Collection, | ||||
Dict, | Dict, | ||||
@@ -804,6 +805,58 @@ class TransportLayerClient: | |||||
destination=destination, path=path, data={"user_ids": user_ids} | destination=destination, path=path, data={"user_ids": user_ids} | ||||
) | ) | ||||
async def download_media_r0( | |||||
self, | |||||
destination: str, | |||||
media_id: str, | |||||
output_stream: BinaryIO, | |||||
max_size: int, | |||||
max_timeout_ms: int, | |||||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||||
path = f"/_matrix/media/r0/download/{destination}/{media_id}" | |||||
return await self.client.get_file( | |||||
destination, | |||||
path, | |||||
output_stream=output_stream, | |||||
max_size=max_size, | |||||
args={ | |||||
# tell the remote server to 404 if it doesn't | |||||
# recognise the server_name, to make sure we don't | |||||
# end up with a routing loop. | |||||
"allow_remote": "false", | |||||
"timeout_ms": str(max_timeout_ms), | |||||
}, | |||||
) | |||||
async def download_media_v3( | |||||
self, | |||||
destination: str, | |||||
media_id: str, | |||||
output_stream: BinaryIO, | |||||
max_size: int, | |||||
max_timeout_ms: int, | |||||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | |||||
path = f"/_matrix/media/v3/download/{destination}/{media_id}" | |||||
return await self.client.get_file( | |||||
destination, | |||||
path, | |||||
output_stream=output_stream, | |||||
max_size=max_size, | |||||
args={ | |||||
# tell the remote server to 404 if it doesn't | |||||
# recognise the server_name, to make sure we don't | |||||
# end up with a routing loop. | |||||
"allow_remote": "false", | |||||
"timeout_ms": str(max_timeout_ms), | |||||
# Matrix 1.7 allows for this to redirect to another URL, this should | |||||
# just be ignored for an old homeserver, so always provide it. | |||||
"allow_redirect": "true", | |||||
}, | |||||
follow_redirects=True, | |||||
) | |||||
def _create_path(federation_prefix: str, path: str, *args: str) -> str: | def _create_path(federation_prefix: str, path: str, *args: str) -> str: | ||||
""" | """ | ||||
@@ -153,12 +153,18 @@ class MatrixFederationRequest: | |||||
"""Query arguments. | """Query arguments. | ||||
""" | """ | ||||
txn_id: Optional[str] = None | |||||
"""Unique ID for this request (for logging) | |||||
txn_id: str = attr.ib(init=False) | |||||
"""Unique ID for this request (for logging), this is autogenerated. | |||||
""" | """ | ||||
uri: bytes = attr.ib(init=False) | |||||
"""The URI of this request | |||||
uri: bytes = b"" | |||||
"""The URI of this request, usually generated from the above information. | |||||
""" | |||||
_generate_uri: bool = True | |||||
"""True to automatically generate the uri field based on the above information. | |||||
Set to False if manually configuring the URI. | |||||
""" | """ | ||||
def __attrs_post_init__(self) -> None: | def __attrs_post_init__(self) -> None: | ||||
@@ -168,22 +174,23 @@ class MatrixFederationRequest: | |||||
object.__setattr__(self, "txn_id", txn_id) | object.__setattr__(self, "txn_id", txn_id) | ||||
destination_bytes = self.destination.encode("ascii") | |||||
path_bytes = self.path.encode("ascii") | |||||
query_bytes = encode_query_args(self.query) | |||||
# The object is frozen so we can pre-compute this. | |||||
uri = urllib.parse.urlunparse( | |||||
( | |||||
b"matrix-federation", | |||||
destination_bytes, | |||||
path_bytes, | |||||
None, | |||||
query_bytes, | |||||
b"", | |||||
if self._generate_uri: | |||||
destination_bytes = self.destination.encode("ascii") | |||||
path_bytes = self.path.encode("ascii") | |||||
query_bytes = encode_query_args(self.query) | |||||
# The object is frozen so we can pre-compute this. | |||||
uri = urllib.parse.urlunparse( | |||||
( | |||||
b"matrix-federation", | |||||
destination_bytes, | |||||
path_bytes, | |||||
None, | |||||
query_bytes, | |||||
b"", | |||||
) | |||||
) | ) | ||||
) | |||||
object.__setattr__(self, "uri", uri) | |||||
object.__setattr__(self, "uri", uri) | |||||
def get_json(self) -> Optional[JsonDict]: | def get_json(self) -> Optional[JsonDict]: | ||||
if self.json_callback: | if self.json_callback: | ||||
@@ -513,6 +520,7 @@ class MatrixFederationHttpClient: | |||||
ignore_backoff: bool = False, | ignore_backoff: bool = False, | ||||
backoff_on_404: bool = False, | backoff_on_404: bool = False, | ||||
backoff_on_all_error_codes: bool = False, | backoff_on_all_error_codes: bool = False, | ||||
follow_redirects: bool = False, | |||||
) -> IResponse: | ) -> IResponse: | ||||
""" | """ | ||||
Sends a request to the given server. | Sends a request to the given server. | ||||
@@ -555,6 +563,9 @@ class MatrixFederationHttpClient: | |||||
backoff_on_404: Back off if we get a 404 | backoff_on_404: Back off if we get a 404 | ||||
backoff_on_all_error_codes: Back off if we get any error response | backoff_on_all_error_codes: Back off if we get any error response | ||||
follow_redirects: True to follow the Location header of 307/308 redirect | |||||
responses. This does not recurse. | |||||
Returns: | Returns: | ||||
Resolves with the HTTP response object on success. | Resolves with the HTTP response object on success. | ||||
@@ -714,6 +725,26 @@ class MatrixFederationHttpClient: | |||||
response.code, | response.code, | ||||
response_phrase, | response_phrase, | ||||
) | ) | ||||
elif ( | |||||
response.code in (307, 308) | |||||
and follow_redirects | |||||
and response.headers.hasHeader("Location") | |||||
): | |||||
# The Location header *might* be relative so resolve it. | |||||
location = response.headers.getRawHeaders(b"Location")[0] | |||||
new_uri = urllib.parse.urljoin(request.uri, location) | |||||
return await self._send_request( | |||||
attr.evolve(request, uri=new_uri, generate_uri=False), | |||||
retry_on_dns_fail, | |||||
timeout, | |||||
long_retries, | |||||
ignore_backoff, | |||||
backoff_on_404, | |||||
backoff_on_all_error_codes, | |||||
# Do not continue following redirects. | |||||
follow_redirects=False, | |||||
) | |||||
else: | else: | ||||
logger.info( | logger.info( | ||||
"{%s} [%s] Got response headers: %d %s", | "{%s} [%s] Got response headers: %d %s", | ||||
@@ -1383,6 +1414,7 @@ class MatrixFederationHttpClient: | |||||
retry_on_dns_fail: bool = True, | retry_on_dns_fail: bool = True, | ||||
max_size: Optional[int] = None, | max_size: Optional[int] = None, | ||||
ignore_backoff: bool = False, | ignore_backoff: bool = False, | ||||
follow_redirects: bool = False, | |||||
) -> Tuple[int, Dict[bytes, List[bytes]]]: | ) -> Tuple[int, Dict[bytes, List[bytes]]]: | ||||
"""GETs a file from a given homeserver | """GETs a file from a given homeserver | ||||
Args: | Args: | ||||
@@ -1392,6 +1424,8 @@ class MatrixFederationHttpClient: | |||||
args: Optional dictionary used to create the query string. | args: Optional dictionary used to create the query string. | ||||
ignore_backoff: true to ignore the historical backoff data | ignore_backoff: true to ignore the historical backoff data | ||||
and try the request anyway. | and try the request anyway. | ||||
follow_redirects: True to follow the Location header of 307/308 redirect | |||||
responses. This does not recurse. | |||||
Returns: | Returns: | ||||
Resolves with an (int,dict) tuple of | Resolves with an (int,dict) tuple of | ||||
@@ -1412,7 +1446,10 @@ class MatrixFederationHttpClient: | |||||
) | ) | ||||
response = await self._send_request( | response = await self._send_request( | ||||
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff | |||||
request, | |||||
retry_on_dns_fail=retry_on_dns_fail, | |||||
ignore_backoff=ignore_backoff, | |||||
follow_redirects=follow_redirects, | |||||
) | ) | ||||
headers = dict(response.headers.getAllRawHeaders()) | headers = dict(response.headers.getAllRawHeaders()) | ||||
@@ -77,7 +77,7 @@ class MediaRepository: | |||||
def __init__(self, hs: "HomeServer"): | def __init__(self, hs: "HomeServer"): | ||||
self.hs = hs | self.hs = hs | ||||
self.auth = hs.get_auth() | self.auth = hs.get_auth() | ||||
self.client = hs.get_federation_http_client() | |||||
self.client = hs.get_federation_client() | |||||
self.clock = hs.get_clock() | self.clock = hs.get_clock() | ||||
self.server_name = hs.hostname | self.server_name = hs.hostname | ||||
self.store = hs.get_datastores().main | self.store = hs.get_datastores().main | ||||
@@ -644,22 +644,13 @@ class MediaRepository: | |||||
file_info = FileInfo(server_name=server_name, file_id=file_id) | file_info = FileInfo(server_name=server_name, file_id=file_id) | ||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish): | with self.media_storage.store_into_file(file_info) as (f, fname, finish): | ||||
request_path = "/".join( | |||||
("/_matrix/media/r0/download", server_name, media_id) | |||||
) | |||||
try: | try: | ||||
length, headers = await self.client.get_file( | |||||
length, headers = await self.client.download_media( | |||||
server_name, | server_name, | ||||
request_path, | |||||
media_id, | |||||
output_stream=f, | output_stream=f, | ||||
max_size=self.max_upload_size, | max_size=self.max_upload_size, | ||||
args={ | |||||
# tell the remote server to 404 if it doesn't | |||||
# recognise the server_name, to make sure we don't | |||||
# end up with a routing loop. | |||||
"allow_remote": "false", | |||||
"timeout_ms": str(max_timeout_ms), | |||||
}, | |||||
max_timeout_ms=max_timeout_ms, | |||||
) | ) | ||||
except RequestSendFailed as e: | except RequestSendFailed as e: | ||||
logger.warning( | logger.warning( | ||||
@@ -27,10 +27,11 @@ from typing_extensions import Literal | |||||
from twisted.internet import defer | from twisted.internet import defer | ||||
from twisted.internet.defer import Deferred | from twisted.internet.defer import Deferred | ||||
from twisted.python.failure import Failure | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
from twisted.web.resource import Resource | from twisted.web.resource import Resource | ||||
from synapse.api.errors import Codes | |||||
from synapse.api.errors import Codes, HttpResponseException | |||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
from synapse.http.types import QueryParams | from synapse.http.types import QueryParams | ||||
from synapse.logging.context import make_deferred_yieldable | from synapse.logging.context import make_deferred_yieldable | ||||
@@ -247,6 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||||
retry_on_dns_fail: bool = True, | retry_on_dns_fail: bool = True, | ||||
max_size: Optional[int] = None, | max_size: Optional[int] = None, | ||||
ignore_backoff: bool = False, | ignore_backoff: bool = False, | ||||
follow_redirects: bool = False, | |||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": | ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": | ||||
"""A mock for MatrixFederationHttpClient.get_file.""" | """A mock for MatrixFederationHttpClient.get_file.""" | ||||
@@ -257,10 +259,15 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||||
output_stream.write(data) | output_stream.write(data) | ||||
return response | return response | ||||
def write_err(f: Failure) -> Failure: | |||||
f.trap(HttpResponseException) | |||||
output_stream.write(f.value.response) | |||||
return f | |||||
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() | d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() | ||||
self.fetches.append((d, destination, path, args)) | self.fetches.append((d, destination, path, args)) | ||||
# Note that this callback changes the value held by d. | # Note that this callback changes the value held by d. | ||||
d_after_callback = d.addCallback(write_to) | |||||
d_after_callback = d.addCallbacks(write_to, write_err) | |||||
return make_deferred_yieldable(d_after_callback) | return make_deferred_yieldable(d_after_callback) | ||||
# Mock out the homeserver's MatrixFederationHttpClient | # Mock out the homeserver's MatrixFederationHttpClient | ||||
@@ -316,10 +323,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||||
self.assertEqual(len(self.fetches), 1) | self.assertEqual(len(self.fetches), 1) | ||||
self.assertEqual(self.fetches[0][1], "example.com") | self.assertEqual(self.fetches[0][1], "example.com") | ||||
self.assertEqual( | self.assertEqual( | ||||
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id | |||||
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id | |||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"} | |||||
self.fetches[0][3], | |||||
{"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"}, | |||||
) | ) | ||||
headers = { | headers = { | ||||
@@ -671,6 +679,52 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||||
[b"cross-origin"], | [b"cross-origin"], | ||||
) | ) | ||||
def test_unknown_v3_endpoint(self) -> None: | |||||
""" | |||||
If the v3 endpoint fails, try the r0 one. | |||||
""" | |||||
channel = self.make_request( | |||||
"GET", | |||||
f"/_matrix/media/v3/download/{self.media_id}", | |||||
shorthand=False, | |||||
await_result=False, | |||||
) | |||||
self.pump() | |||||
# We've made one fetch, to example.com, using the media URL, and asking | |||||
# the other server not to do a remote fetch | |||||
self.assertEqual(len(self.fetches), 1) | |||||
self.assertEqual(self.fetches[0][1], "example.com") | |||||
self.assertEqual( | |||||
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id | |||||
) | |||||
# The result which says the endpoint is unknown. | |||||
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' | |||||
self.fetches[0][0].errback( | |||||
HttpResponseException(404, "NOT FOUND", unknown_endpoint) | |||||
) | |||||
self.pump() | |||||
# There should now be another request to the r0 URL. | |||||
self.assertEqual(len(self.fetches), 2) | |||||
self.assertEqual(self.fetches[1][1], "example.com") | |||||
self.assertEqual( | |||||
self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}" | |||||
) | |||||
headers = { | |||||
b"Content-Length": [b"%d" % (len(self.test_image.data))], | |||||
} | |||||
self.fetches[1][0].callback( | |||||
(self.test_image.data, (len(self.test_image.data), headers)) | |||||
) | |||||
self.pump() | |||||
self.assertEqual(channel.code, 200) | |||||
class TestSpamCheckerLegacy: | class TestSpamCheckerLegacy: | ||||
"""A spam checker module that rejects all media that includes the bytes | """A spam checker module that rejects all media that includes the bytes | ||||
@@ -133,7 +133,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): | |||||
self.assertEqual(request.method, b"GET") | self.assertEqual(request.method, b"GET") | ||||
self.assertEqual( | self.assertEqual( | ||||
request.path, | request.path, | ||||
f"/_matrix/media/r0/download/{target}/{media_id}".encode(), | |||||
f"/_matrix/media/v3/download/{target}/{media_id}".encode(), | |||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] | request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] | ||||