You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

225 lines
8.0 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
  16. from canonicaljson import encode_canonical_json
  17. from synapse.api.errors import Codes, StoreError, SynapseError
  18. from synapse.storage._base import SQLBaseStore, db_to_json
  19. from synapse.storage.database import (
  20. DatabasePool,
  21. LoggingDatabaseConnection,
  22. LoggingTransaction,
  23. )
  24. from synapse.storage.engines import PostgresEngine
  25. from synapse.types import JsonDict, UserID
  26. from synapse.util.caches.descriptors import cached
  27. if TYPE_CHECKING:
  28. from synapse.server import HomeServer
  29. class FilteringWorkerStore(SQLBaseStore):
  30. def __init__(
  31. self,
  32. database: DatabasePool,
  33. db_conn: LoggingDatabaseConnection,
  34. hs: "HomeServer",
  35. ):
  36. super().__init__(database, db_conn, hs)
  37. self.server_name: str = hs.hostname
  38. self.database_engine = database.engine
  39. self.db_pool.updates.register_background_index_update(
  40. "full_users_filters_unique_idx",
  41. index_name="full_users_unique_idx",
  42. table="user_filters",
  43. columns=["full_user_id, filter_id"],
  44. unique=True,
  45. )
  46. self.db_pool.updates.register_background_update_handler(
  47. "populate_full_user_id_user_filters",
  48. self.populate_full_user_id_user_filters,
  49. )
  50. async def populate_full_user_id_user_filters(
  51. self, progress: JsonDict, batch_size: int
  52. ) -> int:
  53. """
  54. Background update to populate the column `full_user_id` of the table
  55. user_filters from entries in the column `user_local_part` of the same table
  56. """
  57. lower_bound_id = progress.get("lower_bound_id", "")
  58. def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
  59. sql = """
  60. SELECT user_id FROM user_filters
  61. WHERE user_id > ?
  62. ORDER BY user_id
  63. LIMIT 1 OFFSET 1000
  64. """
  65. txn.execute(sql, (lower_bound_id,))
  66. res = txn.fetchone()
  67. if res:
  68. upper_bound_id = res[0]
  69. return upper_bound_id
  70. else:
  71. return None
  72. def _process_batch(
  73. txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
  74. ) -> None:
  75. sql = """
  76. UPDATE user_filters
  77. SET full_user_id = '@' || user_id || ?
  78. WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
  79. """
  80. txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
  81. def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
  82. sql = """
  83. UPDATE user_filters
  84. SET full_user_id = '@' || user_id || ?
  85. WHERE ? < user_id AND full_user_id IS NULL
  86. """
  87. txn.execute(
  88. sql,
  89. (
  90. f":{self.server_name}",
  91. lower_bound_id,
  92. ),
  93. )
  94. if isinstance(self.database_engine, PostgresEngine):
  95. sql = """
  96. ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
  97. """
  98. txn.execute(sql)
  99. upper_bound_id = await self.db_pool.runInteraction(
  100. "populate_full_user_id_user_filters", _get_last_id
  101. )
  102. if upper_bound_id is None:
  103. await self.db_pool.runInteraction(
  104. "populate_full_user_id_user_filters", _final_batch, lower_bound_id
  105. )
  106. await self.db_pool.updates._end_background_update(
  107. "populate_full_user_id_user_filters"
  108. )
  109. return 1
  110. await self.db_pool.runInteraction(
  111. "populate_full_user_id_user_filters",
  112. _process_batch,
  113. lower_bound_id,
  114. upper_bound_id,
  115. )
  116. progress["lower_bound_id"] = upper_bound_id
  117. await self.db_pool.runInteraction(
  118. "populate_full_user_id_user_filters",
  119. self.db_pool.updates._background_update_progress_txn,
  120. "populate_full_user_id_user_filters",
  121. progress,
  122. )
  123. return 50
  124. @cached(num_args=2)
  125. async def get_user_filter(
  126. self, user_id: UserID, filter_id: Union[int, str]
  127. ) -> JsonDict:
  128. # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
  129. # with a coherent error message rather than 500 M_UNKNOWN.
  130. try:
  131. int(filter_id)
  132. except ValueError:
  133. raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
  134. def_json = await self.db_pool.simple_select_one_onecol(
  135. table="user_filters",
  136. keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
  137. retcol="filter_json",
  138. allow_none=False,
  139. desc="get_user_filter",
  140. )
  141. return db_to_json(def_json)
  142. async def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> int:
  143. def_json = encode_canonical_json(user_filter)
  144. # Need an atomic transaction to SELECT the maximal ID so far then
  145. # INSERT a new one
  146. def _do_txn(txn: LoggingTransaction) -> int:
  147. sql = (
  148. "SELECT filter_id FROM user_filters "
  149. "WHERE full_user_id = ? AND filter_json = ?"
  150. )
  151. txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
  152. filter_id_response = txn.fetchone()
  153. if filter_id_response is not None:
  154. return filter_id_response[0]
  155. sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
  156. txn.execute(sql, (user_id.to_string(),))
  157. max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
  158. if max_id is None:
  159. filter_id = 0
  160. else:
  161. filter_id = max_id + 1
  162. sql = (
  163. "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
  164. "VALUES(?, ?, ?, ?)"
  165. )
  166. txn.execute(
  167. sql,
  168. (
  169. user_id.to_string(),
  170. user_id.localpart,
  171. filter_id,
  172. bytearray(def_json),
  173. ),
  174. )
  175. return filter_id
  176. attempts = 0
  177. while True:
  178. # Try a few times.
  179. # This is technically needed if a user tries to create two filters at once,
  180. # leading to two concurrent transactions.
  181. # The failure case would be:
  182. # - SELECT filter_id ... filter_json = ? → both transactions return no rows
  183. # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
  184. # - INSERT INTO ... → both transactions insert filter_id = 6
  185. # One of the transactions will commit. The other will get a unique key
  186. # constraint violation error (IntegrityError). This is not the same as a
  187. # serialisability violation, which would be automatically retried by
  188. # `runInteraction`.
  189. try:
  190. return await self.db_pool.runInteraction("add_user_filter", _do_txn)
  191. except self.db_pool.engine.module.IntegrityError:
  192. attempts += 1
  193. if attempts >= 5:
  194. raise StoreError(500, "Couldn't generate a filter ID.")