You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

737 lines
25 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from enum import Enum
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Collection,
  20. Dict,
  21. Iterable,
  22. List,
  23. Optional,
  24. Tuple,
  25. Union,
  26. cast,
  27. )
  28. from synapse.api.constants import Direction
  29. from synapse.storage._base import SQLBaseStore
  30. from synapse.storage.database import (
  31. DatabasePool,
  32. LoggingDatabaseConnection,
  33. LoggingTransaction,
  34. )
  35. from synapse.types import JsonDict, UserID
  36. if TYPE_CHECKING:
  37. from synapse.server import HomeServer
  38. BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
  39. "media_repository_drop_index_wo_method_2"
  40. )
  41. class MediaSortOrder(Enum):
  42. """
  43. Enum to define the sorting method used when returning media with
  44. get_local_media_by_user_paginate
  45. """
  46. MEDIA_ID = "media_id"
  47. UPLOAD_NAME = "upload_name"
  48. CREATED_TS = "created_ts"
  49. LAST_ACCESS_TS = "last_access_ts"
  50. MEDIA_LENGTH = "media_length"
  51. MEDIA_TYPE = "media_type"
  52. QUARANTINED_BY = "quarantined_by"
  53. SAFE_FROM_QUARANTINE = "safe_from_quarantine"
  54. class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
  55. def __init__(
  56. self,
  57. database: DatabasePool,
  58. db_conn: LoggingDatabaseConnection,
  59. hs: "HomeServer",
  60. ):
  61. super().__init__(database, db_conn, hs)
  62. self.db_pool.updates.register_background_index_update(
  63. update_name="local_media_repository_url_idx",
  64. index_name="local_media_repository_url_idx",
  65. table="local_media_repository",
  66. columns=["created_ts"],
  67. where_clause="url_cache IS NOT NULL",
  68. )
  69. # The following the updates add the method to the unique constraint of
  70. # the thumbnail databases. That fixes an issue, where thumbnails of the
  71. # same resolution, but different methods could overwrite one another.
  72. # This can happen with custom thumbnail configs or with dynamic thumbnailing.
  73. self.db_pool.updates.register_background_index_update(
  74. update_name="local_media_repository_thumbnails_method_idx",
  75. index_name="local_media_repository_thumbn_media_id_width_height_method_key",
  76. table="local_media_repository_thumbnails",
  77. columns=[
  78. "media_id",
  79. "thumbnail_width",
  80. "thumbnail_height",
  81. "thumbnail_type",
  82. "thumbnail_method",
  83. ],
  84. unique=True,
  85. )
  86. self.db_pool.updates.register_background_index_update(
  87. update_name="remote_media_repository_thumbnails_method_idx",
  88. index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
  89. table="remote_media_cache_thumbnails",
  90. columns=[
  91. "media_origin",
  92. "media_id",
  93. "thumbnail_width",
  94. "thumbnail_height",
  95. "thumbnail_type",
  96. "thumbnail_method",
  97. ],
  98. unique=True,
  99. )
  100. self.db_pool.updates.register_background_update_handler(
  101. BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
  102. self._drop_media_index_without_method,
  103. )
  104. async def _drop_media_index_without_method(
  105. self, progress: JsonDict, batch_size: int
  106. ) -> int:
  107. """background update handler which removes the old constraints.
  108. Note that this is only run on postgres.
  109. """
  110. def f(txn: LoggingTransaction) -> None:
  111. txn.execute(
  112. "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
  113. )
  114. txn.execute(
  115. "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
  116. )
  117. await self.db_pool.runInteraction("drop_media_indices_without_method", f)
  118. await self.db_pool.updates._end_background_update(
  119. BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
  120. )
  121. return 1
  122. class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
  123. """Persistence for attachments and avatars"""
  124. def __init__(
  125. self,
  126. database: DatabasePool,
  127. db_conn: LoggingDatabaseConnection,
  128. hs: "HomeServer",
  129. ):
  130. super().__init__(database, db_conn, hs)
  131. self.server_name: str = hs.hostname
  132. async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
  133. """Get the metadata for a local piece of media
  134. Returns:
  135. None if the media_id doesn't exist.
  136. """
  137. return await self.db_pool.simple_select_one(
  138. "local_media_repository",
  139. {"media_id": media_id},
  140. (
  141. "media_type",
  142. "media_length",
  143. "upload_name",
  144. "created_ts",
  145. "quarantined_by",
  146. "url_cache",
  147. "safe_from_quarantine",
  148. ),
  149. allow_none=True,
  150. desc="get_local_media",
  151. )
  152. async def get_local_media_by_user_paginate(
  153. self,
  154. start: int,
  155. limit: int,
  156. user_id: str,
  157. order_by: str = MediaSortOrder.CREATED_TS.value,
  158. direction: Direction = Direction.FORWARDS,
  159. ) -> Tuple[List[Dict[str, Any]], int]:
  160. """Get a paginated list of metadata for a local piece of media
  161. which an user_id has uploaded
  162. Args:
  163. start: offset in the list
  164. limit: maximum amount of media_ids to retrieve
  165. user_id: fully-qualified user id
  166. order_by: the sort order of the returned list
  167. direction: sort ascending or descending
  168. Returns:
  169. A paginated list of all metadata of user's media,
  170. plus the total count of all the user's media
  171. """
  172. def get_local_media_by_user_paginate_txn(
  173. txn: LoggingTransaction,
  174. ) -> Tuple[List[Dict[str, Any]], int]:
  175. # Set ordering
  176. order_by_column = MediaSortOrder(order_by).value
  177. if direction == Direction.BACKWARDS:
  178. order = "DESC"
  179. else:
  180. order = "ASC"
  181. args: List[Union[str, int]] = [user_id]
  182. sql = """
  183. SELECT COUNT(*) as total_media
  184. FROM local_media_repository
  185. WHERE user_id = ?
  186. """
  187. txn.execute(sql, args)
  188. count = cast(Tuple[int], txn.fetchone())[0]
  189. sql = """
  190. SELECT
  191. "media_id",
  192. "media_type",
  193. "media_length",
  194. "upload_name",
  195. "created_ts",
  196. "last_access_ts",
  197. "quarantined_by",
  198. "safe_from_quarantine"
  199. FROM local_media_repository
  200. WHERE user_id = ?
  201. ORDER BY {order_by_column} {order}, media_id ASC
  202. LIMIT ? OFFSET ?
  203. """.format(
  204. order_by_column=order_by_column,
  205. order=order,
  206. )
  207. args += [limit, start]
  208. txn.execute(sql, args)
  209. media = self.db_pool.cursor_to_dict(txn)
  210. return media, count
  211. return await self.db_pool.runInteraction(
  212. "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
  213. )
  214. async def get_local_media_ids(
  215. self,
  216. before_ts: int,
  217. size_gt: int,
  218. keep_profiles: bool,
  219. include_quarantined_media: bool,
  220. include_protected_media: bool,
  221. ) -> List[str]:
  222. """
  223. Retrieve a list of media IDs from the local media store.
  224. Args:
  225. before_ts: Only retrieve IDs from media that was either last accessed
  226. (or if never accessed, created) before the given UNIX timestamp in ms.
  227. size_gt: Only retrieve IDs from media that has a size (in bytes) greater than
  228. the given integer.
  229. keep_profiles: If True, exclude media IDs from the results that are used in the
  230. following situations:
  231. * global profile user avatar
  232. * per-room profile user avatar
  233. * room avatar
  234. * a user's avatar in the user directory
  235. include_quarantined_media: If False, exclude media IDs from the results that have
  236. been marked as quarantined.
  237. include_protected_media: If False, exclude media IDs from the results that have
  238. been marked as protected from quarantine.
  239. Returns:
  240. A list of local media IDs.
  241. """
  242. # to find files that have never been accessed (last_access_ts IS NULL)
  243. # compare with `created_ts`
  244. sql = """
  245. SELECT media_id
  246. FROM local_media_repository AS lmr
  247. WHERE
  248. ( last_access_ts < ?
  249. OR ( created_ts < ? AND last_access_ts IS NULL ) )
  250. AND media_length > ?
  251. """
  252. if keep_profiles:
  253. sql_keep = """
  254. AND (
  255. NOT EXISTS
  256. (SELECT 1
  257. FROM profiles
  258. WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
  259. AND NOT EXISTS
  260. (SELECT 1
  261. FROM room_memberships
  262. WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
  263. AND NOT EXISTS
  264. (SELECT 1
  265. FROM user_directory
  266. WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id)
  267. AND NOT EXISTS
  268. (SELECT 1
  269. FROM room_stats_state
  270. WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id)
  271. )
  272. """.format(
  273. media_prefix="mxc://%s/" % (self.server_name,),
  274. )
  275. sql += sql_keep
  276. if include_quarantined_media is False:
  277. # Do not include media that has been quarantined
  278. sql += """
  279. AND quarantined_by IS NULL
  280. """
  281. if include_protected_media is False:
  282. # Do not include media that has been protected from quarantine
  283. sql += """
  284. AND NOT safe_from_quarantine
  285. """
  286. def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]:
  287. txn.execute(sql, (before_ts, before_ts, size_gt))
  288. return [row[0] for row in txn]
  289. return await self.db_pool.runInteraction(
  290. "get_local_media_ids", _get_local_media_ids_txn
  291. )
  292. async def store_local_media(
  293. self,
  294. media_id: str,
  295. media_type: str,
  296. time_now_ms: int,
  297. upload_name: Optional[str],
  298. media_length: int,
  299. user_id: UserID,
  300. url_cache: Optional[str] = None,
  301. ) -> None:
  302. await self.db_pool.simple_insert(
  303. "local_media_repository",
  304. {
  305. "media_id": media_id,
  306. "media_type": media_type,
  307. "created_ts": time_now_ms,
  308. "upload_name": upload_name,
  309. "media_length": media_length,
  310. "user_id": user_id.to_string(),
  311. "url_cache": url_cache,
  312. },
  313. desc="store_local_media",
  314. )
  315. async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
  316. """Mark a local media as safe or unsafe from quarantining."""
  317. await self.db_pool.simple_update_one(
  318. table="local_media_repository",
  319. keyvalues={"media_id": media_id},
  320. updatevalues={"safe_from_quarantine": safe},
  321. desc="mark_local_media_as_safe",
  322. )
  323. async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
  324. """Get the media_id and ts for a cached URL as of the given timestamp
  325. Returns:
  326. None if the URL isn't cached.
  327. """
  328. def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
  329. # get the most recently cached result (relative to the given ts)
  330. sql = (
  331. "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
  332. " FROM local_media_repository_url_cache"
  333. " WHERE url = ? AND download_ts <= ?"
  334. " ORDER BY download_ts DESC LIMIT 1"
  335. )
  336. txn.execute(sql, (url, ts))
  337. row = txn.fetchone()
  338. if not row:
  339. # ...or if we've requested a timestamp older than the oldest
  340. # copy in the cache, return the oldest copy (if any)
  341. sql = (
  342. "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
  343. " FROM local_media_repository_url_cache"
  344. " WHERE url = ? AND download_ts > ?"
  345. " ORDER BY download_ts ASC LIMIT 1"
  346. )
  347. txn.execute(sql, (url, ts))
  348. row = txn.fetchone()
  349. if not row:
  350. return None
  351. return dict(
  352. zip(
  353. (
  354. "response_code",
  355. "etag",
  356. "expires_ts",
  357. "og",
  358. "media_id",
  359. "download_ts",
  360. ),
  361. row,
  362. )
  363. )
  364. return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
  365. async def store_url_cache(
  366. self,
  367. url: str,
  368. response_code: int,
  369. etag: Optional[str],
  370. expires_ts: int,
  371. og: Optional[str],
  372. media_id: str,
  373. download_ts: int,
  374. ) -> None:
  375. await self.db_pool.simple_insert(
  376. "local_media_repository_url_cache",
  377. {
  378. "url": url,
  379. "response_code": response_code,
  380. "etag": etag,
  381. "expires_ts": expires_ts,
  382. "og": og,
  383. "media_id": media_id,
  384. "download_ts": download_ts,
  385. },
  386. desc="store_url_cache",
  387. )
  388. async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
  389. return await self.db_pool.simple_select_list(
  390. "local_media_repository_thumbnails",
  391. {"media_id": media_id},
  392. (
  393. "thumbnail_width",
  394. "thumbnail_height",
  395. "thumbnail_method",
  396. "thumbnail_type",
  397. "thumbnail_length",
  398. ),
  399. desc="get_local_media_thumbnails",
  400. )
  401. async def store_local_thumbnail(
  402. self,
  403. media_id: str,
  404. thumbnail_width: int,
  405. thumbnail_height: int,
  406. thumbnail_type: str,
  407. thumbnail_method: str,
  408. thumbnail_length: int,
  409. ) -> None:
  410. await self.db_pool.simple_upsert(
  411. table="local_media_repository_thumbnails",
  412. keyvalues={
  413. "media_id": media_id,
  414. "thumbnail_width": thumbnail_width,
  415. "thumbnail_height": thumbnail_height,
  416. "thumbnail_method": thumbnail_method,
  417. "thumbnail_type": thumbnail_type,
  418. },
  419. values={"thumbnail_length": thumbnail_length},
  420. desc="store_local_thumbnail",
  421. )
  422. async def get_cached_remote_media(
  423. self, origin: str, media_id: str
  424. ) -> Optional[Dict[str, Any]]:
  425. return await self.db_pool.simple_select_one(
  426. "remote_media_cache",
  427. {"media_origin": origin, "media_id": media_id},
  428. (
  429. "media_type",
  430. "media_length",
  431. "upload_name",
  432. "created_ts",
  433. "filesystem_id",
  434. "quarantined_by",
  435. ),
  436. allow_none=True,
  437. desc="get_cached_remote_media",
  438. )
  439. async def store_cached_remote_media(
  440. self,
  441. origin: str,
  442. media_id: str,
  443. media_type: str,
  444. media_length: int,
  445. time_now_ms: int,
  446. upload_name: Optional[str],
  447. filesystem_id: str,
  448. ) -> None:
  449. await self.db_pool.simple_insert(
  450. "remote_media_cache",
  451. {
  452. "media_origin": origin,
  453. "media_id": media_id,
  454. "media_type": media_type,
  455. "media_length": media_length,
  456. "created_ts": time_now_ms,
  457. "upload_name": upload_name,
  458. "filesystem_id": filesystem_id,
  459. "last_access_ts": time_now_ms,
  460. },
  461. desc="store_cached_remote_media",
  462. )
  463. async def update_cached_last_access_time(
  464. self,
  465. local_media: Iterable[str],
  466. remote_media: Iterable[Tuple[str, str]],
  467. time_ms: int,
  468. ) -> None:
  469. """Updates the last access time of the given media
  470. Args:
  471. local_media: Set of media_ids
  472. remote_media: Set of (server_name, media_id)
  473. time_ms: Current time in milliseconds
  474. """
  475. def update_cache_txn(txn: LoggingTransaction) -> None:
  476. sql = (
  477. "UPDATE remote_media_cache SET last_access_ts = ?"
  478. " WHERE media_origin = ? AND media_id = ?"
  479. )
  480. txn.execute_batch(
  481. sql,
  482. (
  483. (time_ms, media_origin, media_id)
  484. for media_origin, media_id in remote_media
  485. ),
  486. )
  487. sql = (
  488. "UPDATE local_media_repository SET last_access_ts = ?"
  489. " WHERE media_id = ?"
  490. )
  491. txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
  492. await self.db_pool.runInteraction(
  493. "update_cached_last_access_time", update_cache_txn
  494. )
  495. async def get_remote_media_thumbnails(
  496. self, origin: str, media_id: str
  497. ) -> List[Dict[str, Any]]:
  498. return await self.db_pool.simple_select_list(
  499. "remote_media_cache_thumbnails",
  500. {"media_origin": origin, "media_id": media_id},
  501. (
  502. "thumbnail_width",
  503. "thumbnail_height",
  504. "thumbnail_method",
  505. "thumbnail_type",
  506. "thumbnail_length",
  507. "filesystem_id",
  508. ),
  509. desc="get_remote_media_thumbnails",
  510. )
  511. async def get_remote_media_thumbnail(
  512. self,
  513. origin: str,
  514. media_id: str,
  515. t_width: int,
  516. t_height: int,
  517. t_type: str,
  518. ) -> Optional[Dict[str, Any]]:
  519. """Fetch the thumbnail info of given width, height and type."""
  520. return await self.db_pool.simple_select_one(
  521. table="remote_media_cache_thumbnails",
  522. keyvalues={
  523. "media_origin": origin,
  524. "media_id": media_id,
  525. "thumbnail_width": t_width,
  526. "thumbnail_height": t_height,
  527. "thumbnail_type": t_type,
  528. },
  529. retcols=(
  530. "thumbnail_width",
  531. "thumbnail_height",
  532. "thumbnail_method",
  533. "thumbnail_type",
  534. "thumbnail_length",
  535. "filesystem_id",
  536. ),
  537. allow_none=True,
  538. desc="get_remote_media_thumbnail",
  539. )
  540. async def store_remote_media_thumbnail(
  541. self,
  542. origin: str,
  543. media_id: str,
  544. filesystem_id: str,
  545. thumbnail_width: int,
  546. thumbnail_height: int,
  547. thumbnail_type: str,
  548. thumbnail_method: str,
  549. thumbnail_length: int,
  550. ) -> None:
  551. await self.db_pool.simple_upsert(
  552. table="remote_media_cache_thumbnails",
  553. keyvalues={
  554. "media_origin": origin,
  555. "media_id": media_id,
  556. "thumbnail_width": thumbnail_width,
  557. "thumbnail_height": thumbnail_height,
  558. "thumbnail_method": thumbnail_method,
  559. "thumbnail_type": thumbnail_type,
  560. },
  561. values={"thumbnail_length": thumbnail_length},
  562. insertion_values={"filesystem_id": filesystem_id},
  563. desc="store_remote_media_thumbnail",
  564. )
  565. async def get_remote_media_ids(
  566. self, before_ts: int, include_quarantined_media: bool
  567. ) -> List[Dict[str, str]]:
  568. """
  569. Retrieve a list of server name, media ID tuples from the remote media cache.
  570. Args:
  571. before_ts: Only retrieve IDs from media that was either last accessed
  572. (or if never accessed, created) before the given UNIX timestamp in ms.
  573. include_quarantined_media: If False, exclude media IDs from the results that have
  574. been marked as quarantined.
  575. Returns:
  576. A list of tuples containing:
  577. * The server name of homeserver where the media originates from,
  578. * The ID of the media.
  579. """
  580. sql = (
  581. "SELECT media_origin, media_id, filesystem_id"
  582. " FROM remote_media_cache"
  583. " WHERE last_access_ts < ?"
  584. )
  585. if include_quarantined_media is False:
  586. # Only include media that has not been quarantined
  587. sql += """
  588. AND quarantined_by IS NULL
  589. """
  590. return await self.db_pool.execute(
  591. "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
  592. )
  593. async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
  594. def delete_remote_media_txn(txn: LoggingTransaction) -> None:
  595. self.db_pool.simple_delete_txn(
  596. txn,
  597. "remote_media_cache",
  598. keyvalues={"media_origin": media_origin, "media_id": media_id},
  599. )
  600. self.db_pool.simple_delete_txn(
  601. txn,
  602. "remote_media_cache_thumbnails",
  603. keyvalues={"media_origin": media_origin, "media_id": media_id},
  604. )
  605. await self.db_pool.runInteraction(
  606. "delete_remote_media", delete_remote_media_txn
  607. )
  608. async def get_expired_url_cache(self, now_ts: int) -> List[str]:
  609. sql = (
  610. "SELECT media_id FROM local_media_repository_url_cache"
  611. " WHERE expires_ts < ?"
  612. " ORDER BY expires_ts ASC"
  613. " LIMIT 500"
  614. )
  615. def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]:
  616. txn.execute(sql, (now_ts,))
  617. return [row[0] for row in txn]
  618. return await self.db_pool.runInteraction(
  619. "get_expired_url_cache", _get_expired_url_cache_txn
  620. )
  621. async def delete_url_cache(self, media_ids: Collection[str]) -> None:
  622. if len(media_ids) == 0:
  623. return
  624. sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
  625. def _delete_url_cache_txn(txn: LoggingTransaction) -> None:
  626. txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
  627. await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn)
  628. async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
  629. sql = (
  630. "SELECT media_id FROM local_media_repository"
  631. " WHERE created_ts < ? AND url_cache IS NOT NULL"
  632. " ORDER BY created_ts ASC"
  633. " LIMIT 500"
  634. )
  635. def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]:
  636. txn.execute(sql, (before_ts,))
  637. return [row[0] for row in txn]
  638. return await self.db_pool.runInteraction(
  639. "get_url_cache_media_before", _get_url_cache_media_before_txn
  640. )
  641. async def delete_url_cache_media(self, media_ids: Collection[str]) -> None:
  642. if len(media_ids) == 0:
  643. return
  644. def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None:
  645. sql = "DELETE FROM local_media_repository WHERE media_id = ?"
  646. txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
  647. sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
  648. txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
  649. await self.db_pool.runInteraction(
  650. "delete_url_cache_media", _delete_url_cache_media_txn
  651. )