You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

710 lines
25 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Dict,
  20. FrozenSet,
  21. Iterable,
  22. List,
  23. Optional,
  24. Tuple,
  25. cast,
  26. )
  27. from synapse.api.constants import AccountDataTypes
  28. from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
  29. from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
  30. from synapse.storage._base import db_to_json
  31. from synapse.storage.database import (
  32. DatabasePool,
  33. LoggingDatabaseConnection,
  34. LoggingTransaction,
  35. )
  36. from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
  37. from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
  38. from synapse.storage.engines import PostgresEngine
  39. from synapse.storage.util.id_generators import (
  40. AbstractStreamIdGenerator,
  41. AbstractStreamIdTracker,
  42. MultiWriterIdGenerator,
  43. StreamIdGenerator,
  44. )
  45. from synapse.types import JsonDict
  46. from synapse.util import json_encoder
  47. from synapse.util.caches.descriptors import cached
  48. from synapse.util.caches.stream_change_cache import StreamChangeCache
  49. if TYPE_CHECKING:
  50. from synapse.server import HomeServer
  51. logger = logging.getLogger(__name__)
  52. class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
  53. def __init__(
  54. self,
  55. database: DatabasePool,
  56. db_conn: LoggingDatabaseConnection,
  57. hs: "HomeServer",
  58. ):
  59. super().__init__(database, db_conn, hs)
  60. # `_can_write_to_account_data` indicates whether the current worker is allowed
  61. # to write account data. A value of `True` implies that `_account_data_id_gen`
  62. # is an `AbstractStreamIdGenerator` and not just a tracker.
  63. self._account_data_id_gen: AbstractStreamIdTracker
  64. if isinstance(database.engine, PostgresEngine):
  65. self._can_write_to_account_data = (
  66. self._instance_name in hs.config.worker.writers.account_data
  67. )
  68. self._account_data_id_gen = MultiWriterIdGenerator(
  69. db_conn=db_conn,
  70. db=database,
  71. stream_name="account_data",
  72. instance_name=self._instance_name,
  73. tables=[
  74. ("room_account_data", "instance_name", "stream_id"),
  75. ("room_tags_revisions", "instance_name", "stream_id"),
  76. ("account_data", "instance_name", "stream_id"),
  77. ],
  78. sequence_name="account_data_sequence",
  79. writers=hs.config.worker.writers.account_data,
  80. )
  81. else:
  82. # We shouldn't be running in worker mode with SQLite, but its useful
  83. # to support it for unit tests.
  84. #
  85. # If this process is the writer than we need to use
  86. # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
  87. # updated over replication. (Multiple writers are not supported for
  88. # SQLite).
  89. if self._instance_name in hs.config.worker.writers.account_data:
  90. self._can_write_to_account_data = True
  91. self._account_data_id_gen = StreamIdGenerator(
  92. db_conn,
  93. "room_account_data",
  94. "stream_id",
  95. extra_tables=[("room_tags_revisions", "stream_id")],
  96. )
  97. else:
  98. self._account_data_id_gen = SlavedIdTracker(
  99. db_conn,
  100. "room_account_data",
  101. "stream_id",
  102. extra_tables=[("room_tags_revisions", "stream_id")],
  103. )
  104. account_max = self.get_max_account_data_stream_id()
  105. self._account_data_stream_cache = StreamChangeCache(
  106. "AccountDataAndTagsChangeCache", account_max
  107. )
  108. self.db_pool.updates.register_background_update_handler(
  109. "delete_account_data_for_deactivated_users",
  110. self._delete_account_data_for_deactivated_users,
  111. )
  112. def get_max_account_data_stream_id(self) -> int:
  113. """Get the current max stream ID for account data stream
  114. Returns:
  115. int
  116. """
  117. return self._account_data_id_gen.get_current_token()
  118. @cached()
  119. async def get_account_data_for_user(
  120. self, user_id: str
  121. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  122. """Get all the client account_data for a user.
  123. Args:
  124. user_id: The user to get the account_data for.
  125. Returns:
  126. A 2-tuple of a dict of global account_data and a dict mapping from
  127. room_id string to per room account_data dicts.
  128. """
  129. def get_account_data_for_user_txn(
  130. txn: LoggingTransaction,
  131. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  132. rows = self.db_pool.simple_select_list_txn(
  133. txn,
  134. "account_data",
  135. {"user_id": user_id},
  136. ["account_data_type", "content"],
  137. )
  138. global_account_data = {
  139. row["account_data_type"]: db_to_json(row["content"]) for row in rows
  140. }
  141. rows = self.db_pool.simple_select_list_txn(
  142. txn,
  143. "room_account_data",
  144. {"user_id": user_id},
  145. ["room_id", "account_data_type", "content"],
  146. )
  147. by_room: Dict[str, Dict[str, JsonDict]] = {}
  148. for row in rows:
  149. room_data = by_room.setdefault(row["room_id"], {})
  150. room_data[row["account_data_type"]] = db_to_json(row["content"])
  151. return global_account_data, by_room
  152. return await self.db_pool.runInteraction(
  153. "get_account_data_for_user", get_account_data_for_user_txn
  154. )
  155. @cached(num_args=2, max_entries=5000, tree=True)
  156. async def get_global_account_data_by_type_for_user(
  157. self, user_id: str, data_type: str
  158. ) -> Optional[JsonDict]:
  159. """
  160. Returns:
  161. The account data.
  162. """
  163. result = await self.db_pool.simple_select_one_onecol(
  164. table="account_data",
  165. keyvalues={"user_id": user_id, "account_data_type": data_type},
  166. retcol="content",
  167. desc="get_global_account_data_by_type_for_user",
  168. allow_none=True,
  169. )
  170. if result:
  171. return db_to_json(result)
  172. else:
  173. return None
  174. @cached(num_args=2, tree=True)
  175. async def get_account_data_for_room(
  176. self, user_id: str, room_id: str
  177. ) -> Dict[str, JsonDict]:
  178. """Get all the client account_data for a user for a room.
  179. Args:
  180. user_id: The user to get the account_data for.
  181. room_id: The room to get the account_data for.
  182. Returns:
  183. A dict of the room account_data
  184. """
  185. def get_account_data_for_room_txn(
  186. txn: LoggingTransaction,
  187. ) -> Dict[str, JsonDict]:
  188. rows = self.db_pool.simple_select_list_txn(
  189. txn,
  190. "room_account_data",
  191. {"user_id": user_id, "room_id": room_id},
  192. ["account_data_type", "content"],
  193. )
  194. return {
  195. row["account_data_type"]: db_to_json(row["content"]) for row in rows
  196. }
  197. return await self.db_pool.runInteraction(
  198. "get_account_data_for_room", get_account_data_for_room_txn
  199. )
  200. @cached(num_args=3, max_entries=5000, tree=True)
  201. async def get_account_data_for_room_and_type(
  202. self, user_id: str, room_id: str, account_data_type: str
  203. ) -> Optional[JsonDict]:
  204. """Get the client account_data of given type for a user for a room.
  205. Args:
  206. user_id: The user to get the account_data for.
  207. room_id: The room to get the account_data for.
  208. account_data_type: The account data type to get.
  209. Returns:
  210. The room account_data for that type, or None if there isn't any set.
  211. """
  212. def get_account_data_for_room_and_type_txn(
  213. txn: LoggingTransaction,
  214. ) -> Optional[JsonDict]:
  215. content_json = self.db_pool.simple_select_one_onecol_txn(
  216. txn,
  217. table="room_account_data",
  218. keyvalues={
  219. "user_id": user_id,
  220. "room_id": room_id,
  221. "account_data_type": account_data_type,
  222. },
  223. retcol="content",
  224. allow_none=True,
  225. )
  226. return db_to_json(content_json) if content_json else None
  227. return await self.db_pool.runInteraction(
  228. "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
  229. )
  230. async def get_updated_global_account_data(
  231. self, last_id: int, current_id: int, limit: int
  232. ) -> List[Tuple[int, str, str]]:
  233. """Get the global account_data that has changed, for the account_data stream
  234. Args:
  235. last_id: the last stream_id from the previous batch.
  236. current_id: the maximum stream_id to return up to
  237. limit: the maximum number of rows to return
  238. Returns:
  239. A list of tuples of stream_id int, user_id string,
  240. and type string.
  241. """
  242. if last_id == current_id:
  243. return []
  244. def get_updated_global_account_data_txn(
  245. txn: LoggingTransaction,
  246. ) -> List[Tuple[int, str, str]]:
  247. sql = (
  248. "SELECT stream_id, user_id, account_data_type"
  249. " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
  250. " ORDER BY stream_id ASC LIMIT ?"
  251. )
  252. txn.execute(sql, (last_id, current_id, limit))
  253. return cast(List[Tuple[int, str, str]], txn.fetchall())
  254. return await self.db_pool.runInteraction(
  255. "get_updated_global_account_data", get_updated_global_account_data_txn
  256. )
  257. async def get_updated_room_account_data(
  258. self, last_id: int, current_id: int, limit: int
  259. ) -> List[Tuple[int, str, str, str]]:
  260. """Get the global account_data that has changed, for the account_data stream
  261. Args:
  262. last_id: the last stream_id from the previous batch.
  263. current_id: the maximum stream_id to return up to
  264. limit: the maximum number of rows to return
  265. Returns:
  266. A list of tuples of stream_id int, user_id string,
  267. room_id string and type string.
  268. """
  269. if last_id == current_id:
  270. return []
  271. def get_updated_room_account_data_txn(
  272. txn: LoggingTransaction,
  273. ) -> List[Tuple[int, str, str, str]]:
  274. sql = (
  275. "SELECT stream_id, user_id, room_id, account_data_type"
  276. " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
  277. " ORDER BY stream_id ASC LIMIT ?"
  278. )
  279. txn.execute(sql, (last_id, current_id, limit))
  280. return cast(List[Tuple[int, str, str, str]], txn.fetchall())
  281. return await self.db_pool.runInteraction(
  282. "get_updated_room_account_data", get_updated_room_account_data_txn
  283. )
  284. async def get_updated_account_data_for_user(
  285. self, user_id: str, stream_id: int
  286. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  287. """Get all the client account_data for a that's changed for a user
  288. Args:
  289. user_id: The user to get the account_data for.
  290. stream_id: The point in the stream since which to get updates
  291. Returns:
  292. A deferred pair of a dict of global account_data and a dict
  293. mapping from room_id string to per room account_data dicts.
  294. """
  295. def get_updated_account_data_for_user_txn(
  296. txn: LoggingTransaction,
  297. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  298. sql = (
  299. "SELECT account_data_type, content FROM account_data"
  300. " WHERE user_id = ? AND stream_id > ?"
  301. )
  302. txn.execute(sql, (user_id, stream_id))
  303. global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
  304. sql = (
  305. "SELECT room_id, account_data_type, content FROM room_account_data"
  306. " WHERE user_id = ? AND stream_id > ?"
  307. )
  308. txn.execute(sql, (user_id, stream_id))
  309. account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
  310. for row in txn:
  311. room_account_data = account_data_by_room.setdefault(row[0], {})
  312. room_account_data[row[1]] = db_to_json(row[2])
  313. return global_account_data, account_data_by_room
  314. changed = self._account_data_stream_cache.has_entity_changed(
  315. user_id, int(stream_id)
  316. )
  317. if not changed:
  318. return {}, {}
  319. return await self.db_pool.runInteraction(
  320. "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
  321. )
  322. @cached(max_entries=5000, iterable=True)
  323. async def ignored_by(self, user_id: str) -> FrozenSet[str]:
  324. """
  325. Get users which ignore the given user.
  326. Params:
  327. user_id: The user ID which might be ignored.
  328. Return:
  329. The user IDs which ignore the given user.
  330. """
  331. return frozenset(
  332. await self.db_pool.simple_select_onecol(
  333. table="ignored_users",
  334. keyvalues={"ignored_user_id": user_id},
  335. retcol="ignorer_user_id",
  336. desc="ignored_by",
  337. )
  338. )
  339. @cached(max_entries=5000, iterable=True)
  340. async def ignored_users(self, user_id: str) -> FrozenSet[str]:
  341. """
  342. Get users which the given user ignores.
  343. Params:
  344. user_id: The user ID which is making the request.
  345. Return:
  346. The user IDs which are ignored by the given user.
  347. """
  348. return frozenset(
  349. await self.db_pool.simple_select_onecol(
  350. table="ignored_users",
  351. keyvalues={"ignorer_user_id": user_id},
  352. retcol="ignored_user_id",
  353. desc="ignored_users",
  354. )
  355. )
  356. def process_replication_rows(
  357. self,
  358. stream_name: str,
  359. instance_name: str,
  360. token: int,
  361. rows: Iterable[Any],
  362. ) -> None:
  363. if stream_name == TagAccountDataStream.NAME:
  364. self._account_data_id_gen.advance(instance_name, token)
  365. elif stream_name == AccountDataStream.NAME:
  366. self._account_data_id_gen.advance(instance_name, token)
  367. for row in rows:
  368. if not row.room_id:
  369. self.get_global_account_data_by_type_for_user.invalidate(
  370. (row.user_id, row.data_type)
  371. )
  372. self.get_account_data_for_user.invalidate((row.user_id,))
  373. self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
  374. self.get_account_data_for_room_and_type.invalidate(
  375. (row.user_id, row.room_id, row.data_type)
  376. )
  377. self._account_data_stream_cache.entity_has_changed(row.user_id, token)
  378. super().process_replication_rows(stream_name, instance_name, token, rows)
  379. async def add_account_data_to_room(
  380. self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
  381. ) -> int:
  382. """Add some account_data to a room for a user.
  383. Args:
  384. user_id: The user to add a tag for.
  385. room_id: The room to add a tag for.
  386. account_data_type: The type of account_data to add.
  387. content: A json object to associate with the tag.
  388. Returns:
  389. The maximum stream ID.
  390. """
  391. assert self._can_write_to_account_data
  392. assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
  393. content_json = json_encoder.encode(content)
  394. async with self._account_data_id_gen.get_next() as next_id:
  395. # no need to lock here as room_account_data has a unique constraint
  396. # on (user_id, room_id, account_data_type) so simple_upsert will
  397. # retry if there is a conflict.
  398. await self.db_pool.simple_upsert(
  399. desc="add_room_account_data",
  400. table="room_account_data",
  401. keyvalues={
  402. "user_id": user_id,
  403. "room_id": room_id,
  404. "account_data_type": account_data_type,
  405. },
  406. values={"stream_id": next_id, "content": content_json},
  407. lock=False,
  408. )
  409. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  410. self.get_account_data_for_user.invalidate((user_id,))
  411. self.get_account_data_for_room.invalidate((user_id, room_id))
  412. self.get_account_data_for_room_and_type.prefill(
  413. (user_id, room_id, account_data_type), content
  414. )
  415. return self._account_data_id_gen.get_current_token()
  416. async def add_account_data_for_user(
  417. self, user_id: str, account_data_type: str, content: JsonDict
  418. ) -> int:
  419. """Add some global account_data for a user.
  420. Args:
  421. user_id: The user to add a tag for.
  422. account_data_type: The type of account_data to add.
  423. content: A json object to associate with the tag.
  424. Returns:
  425. The maximum stream ID.
  426. """
  427. assert self._can_write_to_account_data
  428. assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
  429. async with self._account_data_id_gen.get_next() as next_id:
  430. await self.db_pool.runInteraction(
  431. "add_user_account_data",
  432. self._add_account_data_for_user,
  433. next_id,
  434. user_id,
  435. account_data_type,
  436. content,
  437. )
  438. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  439. self.get_account_data_for_user.invalidate((user_id,))
  440. self.get_global_account_data_by_type_for_user.invalidate(
  441. (user_id, account_data_type)
  442. )
  443. return self._account_data_id_gen.get_current_token()
  444. def _add_account_data_for_user(
  445. self,
  446. txn: LoggingTransaction,
  447. next_id: int,
  448. user_id: str,
  449. account_data_type: str,
  450. content: JsonDict,
  451. ) -> None:
  452. content_json = json_encoder.encode(content)
  453. # no need to lock here as account_data has a unique constraint on
  454. # (user_id, account_data_type) so simple_upsert will retry if
  455. # there is a conflict.
  456. self.db_pool.simple_upsert_txn(
  457. txn,
  458. table="account_data",
  459. keyvalues={"user_id": user_id, "account_data_type": account_data_type},
  460. values={"stream_id": next_id, "content": content_json},
  461. lock=False,
  462. )
  463. # Ignored users get denormalized into a separate table as an optimisation.
  464. if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
  465. return
  466. # Insert / delete to sync the list of ignored users.
  467. previously_ignored_users = set(
  468. self.db_pool.simple_select_onecol_txn(
  469. txn,
  470. table="ignored_users",
  471. keyvalues={"ignorer_user_id": user_id},
  472. retcol="ignored_user_id",
  473. )
  474. )
  475. # If the data is invalid, no one is ignored.
  476. ignored_users_content = content.get("ignored_users", {})
  477. if isinstance(ignored_users_content, dict):
  478. currently_ignored_users = set(ignored_users_content)
  479. else:
  480. currently_ignored_users = set()
  481. # If the data has not changed, nothing to do.
  482. if previously_ignored_users == currently_ignored_users:
  483. return
  484. # Delete entries which are no longer ignored.
  485. self.db_pool.simple_delete_many_txn(
  486. txn,
  487. table="ignored_users",
  488. column="ignored_user_id",
  489. values=previously_ignored_users - currently_ignored_users,
  490. keyvalues={"ignorer_user_id": user_id},
  491. )
  492. # Add entries which are newly ignored.
  493. self.db_pool.simple_insert_many_txn(
  494. txn,
  495. table="ignored_users",
  496. keys=("ignorer_user_id", "ignored_user_id"),
  497. values=[
  498. (user_id, u) for u in currently_ignored_users - previously_ignored_users
  499. ],
  500. )
  501. # Invalidate the cache for any ignored users which were added or removed.
  502. for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
  503. self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
  504. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
  505. async def purge_account_data_for_user(self, user_id: str) -> None:
  506. """
  507. Removes ALL the account data for a user.
  508. Intended to be used upon user deactivation.
  509. Also purges the user from the ignored_users cache table
  510. and the push_rules cache tables.
  511. """
  512. await self.db_pool.runInteraction(
  513. "purge_account_data_for_user_txn",
  514. self._purge_account_data_for_user_txn,
  515. user_id,
  516. )
  517. def _purge_account_data_for_user_txn(
  518. self, txn: LoggingTransaction, user_id: str
  519. ) -> None:
  520. """
  521. See `purge_account_data_for_user`.
  522. """
  523. # Purge from the primary account_data tables.
  524. self.db_pool.simple_delete_txn(
  525. txn, table="account_data", keyvalues={"user_id": user_id}
  526. )
  527. self.db_pool.simple_delete_txn(
  528. txn, table="room_account_data", keyvalues={"user_id": user_id}
  529. )
  530. # Purge from ignored_users where this user is the ignorer.
  531. # N.B. We don't purge where this user is the ignoree, because that
  532. # interferes with other users' account data.
  533. # It's also not this user's data to delete!
  534. self.db_pool.simple_delete_txn(
  535. txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
  536. )
  537. # Remove the push rules
  538. self.db_pool.simple_delete_txn(
  539. txn, table="push_rules", keyvalues={"user_name": user_id}
  540. )
  541. self.db_pool.simple_delete_txn(
  542. txn, table="push_rules_enable", keyvalues={"user_name": user_id}
  543. )
  544. self.db_pool.simple_delete_txn(
  545. txn, table="push_rules_stream", keyvalues={"user_id": user_id}
  546. )
  547. # Invalidate caches as appropriate
  548. self._invalidate_cache_and_stream(
  549. txn, self.get_account_data_for_room_and_type, (user_id,)
  550. )
  551. self._invalidate_cache_and_stream(
  552. txn, self.get_account_data_for_user, (user_id,)
  553. )
  554. self._invalidate_cache_and_stream(
  555. txn, self.get_global_account_data_by_type_for_user, (user_id,)
  556. )
  557. self._invalidate_cache_and_stream(
  558. txn, self.get_account_data_for_room, (user_id,)
  559. )
  560. self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
  561. self._invalidate_cache_and_stream(
  562. txn, self.get_push_rules_enabled_for_user, (user_id,)
  563. )
  564. # This user might be contained in the ignored_by cache for other users,
  565. # so we have to invalidate it all.
  566. self._invalidate_all_cache_and_stream(txn, self.ignored_by)
  567. async def _delete_account_data_for_deactivated_users(
  568. self, progress: dict, batch_size: int
  569. ) -> int:
  570. """
  571. Retroactively purges account data for users that have already been deactivated.
  572. Gets run as a background update caused by a schema delta.
  573. """
  574. last_user: str = progress.get("last_user", "")
  575. def _delete_account_data_for_deactivated_users_txn(
  576. txn: LoggingTransaction,
  577. ) -> int:
  578. sql = """
  579. SELECT name FROM users
  580. WHERE deactivated = ? and name > ?
  581. ORDER BY name ASC
  582. LIMIT ?
  583. """
  584. txn.execute(sql, (1, last_user, batch_size))
  585. users = [row[0] for row in txn]
  586. for user in users:
  587. self._purge_account_data_for_user_txn(txn, user_id=user)
  588. if users:
  589. self.db_pool.updates._background_update_progress_txn(
  590. txn,
  591. "delete_account_data_for_deactivated_users",
  592. {"last_user": users[-1]},
  593. )
  594. return len(users)
  595. number_deleted = await self.db_pool.runInteraction(
  596. "_delete_account_data_for_deactivated_users",
  597. _delete_account_data_for_deactivated_users_txn,
  598. )
  599. if number_deleted < batch_size:
  600. await self.db_pool.updates._end_background_update(
  601. "delete_account_data_for_deactivated_users"
  602. )
  603. return number_deleted
  604. class AccountDataStore(AccountDataWorkerStore):
  605. pass