Support asynchronous uploads as defined in MSC2246.tags/v1.97.0rc1
@@ -0,0 +1 @@ | |||
Add support for asynchronous uploads as defined by [MSC2246](https://github.com/matrix-org/matrix-spec-proposals/pull/2246). Contributed by @sumnerevans at @beeper. |
@@ -1753,6 +1753,19 @@ rc_third_party_invite: | |||
burst_count: 10 | |||
``` | |||
--- | |||
### `rc_media_create` | |||
This option ratelimits creation of MXC URIs via the `/_matrix/media/v1/create` | |||
endpoint based on the account that's creating the media. Defaults to | |||
`per_second: 10`, `burst_count: 50`. | |||
Example configuration: | |||
```yaml | |||
rc_media_create: | |||
per_second: 10 | |||
burst_count: 50 | |||
``` | |||
--- | |||
### `rc_federation` | |||
Defines limits on federation requests. | |||
@@ -1814,6 +1827,27 @@ Example configuration: | |||
media_store_path: "DATADIR/media_store" | |||
``` | |||
--- | |||
### `max_pending_media_uploads` | |||
How many *pending media uploads* can a given user have? A pending media upload | |||
is a created MXC URI that (a) is not expired (the `unused_expires_at` timestamp | |||
has not passed) and (b) the media has not yet been uploaded for. Defaults to 5. | |||
Example configuration: | |||
```yaml | |||
max_pending_media_uploads: 5 | |||
``` | |||
--- | |||
### `unused_expiration_time` | |||
How long to wait in milliseconds before expiring created media IDs. Defaults to | |||
"24h" | |||
Example configuration: | |||
```yaml | |||
unused_expiration_time: "1h" | |||
``` | |||
--- | |||
### `media_storage_providers` | |||
Media storage providers allow media to be stored in different | |||
@@ -83,6 +83,8 @@ class Codes(str, Enum): | |||
USER_DEACTIVATED = "M_USER_DEACTIVATED" | |||
# USER_LOCKED = "M_USER_LOCKED" | |||
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" | |||
NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED" | |||
CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA" | |||
# Part of MSC3848 | |||
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848 | |||
@@ -204,3 +204,10 @@ class RatelimitConfig(Config): | |||
"rc_third_party_invite", | |||
defaults={"per_second": 0.0025, "burst_count": 5}, | |||
) | |||
# Ratelimit create media requests: | |||
self.rc_media_create = RatelimitSettings.parse( | |||
config, | |||
"rc_media_create", | |||
defaults={"per_second": 10, "burst_count": 50}, | |||
) |
@@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config): | |||
"prevent_media_downloads_from", [] | |||
) | |||
self.unused_expiration_time = self.parse_duration( | |||
config.get("unused_expiration_time", "24h") | |||
) | |||
self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5) | |||
self.media_store_path = self.ensure_directory( | |||
config.get("media_store_path", "media_store") | |||
) | |||
@@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [ | |||
"audio/x-flac", | |||
] | |||
# Default timeout_ms for download and thumbnail requests | |||
DEFAULT_MAX_TIMEOUT_MS = 20_000 | |||
# Maximum allowed timeout_ms for download and thumbnail requests | |||
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000 | |||
def respond_404(request: SynapseRequest) -> None: | |||
assert request.path is not None | |||
@@ -27,13 +27,16 @@ import twisted.web.http | |||
from twisted.internet.defer import Deferred | |||
from synapse.api.errors import ( | |||
Codes, | |||
FederationDeniedError, | |||
HttpResponseException, | |||
NotFoundError, | |||
RequestSendFailed, | |||
SynapseError, | |||
cs_error, | |||
) | |||
from synapse.config.repository import ThumbnailRequirement | |||
from synapse.http.server import respond_with_json | |||
from synapse.http.site import SynapseRequest | |||
from synapse.logging.context import defer_to_thread | |||
from synapse.logging.opentracing import trace | |||
@@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper | |||
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError | |||
from synapse.media.url_previewer import UrlPreviewer | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage.databases.main.media_repository import RemoteMedia | |||
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia | |||
from synapse.types import UserID | |||
from synapse.util.async_helpers import Linearizer | |||
from synapse.util.retryutils import NotRetryingDestination | |||
@@ -80,6 +83,8 @@ class MediaRepository: | |||
self.store = hs.get_datastores().main | |||
self.max_upload_size = hs.config.media.max_upload_size | |||
self.max_image_pixels = hs.config.media.max_image_pixels | |||
self.unused_expiration_time = hs.config.media.unused_expiration_time | |||
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads | |||
Thumbnailer.set_limits(self.max_image_pixels) | |||
@@ -185,6 +190,117 @@ class MediaRepository: | |||
else: | |||
self.recently_accessed_locals.add(media_id) | |||
@trace | |||
async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: | |||
"""Create and store a media ID for a local user and return the MXC URI and its | |||
expiration. | |||
Args: | |||
auth_user: The user_id of the uploader | |||
Returns: | |||
A tuple containing the MXC URI of the stored content and the timestamp at | |||
which the MXC URI expires. | |||
""" | |||
media_id = random_string(24) | |||
now = self.clock.time_msec() | |||
await self.store.store_local_media_id( | |||
media_id=media_id, | |||
time_now_ms=now, | |||
user_id=auth_user, | |||
) | |||
return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time | |||
@trace | |||
async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]: | |||
"""Check if the user is over the limit for pending media uploads. | |||
Args: | |||
auth_user: The user_id of the uploader | |||
Returns: | |||
A tuple with a boolean and an integer indicating whether the user has too | |||
many pending media uploads and the timestamp at which the first pending | |||
media will expire, respectively. | |||
""" | |||
pending, first_expiration_ts = await self.store.count_pending_media( | |||
user_id=auth_user | |||
) | |||
return pending >= self.max_pending_media_uploads, first_expiration_ts | |||
@trace | |||
async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None: | |||
"""Verify that the media ID can be uploaded to by the given user. This | |||
function checks that: | |||
* the media ID exists | |||
* the media ID does not already have content | |||
* the user uploading is the same as the one who created the media ID | |||
* the media ID has not expired | |||
Args: | |||
media_id: The media ID to verify | |||
auth_user: The user_id of the uploader | |||
""" | |||
media = await self.store.get_local_media(media_id) | |||
if media is None: | |||
raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND) | |||
if media.user_id != auth_user.to_string(): | |||
raise SynapseError( | |||
403, | |||
"Only the creator of the media ID can upload to it", | |||
errcode=Codes.FORBIDDEN, | |||
) | |||
if media.media_length is not None: | |||
raise SynapseError( | |||
409, | |||
"Media ID already has content", | |||
errcode=Codes.CANNOT_OVERWRITE_MEDIA, | |||
) | |||
expired_time_ms = self.clock.time_msec() - self.unused_expiration_time | |||
if media.created_ts < expired_time_ms: | |||
raise NotFoundError("Media ID has expired") | |||
@trace | |||
async def update_content( | |||
self, | |||
media_id: str, | |||
media_type: str, | |||
upload_name: Optional[str], | |||
content: IO, | |||
content_length: int, | |||
auth_user: UserID, | |||
) -> None: | |||
"""Update the content of the given media ID. | |||
Args: | |||
media_id: The media ID to replace. | |||
media_type: The content type of the file. | |||
upload_name: The name of the file, if provided. | |||
content: A file like object that is the content to store | |||
content_length: The length of the content | |||
auth_user: The user_id of the uploader | |||
""" | |||
file_info = FileInfo(server_name=None, file_id=media_id) | |||
fname = await self.media_storage.store_file(content, file_info) | |||
logger.info("Stored local media in file %r", fname) | |||
await self.store.update_local_media( | |||
media_id=media_id, | |||
media_type=media_type, | |||
upload_name=upload_name, | |||
media_length=content_length, | |||
user_id=auth_user, | |||
) | |||
try: | |||
await self._generate_thumbnails(None, media_id, media_id, media_type) | |||
except Exception as e: | |||
logger.info("Failed to generate thumbnails: %s", e) | |||
@trace | |||
async def create_content( | |||
self, | |||
@@ -231,8 +347,74 @@ class MediaRepository: | |||
return MXCUri(self.server_name, media_id) | |||
def respond_not_yet_uploaded(self, request: SynapseRequest) -> None: | |||
respond_with_json( | |||
request, | |||
504, | |||
cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED), | |||
send_cors=True, | |||
) | |||
async def get_local_media_info( | |||
self, request: SynapseRequest, media_id: str, max_timeout_ms: int | |||
) -> Optional[LocalMedia]: | |||
"""Gets the info dictionary for given local media ID. If the media has | |||
not been uploaded yet, this function will wait up to ``max_timeout_ms`` | |||
milliseconds for the media to be uploaded. | |||
Args: | |||
request: The incoming request. | |||
media_id: The media ID of the content. (This is the same as | |||
the file_id for local content.) | |||
max_timeout_ms: the maximum number of milliseconds to wait for the | |||
media to be uploaded. | |||
Returns: | |||
Either the info dictionary for the given local media ID or | |||
``None``. If ``None``, then no further processing is necessary as | |||
this function will send the necessary JSON response. | |||
""" | |||
wait_until = self.clock.time_msec() + max_timeout_ms | |||
while True: | |||
# Get the info for the media | |||
media_info = await self.store.get_local_media(media_id) | |||
if not media_info: | |||
logger.info("Media %s is unknown", media_id) | |||
respond_404(request) | |||
return None | |||
if media_info.quarantined_by: | |||
logger.info("Media %s is quarantined", media_id) | |||
respond_404(request) | |||
return None | |||
# The file has been uploaded, so stop looping | |||
if media_info.media_length is not None: | |||
return media_info | |||
# Check if the media ID has expired and still hasn't been uploaded to. | |||
now = self.clock.time_msec() | |||
expired_time_ms = now - self.unused_expiration_time | |||
if media_info.created_ts < expired_time_ms: | |||
logger.info("Media %s has expired without being uploaded", media_id) | |||
respond_404(request) | |||
return None | |||
if now >= wait_until: | |||
break | |||
await self.clock.sleep(0.5) | |||
logger.info("Media %s has not yet been uploaded", media_id) | |||
self.respond_not_yet_uploaded(request) | |||
return None | |||
async def get_local_media( | |||
self, request: SynapseRequest, media_id: str, name: Optional[str] | |||
self, | |||
request: SynapseRequest, | |||
media_id: str, | |||
name: Optional[str], | |||
max_timeout_ms: int, | |||
) -> None: | |||
"""Responds to requests for local media, if exists, or returns 404. | |||
@@ -242,13 +424,14 @@ class MediaRepository: | |||
the file_id for local content.) | |||
name: Optional name that, if specified, will be used as | |||
the filename in the Content-Disposition header of the response. | |||
max_timeout_ms: the maximum number of milliseconds to wait for the | |||
media to be uploaded. | |||
Returns: | |||
Resolves once a response has successfully been written to request | |||
""" | |||
media_info = await self.store.get_local_media(media_id) | |||
if not media_info or media_info.quarantined_by: | |||
respond_404(request) | |||
media_info = await self.get_local_media_info(request, media_id, max_timeout_ms) | |||
if not media_info: | |||
return | |||
self.mark_recently_accessed(None, media_id) | |||
@@ -273,6 +456,7 @@ class MediaRepository: | |||
server_name: str, | |||
media_id: str, | |||
name: Optional[str], | |||
max_timeout_ms: int, | |||
) -> None: | |||
"""Respond to requests for remote media. | |||
@@ -282,6 +466,8 @@ class MediaRepository: | |||
media_id: The media ID of the content (as defined by the remote server). | |||
name: Optional name that, if specified, will be used as | |||
the filename in the Content-Disposition header of the response. | |||
max_timeout_ms: the maximum number of milliseconds to wait for the | |||
media to be uploaded. | |||
Returns: | |||
Resolves once a response has successfully been written to request | |||
@@ -307,11 +493,11 @@ class MediaRepository: | |||
key = (server_name, media_id) | |||
async with self.remote_media_linearizer.queue(key): | |||
responder, media_info = await self._get_remote_media_impl( | |||
server_name, media_id | |||
server_name, media_id, max_timeout_ms | |||
) | |||
# We deliberately stream the file outside the lock | |||
if responder: | |||
if responder and media_info: | |||
upload_name = name if name else media_info.upload_name | |||
await respond_with_responder( | |||
request, | |||
@@ -324,7 +510,7 @@ class MediaRepository: | |||
respond_404(request) | |||
async def get_remote_media_info( | |||
self, server_name: str, media_id: str | |||
self, server_name: str, media_id: str, max_timeout_ms: int | |||
) -> RemoteMedia: | |||
"""Gets the media info associated with the remote file, downloading | |||
if necessary. | |||
@@ -332,6 +518,8 @@ class MediaRepository: | |||
Args: | |||
server_name: Remote server_name where the media originated. | |||
media_id: The media ID of the content (as defined by the remote server). | |||
max_timeout_ms: the maximum number of milliseconds to wait for the | |||
media to be uploaded. | |||
Returns: | |||
The media info of the file | |||
@@ -347,7 +535,7 @@ class MediaRepository: | |||
key = (server_name, media_id) | |||
async with self.remote_media_linearizer.queue(key): | |||
responder, media_info = await self._get_remote_media_impl( | |||
server_name, media_id | |||
server_name, media_id, max_timeout_ms | |||
) | |||
# Ensure we actually use the responder so that it releases resources | |||
@@ -358,7 +546,7 @@ class MediaRepository: | |||
return media_info | |||
async def _get_remote_media_impl( | |||
self, server_name: str, media_id: str | |||
self, server_name: str, media_id: str, max_timeout_ms: int | |||
) -> Tuple[Optional[Responder], RemoteMedia]: | |||
"""Looks for media in local cache, if not there then attempt to | |||
download from remote server. | |||
@@ -367,6 +555,8 @@ class MediaRepository: | |||
server_name: Remote server_name where the media originated. | |||
media_id: The media ID of the content (as defined by the | |||
remote server). | |||
max_timeout_ms: the maximum number of milliseconds to wait for the | |||
media to be uploaded. | |||
Returns: | |||
A tuple of responder and the media info of the file. | |||
@@ -399,8 +589,7 @@ class MediaRepository: | |||
try: | |||
media_info = await self._download_remote_file( | |||
server_name, | |||
media_id, | |||
server_name, media_id, max_timeout_ms | |||
) | |||
except SynapseError: | |||
raise | |||
@@ -433,6 +622,7 @@ class MediaRepository: | |||
self, | |||
server_name: str, | |||
media_id: str, | |||
max_timeout_ms: int, | |||
) -> RemoteMedia: | |||
"""Attempt to download the remote file from the given server name, | |||
using the given file_id as the local id. | |||
@@ -442,7 +632,8 @@ class MediaRepository: | |||
media_id: The media ID of the content (as defined by the | |||
remote server). This is different than the file_id, which is | |||
locally generated. | |||
file_id: Local file ID | |||
max_timeout_ms: the maximum number of milliseconds to wait for the | |||
media to be uploaded. | |||
Returns: | |||
The media info of the file. | |||
@@ -466,7 +657,8 @@ class MediaRepository: | |||
# tell the remote server to 404 if it doesn't | |||
# recognise the server_name, to make sure we don't | |||
# end up with a routing loop. | |||
"allow_remote": "false" | |||
"allow_remote": "false", | |||
"timeout_ms": str(max_timeout_ms), | |||
}, | |||
) | |||
except RequestSendFailed as e: | |||
@@ -0,0 +1,83 @@ | |||
# Copyright 2023 Beeper Inc. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
import re | |||
from typing import TYPE_CHECKING | |||
from synapse.api.errors import LimitExceededError | |||
from synapse.api.ratelimiting import Ratelimiter | |||
from synapse.http.server import respond_with_json | |||
from synapse.http.servlet import RestServlet | |||
from synapse.http.site import SynapseRequest | |||
if TYPE_CHECKING: | |||
from synapse.media.media_repository import MediaRepository | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class CreateResource(RestServlet): | |||
PATTERNS = [re.compile("/_matrix/media/v1/create")] | |||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): | |||
super().__init__() | |||
self.media_repo = media_repo | |||
self.clock = hs.get_clock() | |||
self.auth = hs.get_auth() | |||
self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads | |||
# A rate limiter for creating new media IDs. | |||
self._create_media_rate_limiter = Ratelimiter( | |||
store=hs.get_datastores().main, | |||
clock=self.clock, | |||
cfg=hs.config.ratelimiting.rc_media_create, | |||
) | |||
async def on_POST(self, request: SynapseRequest) -> None: | |||
requester = await self.auth.get_user_by_req(request) | |||
# If the create media requests for the user are over the limit, drop them. | |||
await self._create_media_rate_limiter.ratelimit(requester) | |||
( | |||
reached_pending_limit, | |||
first_expiration_ts, | |||
) = await self.media_repo.reached_pending_media_limit(requester.user) | |||
if reached_pending_limit: | |||
raise LimitExceededError( | |||
limiter_name="max_pending_media_uploads", | |||
retry_after_ms=first_expiration_ts - self.clock.time_msec(), | |||
) | |||
content_uri, unused_expires_at = await self.media_repo.create_media_id( | |||
requester.user | |||
) | |||
logger.info( | |||
"Created Media URI %r that if unused will expire at %d", | |||
content_uri, | |||
unused_expires_at, | |||
) | |||
respond_with_json( | |||
request, | |||
200, | |||
{ | |||
"content_uri": content_uri, | |||
"unused_expires_at": unused_expires_at, | |||
}, | |||
send_cors=True, | |||
) |
@@ -17,9 +17,13 @@ import re | |||
from typing import TYPE_CHECKING, Optional | |||
from synapse.http.server import set_corp_headers, set_cors_headers | |||
from synapse.http.servlet import RestServlet, parse_boolean | |||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer | |||
from synapse.http.site import SynapseRequest | |||
from synapse.media._base import respond_404 | |||
from synapse.media._base import ( | |||
DEFAULT_MAX_TIMEOUT_MS, | |||
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS, | |||
respond_404, | |||
) | |||
from synapse.util.stringutils import parse_and_validate_server_name | |||
if TYPE_CHECKING: | |||
@@ -65,12 +69,16 @@ class DownloadResource(RestServlet): | |||
) | |||
# Limited non-standard form of CSP for IE11 | |||
request.setHeader(b"X-Content-Security-Policy", b"sandbox;") | |||
request.setHeader( | |||
b"Referrer-Policy", | |||
b"no-referrer", | |||
request.setHeader(b"Referrer-Policy", b"no-referrer") | |||
max_timeout_ms = parse_integer( | |||
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS | |||
) | |||
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) | |||
if self._is_mine_server_name(server_name): | |||
await self.media_repo.get_local_media(request, media_id, file_name) | |||
await self.media_repo.get_local_media( | |||
request, media_id, file_name, max_timeout_ms | |||
) | |||
else: | |||
allow_remote = parse_boolean(request, "allow_remote", default=True) | |||
if not allow_remote: | |||
@@ -83,5 +91,5 @@ class DownloadResource(RestServlet): | |||
return | |||
await self.media_repo.get_remote_media( | |||
request, server_name, media_id, file_name | |||
request, server_name, media_id, file_name, max_timeout_ms | |||
) |
@@ -18,10 +18,11 @@ from synapse.config._base import ConfigError | |||
from synapse.http.server import HttpServer, JsonResource | |||
from .config_resource import MediaConfigResource | |||
from .create_resource import CreateResource | |||
from .download_resource import DownloadResource | |||
from .preview_url_resource import PreviewUrlResource | |||
from .thumbnail_resource import ThumbnailResource | |||
from .upload_resource import UploadResource | |||
from .upload_resource import AsyncUploadServlet, UploadServlet | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource): | |||
# Note that many of these should not exist as v1 endpoints, but empirically | |||
# a lot of traffic still goes to them. | |||
UploadResource(hs, media_repo).register(http_server) | |||
CreateResource(hs, media_repo).register(http_server) | |||
UploadServlet(hs, media_repo).register(http_server) | |||
AsyncUploadServlet(hs, media_repo).register(http_server) | |||
DownloadResource(hs, media_repo).register(http_server) | |||
ThumbnailResource(hs, media_repo, media_repo.media_storage).register( | |||
http_server | |||
@@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he | |||
from synapse.http.servlet import RestServlet, parse_integer, parse_string | |||
from synapse.http.site import SynapseRequest | |||
from synapse.media._base import ( | |||
DEFAULT_MAX_TIMEOUT_MS, | |||
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS, | |||
FileInfo, | |||
ThumbnailInfo, | |||
respond_404, | |||
@@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet): | |||
method = parse_string(request, "method", "scale") | |||
# TODO Parse the Accept header to get an prioritised list of thumbnail types. | |||
m_type = "image/png" | |||
max_timeout_ms = parse_integer( | |||
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS | |||
) | |||
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) | |||
if self._is_mine_server_name(server_name): | |||
if self.dynamic_thumbnails: | |||
await self._select_or_generate_local_thumbnail( | |||
request, media_id, width, height, method, m_type | |||
request, media_id, width, height, method, m_type, max_timeout_ms | |||
) | |||
else: | |||
await self._respond_local_thumbnail( | |||
request, media_id, width, height, method, m_type | |||
request, media_id, width, height, method, m_type, max_timeout_ms | |||
) | |||
self.media_repo.mark_recently_accessed(None, media_id) | |||
else: | |||
@@ -95,14 +101,21 @@ class ThumbnailResource(RestServlet): | |||
respond_404(request) | |||
return | |||
if self.dynamic_thumbnails: | |||
await self._select_or_generate_remote_thumbnail( | |||
request, server_name, media_id, width, height, method, m_type | |||
) | |||
else: | |||
await self._respond_remote_thumbnail( | |||
request, server_name, media_id, width, height, method, m_type | |||
) | |||
remote_resp_function = ( | |||
self._select_or_generate_remote_thumbnail | |||
if self.dynamic_thumbnails | |||
else self._respond_remote_thumbnail | |||
) | |||
await remote_resp_function( | |||
request, | |||
server_name, | |||
media_id, | |||
width, | |||
height, | |||
method, | |||
m_type, | |||
max_timeout_ms, | |||
) | |||
self.media_repo.mark_recently_accessed(server_name, media_id) | |||
async def _respond_local_thumbnail( | |||
@@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet): | |||
height: int, | |||
method: str, | |||
m_type: str, | |||
max_timeout_ms: int, | |||
) -> None: | |||
media_info = await self.store.get_local_media(media_id) | |||
media_info = await self.media_repo.get_local_media_info( | |||
request, media_id, max_timeout_ms | |||
) | |||
if not media_info: | |||
respond_404(request) | |||
return | |||
if media_info.quarantined_by: | |||
logger.info("Media is quarantined") | |||
respond_404(request) | |||
return | |||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) | |||
@@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet): | |||
desired_height: int, | |||
desired_method: str, | |||
desired_type: str, | |||
max_timeout_ms: int, | |||
) -> None: | |||
media_info = await self.store.get_local_media(media_id) | |||
media_info = await self.media_repo.get_local_media_info( | |||
request, media_id, max_timeout_ms | |||
) | |||
if not media_info: | |||
respond_404(request) | |||
return | |||
if media_info.quarantined_by: | |||
logger.info("Media is quarantined") | |||
respond_404(request) | |||
return | |||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) | |||
@@ -206,8 +214,14 @@ class ThumbnailResource(RestServlet): | |||
desired_height: int, | |||
desired_method: str, | |||
desired_type: str, | |||
max_timeout_ms: int, | |||
) -> None: | |||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id) | |||
media_info = await self.media_repo.get_remote_media_info( | |||
server_name, media_id, max_timeout_ms | |||
) | |||
if not media_info: | |||
respond_404(request) | |||
return | |||
thumbnail_infos = await self.store.get_remote_media_thumbnails( | |||
server_name, media_id | |||
@@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet): | |||
height: int, | |||
method: str, | |||
m_type: str, | |||
max_timeout_ms: int, | |||
) -> None: | |||
# TODO: Don't download the whole remote file | |||
# We should proxy the thumbnail from the remote server instead of | |||
# downloading the remote file and generating our own thumbnails. | |||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id) | |||
media_info = await self.media_repo.get_remote_media_info( | |||
server_name, media_id, max_timeout_ms | |||
) | |||
if not media_info: | |||
return | |||
thumbnail_infos = await self.store.get_remote_media_thumbnails( | |||
server_name, media_id | |||
@@ -15,7 +15,7 @@ | |||
import logging | |||
import re | |||
from typing import IO, TYPE_CHECKING, Dict, List, Optional | |||
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.http.server import respond_with_json | |||
@@ -29,23 +29,24 @@ if TYPE_CHECKING: | |||
logger = logging.getLogger(__name__) | |||
# The name of the lock to use when uploading media. | |||
_UPLOAD_MEDIA_LOCK_NAME = "upload_media" | |||
class UploadResource(RestServlet): | |||
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")] | |||
class BaseUploadServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): | |||
super().__init__() | |||
self.media_repo = media_repo | |||
self.filepaths = media_repo.filepaths | |||
self.store = hs.get_datastores().main | |||
self.clock = hs.get_clock() | |||
self.server_name = hs.hostname | |||
self.auth = hs.get_auth() | |||
self.max_upload_size = hs.config.media.max_upload_size | |||
self.clock = hs.get_clock() | |||
async def on_POST(self, request: SynapseRequest) -> None: | |||
requester = await self.auth.get_user_by_req(request) | |||
def _get_file_metadata( | |||
self, request: SynapseRequest | |||
) -> Tuple[int, Optional[str], str]: | |||
raw_content_length = request.getHeader("Content-Length") | |||
if raw_content_length is None: | |||
raise SynapseError(msg="Request must specify a Content-Length", code=400) | |||
@@ -88,6 +89,16 @@ class UploadResource(RestServlet): | |||
# disposition = headers.getRawHeaders(b"Content-Disposition")[0] | |||
# TODO(markjh): parse content-dispostion | |||
return content_length, upload_name, media_type | |||
class UploadServlet(BaseUploadServlet): | |||
PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")] | |||
async def on_POST(self, request: SynapseRequest) -> None: | |||
requester = await self.auth.get_user_by_req(request) | |||
content_length, upload_name, media_type = self._get_file_metadata(request) | |||
try: | |||
content: IO = request.content # type: ignore | |||
content_uri = await self.media_repo.create_content( | |||
@@ -103,3 +114,53 @@ class UploadResource(RestServlet): | |||
respond_with_json( | |||
request, 200, {"content_uri": str(content_uri)}, send_cors=True | |||
) | |||
class AsyncUploadServlet(BaseUploadServlet): | |||
PATTERNS = [ | |||
re.compile( | |||
"/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" | |||
) | |||
] | |||
async def on_PUT( | |||
self, request: SynapseRequest, server_name: str, media_id: str | |||
) -> None: | |||
requester = await self.auth.get_user_by_req(request) | |||
if server_name != self.server_name: | |||
raise SynapseError( | |||
404, | |||
"Non-local server name specified", | |||
errcode=Codes.NOT_FOUND, | |||
) | |||
lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id) | |||
if not lock: | |||
raise SynapseError( | |||
409, | |||
"Media ID cannot be overwritten", | |||
errcode=Codes.CANNOT_OVERWRITE_MEDIA, | |||
) | |||
async with lock: | |||
await self.media_repo.verify_can_upload(media_id, requester.user) | |||
content_length, upload_name, media_type = self._get_file_metadata(request) | |||
try: | |||
content: IO = request.content # type: ignore | |||
await self.media_repo.update_content( | |||
media_id, | |||
media_type, | |||
upload_name, | |||
content, | |||
content_length, | |||
requester.user, | |||
) | |||
except SpamMediaException: | |||
# For uploading of media we want to respond with a 400, instead of | |||
# the default 404, as that would just be confusing. | |||
raise SynapseError(400, "Bad content") | |||
logger.info("Uploaded content for media ID %r", media_id) | |||
respond_with_json(request, 200, {}, send_cors=True) |
@@ -49,13 +49,14 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( | |||
class LocalMedia: | |||
media_id: str | |||
media_type: str | |||
media_length: int | |||
media_length: Optional[int] | |||
upload_name: str | |||
created_ts: int | |||
url_cache: Optional[str] | |||
last_access_ts: int | |||
quarantined_by: Optional[str] | |||
safe_from_quarantine: bool | |||
user_id: Optional[str] | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
@@ -149,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): | |||
self._drop_media_index_without_method, | |||
) | |||
if hs.config.media.can_load_media_repo: | |||
self.unused_expiration_time: Optional[ | |||
int | |||
] = hs.config.media.unused_expiration_time | |||
else: | |||
self.unused_expiration_time = None | |||
async def _drop_media_index_without_method( | |||
self, progress: JsonDict, batch_size: int | |||
) -> int: | |||
@@ -202,6 +210,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
"url_cache", | |||
"last_access_ts", | |||
"safe_from_quarantine", | |||
"user_id", | |||
), | |||
allow_none=True, | |||
desc="get_local_media", | |||
@@ -218,6 +227,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
url_cache=row[5], | |||
last_access_ts=row[6], | |||
safe_from_quarantine=row[7], | |||
user_id=row[8], | |||
) | |||
async def get_local_media_by_user_paginate( | |||
@@ -272,7 +282,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
url_cache, | |||
last_access_ts, | |||
quarantined_by, | |||
safe_from_quarantine | |||
safe_from_quarantine, | |||
user_id | |||
FROM local_media_repository | |||
WHERE user_id = ? | |||
ORDER BY {order_by_column} {order}, media_id ASC | |||
@@ -295,6 +306,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
last_access_ts=row[6], | |||
quarantined_by=row[7], | |||
safe_from_quarantine=bool(row[8]), | |||
user_id=row[9], | |||
) | |||
for row in txn | |||
] | |||
@@ -391,6 +403,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
"get_local_media_ids", _get_local_media_ids_txn | |||
) | |||
@trace | |||
async def store_local_media_id( | |||
self, | |||
media_id: str, | |||
time_now_ms: int, | |||
user_id: UserID, | |||
) -> None: | |||
await self.db_pool.simple_insert( | |||
"local_media_repository", | |||
{ | |||
"media_id": media_id, | |||
"created_ts": time_now_ms, | |||
"user_id": user_id.to_string(), | |||
}, | |||
desc="store_local_media_id", | |||
) | |||
@trace | |||
async def store_local_media( | |||
self, | |||
@@ -416,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
desc="store_local_media", | |||
) | |||
async def update_local_media( | |||
self, | |||
media_id: str, | |||
media_type: str, | |||
upload_name: Optional[str], | |||
media_length: int, | |||
user_id: UserID, | |||
url_cache: Optional[str] = None, | |||
) -> None: | |||
await self.db_pool.simple_update_one( | |||
"local_media_repository", | |||
keyvalues={ | |||
"user_id": user_id.to_string(), | |||
"media_id": media_id, | |||
}, | |||
updatevalues={ | |||
"media_type": media_type, | |||
"upload_name": upload_name, | |||
"media_length": media_length, | |||
"url_cache": url_cache, | |||
}, | |||
desc="update_local_media", | |||
) | |||
async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: | |||
"""Mark a local media as safe or unsafe from quarantining.""" | |||
await self.db_pool.simple_update_one( | |||
@@ -425,6 +478,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
desc="mark_local_media_as_safe", | |||
) | |||
async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]: | |||
"""Count the number of pending media for a user. | |||
Returns: | |||
A tuple of two integers: the total pending media requests and the earliest | |||
expiration timestamp. | |||
""" | |||
def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: | |||
sql = """ | |||
SELECT COUNT(*), MIN(created_ts) | |||
FROM local_media_repository | |||
WHERE user_id = ? | |||
AND created_ts > ? | |||
AND media_length IS NULL | |||
""" | |||
assert self.unused_expiration_time is not None | |||
txn.execute( | |||
sql, | |||
( | |||
user_id.to_string(), | |||
self._clock.time_msec() - self.unused_expiration_time, | |||
), | |||
) | |||
row = txn.fetchone() | |||
if not row: | |||
return 0, 0 | |||
return row[0], (row[1] + self.unused_expiration_time if row[1] else 0) | |||
return await self.db_pool.runInteraction( | |||
"get_pending_media", get_pending_media_txn | |||
) | |||
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: | |||
"""Get the media_id and ts for a cached URL as of the given timestamp | |||
Returns: | |||
@@ -318,7 +318,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): | |||
self.assertEqual( | |||
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id | |||
) | |||
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) | |||
self.assertEqual( | |||
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"} | |||
) | |||
headers = { | |||
b"Content-Length": [b"%d" % (len(self.test_image.data))], | |||