No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.
 
 
 
 
 
 

399 líneas
14 KiB

  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. import logging
  16. import os
  17. import shutil
  18. from types import TracebackType
  19. from typing import (
  20. IO,
  21. TYPE_CHECKING,
  22. Any,
  23. Awaitable,
  24. BinaryIO,
  25. Callable,
  26. Generator,
  27. Optional,
  28. Sequence,
  29. Tuple,
  30. Type,
  31. )
  32. import attr
  33. from twisted.internet.defer import Deferred
  34. from twisted.internet.interfaces import IConsumer
  35. from twisted.protocols.basic import FileSender
  36. from synapse.api.errors import NotFoundError
  37. from synapse.logging.context import defer_to_thread, make_deferred_yieldable
  38. from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
  39. from synapse.util import Clock
  40. from synapse.util.file_consumer import BackgroundFileConsumer
  41. from ._base import FileInfo, Responder
  42. from .filepath import MediaFilePaths
  43. if TYPE_CHECKING:
  44. from synapse.media.storage_provider import StorageProvider
  45. from synapse.server import HomeServer
  46. logger = logging.getLogger(__name__)
  47. class MediaStorage:
  48. """Responsible for storing/fetching files from local sources.
  49. Args:
  50. hs
  51. local_media_directory: Base path where we store media on disk
  52. filepaths
  53. storage_providers: List of StorageProvider that are used to fetch and store files.
  54. """
  55. def __init__(
  56. self,
  57. hs: "HomeServer",
  58. local_media_directory: str,
  59. filepaths: MediaFilePaths,
  60. storage_providers: Sequence["StorageProvider"],
  61. ):
  62. self.hs = hs
  63. self.reactor = hs.get_reactor()
  64. self.local_media_directory = local_media_directory
  65. self.filepaths = filepaths
  66. self.storage_providers = storage_providers
  67. self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
  68. self.clock = hs.get_clock()
  69. @trace_with_opname("MediaStorage.store_file")
  70. async def store_file(self, source: IO, file_info: FileInfo) -> str:
  71. """Write `source` to the on disk media store, and also any other
  72. configured storage providers
  73. Args:
  74. source: A file like object that should be written
  75. file_info: Info about the file to store
  76. Returns:
  77. the file path written to in the primary media store
  78. """
  79. with self.store_into_file(file_info) as (f, fname, finish_cb):
  80. # Write to the main media repository
  81. await self.write_to_file(source, f)
  82. # Write to the other storage providers
  83. await finish_cb()
  84. return fname
  85. @trace_with_opname("MediaStorage.write_to_file")
  86. async def write_to_file(self, source: IO, output: IO) -> None:
  87. """Asynchronously write the `source` to `output`."""
  88. await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
  89. @trace_with_opname("MediaStorage.store_into_file")
  90. @contextlib.contextmanager
  91. def store_into_file(
  92. self, file_info: FileInfo
  93. ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
  94. """Context manager used to get a file like object to write into, as
  95. described by file_info.
  96. Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
  97. like object that can be written to, fname is the absolute path of file
  98. on disk, and finish_cb is a function that returns an awaitable.
  99. fname can be used to read the contents from after upload, e.g. to
  100. generate thumbnails.
  101. finish_cb must be called and waited on after the file has been successfully been
  102. written to. Should not be called if there was an error. Checks for spam and
  103. stores the file into the configured storage providers.
  104. Args:
  105. file_info: Info about the file to store
  106. Example:
  107. with media_storage.store_into_file(info) as (f, fname, finish_cb):
  108. # .. write into f ...
  109. await finish_cb()
  110. """
  111. path = self._file_info_to_path(file_info)
  112. fname = os.path.join(self.local_media_directory, path)
  113. dirname = os.path.dirname(fname)
  114. os.makedirs(dirname, exist_ok=True)
  115. finished_called = [False]
  116. main_media_repo_write_trace_scope = start_active_span(
  117. "writing to main media repo"
  118. )
  119. main_media_repo_write_trace_scope.__enter__()
  120. try:
  121. with open(fname, "wb") as f:
  122. async def finish() -> None:
  123. # When someone calls finish, we assume they are done writing to the main media repo
  124. main_media_repo_write_trace_scope.__exit__(None, None, None)
  125. with start_active_span("writing to other storage providers"):
  126. # Ensure that all writes have been flushed and close the
  127. # file.
  128. f.flush()
  129. f.close()
  130. spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
  131. ReadableFileWrapper(self.clock, fname), file_info
  132. )
  133. if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
  134. logger.info("Blocking media due to spam checker")
  135. # Note that we'll delete the stored media, due to the
  136. # try/except below. The media also won't be stored in
  137. # the DB.
  138. # We currently ignore any additional field returned by
  139. # the spam-check API.
  140. raise SpamMediaException(errcode=spam_check[0])
  141. for provider in self.storage_providers:
  142. with start_active_span(str(provider)):
  143. await provider.store_file(path, file_info)
  144. finished_called[0] = True
  145. yield f, fname, finish
  146. except Exception as e:
  147. try:
  148. main_media_repo_write_trace_scope.__exit__(
  149. type(e), None, e.__traceback__
  150. )
  151. os.remove(fname)
  152. except Exception:
  153. pass
  154. raise e from None
  155. if not finished_called:
  156. exc = Exception("Finished callback not called")
  157. main_media_repo_write_trace_scope.__exit__(
  158. type(exc), None, exc.__traceback__
  159. )
  160. raise exc
  161. async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
  162. """Attempts to fetch media described by file_info from the local cache
  163. and configured storage providers.
  164. Args:
  165. file_info
  166. Returns:
  167. Returns a Responder if the file was found, otherwise None.
  168. """
  169. paths = [self._file_info_to_path(file_info)]
  170. # fallback for remote thumbnails with no method in the filename
  171. if file_info.thumbnail and file_info.server_name:
  172. paths.append(
  173. self.filepaths.remote_media_thumbnail_rel_legacy(
  174. server_name=file_info.server_name,
  175. file_id=file_info.file_id,
  176. width=file_info.thumbnail.width,
  177. height=file_info.thumbnail.height,
  178. content_type=file_info.thumbnail.type,
  179. )
  180. )
  181. for path in paths:
  182. local_path = os.path.join(self.local_media_directory, path)
  183. if os.path.exists(local_path):
  184. logger.debug("responding with local file %s", local_path)
  185. return FileResponder(open(local_path, "rb"))
  186. logger.debug("local file %s did not exist", local_path)
  187. for provider in self.storage_providers:
  188. for path in paths:
  189. res: Any = await provider.fetch(path, file_info)
  190. if res:
  191. logger.debug("Streaming %s from %s", path, provider)
  192. return res
  193. logger.debug("%s not found on %s", path, provider)
  194. return None
  195. @trace
  196. async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
  197. """Ensures that the given file is in the local cache. Attempts to
  198. download it from storage providers if it isn't.
  199. Args:
  200. file_info
  201. Returns:
  202. Full path to local file
  203. """
  204. path = self._file_info_to_path(file_info)
  205. local_path = os.path.join(self.local_media_directory, path)
  206. if os.path.exists(local_path):
  207. return local_path
  208. # Fallback for paths without method names
  209. # Should be removed in the future
  210. if file_info.thumbnail and file_info.server_name:
  211. legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
  212. server_name=file_info.server_name,
  213. file_id=file_info.file_id,
  214. width=file_info.thumbnail.width,
  215. height=file_info.thumbnail.height,
  216. content_type=file_info.thumbnail.type,
  217. )
  218. legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
  219. if os.path.exists(legacy_local_path):
  220. return legacy_local_path
  221. dirname = os.path.dirname(local_path)
  222. os.makedirs(dirname, exist_ok=True)
  223. for provider in self.storage_providers:
  224. res: Any = await provider.fetch(path, file_info)
  225. if res:
  226. with res:
  227. consumer = BackgroundFileConsumer(
  228. open(local_path, "wb"), self.reactor
  229. )
  230. await res.write_to_consumer(consumer)
  231. await consumer.wait()
  232. return local_path
  233. raise NotFoundError()
  234. @trace
  235. def _file_info_to_path(self, file_info: FileInfo) -> str:
  236. """Converts file_info into a relative path.
  237. The path is suitable for storing files under a directory, e.g. used to
  238. store files on local FS under the base media repository directory.
  239. """
  240. if file_info.url_cache:
  241. if file_info.thumbnail:
  242. return self.filepaths.url_cache_thumbnail_rel(
  243. media_id=file_info.file_id,
  244. width=file_info.thumbnail.width,
  245. height=file_info.thumbnail.height,
  246. content_type=file_info.thumbnail.type,
  247. method=file_info.thumbnail.method,
  248. )
  249. return self.filepaths.url_cache_filepath_rel(file_info.file_id)
  250. if file_info.server_name:
  251. if file_info.thumbnail:
  252. return self.filepaths.remote_media_thumbnail_rel(
  253. server_name=file_info.server_name,
  254. file_id=file_info.file_id,
  255. width=file_info.thumbnail.width,
  256. height=file_info.thumbnail.height,
  257. content_type=file_info.thumbnail.type,
  258. method=file_info.thumbnail.method,
  259. )
  260. return self.filepaths.remote_media_filepath_rel(
  261. file_info.server_name, file_info.file_id
  262. )
  263. if file_info.thumbnail:
  264. return self.filepaths.local_media_thumbnail_rel(
  265. media_id=file_info.file_id,
  266. width=file_info.thumbnail.width,
  267. height=file_info.thumbnail.height,
  268. content_type=file_info.thumbnail.type,
  269. method=file_info.thumbnail.method,
  270. )
  271. return self.filepaths.local_media_filepath_rel(file_info.file_id)
  272. @trace
  273. def _write_file_synchronously(source: IO, dest: IO) -> None:
  274. """Write `source` to the file like `dest` synchronously. Should be called
  275. from a thread.
  276. Args:
  277. source: A file like object that's to be written
  278. dest: A file like object to be written to
  279. """
  280. source.seek(0) # Ensure we read from the start of the file
  281. shutil.copyfileobj(source, dest)
  282. class FileResponder(Responder):
  283. """Wraps an open file that can be sent to a request.
  284. Args:
  285. open_file: A file like object to be streamed ot the client,
  286. is closed when finished streaming.
  287. """
  288. def __init__(self, open_file: IO):
  289. self.open_file = open_file
  290. def write_to_consumer(self, consumer: IConsumer) -> Deferred:
  291. return make_deferred_yieldable(
  292. FileSender().beginFileTransfer(self.open_file, consumer)
  293. )
  294. def __exit__(
  295. self,
  296. exc_type: Optional[Type[BaseException]],
  297. exc_val: Optional[BaseException],
  298. exc_tb: Optional[TracebackType],
  299. ) -> None:
  300. self.open_file.close()
  301. class SpamMediaException(NotFoundError):
  302. """The media was blocked by a spam checker, so we simply 404 the request (in
  303. the same way as if it was quarantined).
  304. """
  305. @attr.s(slots=True, auto_attribs=True)
  306. class ReadableFileWrapper:
  307. """Wrapper that allows reading a file in chunks, yielding to the reactor,
  308. and writing to a callback.
  309. This is simplified `FileSender` that takes an IO object rather than an
  310. `IConsumer`.
  311. """
  312. CHUNK_SIZE = 2**14
  313. clock: Clock
  314. path: str
  315. async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
  316. """Reads the file in chunks and calls the callback with each chunk."""
  317. with open(self.path, "rb") as file:
  318. while True:
  319. chunk = file.read(self.CHUNK_SIZE)
  320. if not chunk:
  321. break
  322. callback(chunk)
  323. # We yield to the reactor by sleeping for 0 seconds.
  324. await self.clock.sleep(0)