|
- # Copyright 2014-2016 OpenMarket Ltd
- # Copyright 2018 New Vector Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import logging
- from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- FrozenSet,
- Iterable,
- List,
- Optional,
- Tuple,
- cast,
- )
-
- from synapse.api.constants import AccountDataTypes
- from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
- from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
- from synapse.storage._base import db_to_json
- from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- )
- from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
- from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
- from synapse.storage.engines import PostgresEngine
- from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
- MultiWriterIdGenerator,
- StreamIdGenerator,
- )
- from synapse.types import JsonDict
- from synapse.util import json_encoder
- from synapse.util.caches.descriptors import cached
- from synapse.util.caches.stream_change_cache import StreamChangeCache
-
- if TYPE_CHECKING:
- from synapse.server import HomeServer
-
- logger = logging.getLogger(__name__)
-
-
- class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- # `_can_write_to_account_data` indicates whether the current worker is allowed
- # to write account data. A value of `True` implies that `_account_data_id_gen`
- # is an `AbstractStreamIdGenerator` and not just a tracker.
- self._account_data_id_gen: AbstractStreamIdTracker
-
- if isinstance(database.engine, PostgresEngine):
- self._can_write_to_account_data = (
- self._instance_name in hs.config.worker.writers.account_data
- )
-
- self._account_data_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- stream_name="account_data",
- instance_name=self._instance_name,
- tables=[
- ("room_account_data", "instance_name", "stream_id"),
- ("room_tags_revisions", "instance_name", "stream_id"),
- ("account_data", "instance_name", "stream_id"),
- ],
- sequence_name="account_data_sequence",
- writers=hs.config.worker.writers.account_data,
- )
- else:
- # We shouldn't be running in worker mode with SQLite, but its useful
- # to support it for unit tests.
- #
- # If this process is the writer than we need to use
- # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
- # updated over replication. (Multiple writers are not supported for
- # SQLite).
- if self._instance_name in hs.config.worker.writers.account_data:
- self._can_write_to_account_data = True
- self._account_data_id_gen = StreamIdGenerator(
- db_conn,
- "room_account_data",
- "stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
- )
- else:
- self._account_data_id_gen = SlavedIdTracker(
- db_conn,
- "room_account_data",
- "stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
- )
-
- account_max = self.get_max_account_data_stream_id()
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache", account_max
- )
-
- self.db_pool.updates.register_background_update_handler(
- "delete_account_data_for_deactivated_users",
- self._delete_account_data_for_deactivated_users,
- )
-
- def get_max_account_data_stream_id(self) -> int:
- """Get the current max stream ID for account data stream
-
- Returns:
- int
- """
- return self._account_data_id_gen.get_current_token()
-
- @cached()
- async def get_account_data_for_user(
- self, user_id: str
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a user.
-
- Args:
- user_id: The user to get the account_data for.
- Returns:
- A 2-tuple of a dict of global account_data and a dict mapping from
- room_id string to per room account_data dicts.
- """
-
- def get_account_data_for_user_txn(
- txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "account_data",
- {"user_id": user_id},
- ["account_data_type", "content"],
- )
-
- global_account_data = {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
- }
-
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id},
- ["room_id", "account_data_type", "content"],
- )
-
- by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_data = by_room.setdefault(row["room_id"], {})
- room_data[row["account_data_type"]] = db_to_json(row["content"])
-
- return global_account_data, by_room
-
- return await self.db_pool.runInteraction(
- "get_account_data_for_user", get_account_data_for_user_txn
- )
-
- @cached(num_args=2, max_entries=5000, tree=True)
- async def get_global_account_data_by_type_for_user(
- self, user_id: str, data_type: str
- ) -> Optional[JsonDict]:
- """
- Returns:
- The account data.
- """
- result = await self.db_pool.simple_select_one_onecol(
- table="account_data",
- keyvalues={"user_id": user_id, "account_data_type": data_type},
- retcol="content",
- desc="get_global_account_data_by_type_for_user",
- allow_none=True,
- )
-
- if result:
- return db_to_json(result)
- else:
- return None
-
- @cached(num_args=2, tree=True)
- async def get_account_data_for_room(
- self, user_id: str, room_id: str
- ) -> Dict[str, JsonDict]:
- """Get all the client account_data for a user for a room.
-
- Args:
- user_id: The user to get the account_data for.
- room_id: The room to get the account_data for.
- Returns:
- A dict of the room account_data
- """
-
- def get_account_data_for_room_txn(
- txn: LoggingTransaction,
- ) -> Dict[str, JsonDict]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"],
- )
-
- return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
- }
-
- return await self.db_pool.runInteraction(
- "get_account_data_for_room", get_account_data_for_room_txn
- )
-
- @cached(num_args=3, max_entries=5000, tree=True)
- async def get_account_data_for_room_and_type(
- self, user_id: str, room_id: str, account_data_type: str
- ) -> Optional[JsonDict]:
- """Get the client account_data of given type for a user for a room.
-
- Args:
- user_id: The user to get the account_data for.
- room_id: The room to get the account_data for.
- account_data_type: The account data type to get.
- Returns:
- The room account_data for that type, or None if there isn't any set.
- """
-
- def get_account_data_for_room_and_type_txn(
- txn: LoggingTransaction,
- ) -> Optional[JsonDict]:
- content_json = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="room_account_data",
- keyvalues={
- "user_id": user_id,
- "room_id": room_id,
- "account_data_type": account_data_type,
- },
- retcol="content",
- allow_none=True,
- )
-
- return db_to_json(content_json) if content_json else None
-
- return await self.db_pool.runInteraction(
- "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
- )
-
- async def get_updated_global_account_data(
- self, last_id: int, current_id: int, limit: int
- ) -> List[Tuple[int, str, str]]:
- """Get the global account_data that has changed, for the account_data stream
-
- Args:
- last_id: the last stream_id from the previous batch.
- current_id: the maximum stream_id to return up to
- limit: the maximum number of rows to return
-
- Returns:
- A list of tuples of stream_id int, user_id string,
- and type string.
- """
- if last_id == current_id:
- return []
-
- def get_updated_global_account_data_txn(
- txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str]]:
- sql = (
- "SELECT stream_id, user_id, account_data_type"
- " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- return cast(List[Tuple[int, str, str]], txn.fetchall())
-
- return await self.db_pool.runInteraction(
- "get_updated_global_account_data", get_updated_global_account_data_txn
- )
-
- async def get_updated_room_account_data(
- self, last_id: int, current_id: int, limit: int
- ) -> List[Tuple[int, str, str, str]]:
- """Get the global account_data that has changed, for the account_data stream
-
- Args:
- last_id: the last stream_id from the previous batch.
- current_id: the maximum stream_id to return up to
- limit: the maximum number of rows to return
-
- Returns:
- A list of tuples of stream_id int, user_id string,
- room_id string and type string.
- """
- if last_id == current_id:
- return []
-
- def get_updated_room_account_data_txn(
- txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str]]:
- sql = (
- "SELECT stream_id, user_id, room_id, account_data_type"
- " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- return cast(List[Tuple[int, str, str, str]], txn.fetchall())
-
- return await self.db_pool.runInteraction(
- "get_updated_room_account_data", get_updated_room_account_data_txn
- )
-
- async def get_updated_account_data_for_user(
- self, user_id: str, stream_id: int
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a that's changed for a user
-
- Args:
- user_id: The user to get the account_data for.
- stream_id: The point in the stream since which to get updates
- Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
- """
-
- def get_updated_account_data_for_user_txn(
- txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- sql = (
- "SELECT account_data_type, content FROM account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
-
- txn.execute(sql, (user_id, stream_id))
-
- global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
-
- sql = (
- "SELECT room_id, account_data_type, content FROM room_account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
-
- txn.execute(sql, (user_id, stream_id))
-
- account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in txn:
- room_account_data = account_data_by_room.setdefault(row[0], {})
- room_account_data[row[1]] = db_to_json(row[2])
-
- return global_account_data, account_data_by_room
-
- changed = self._account_data_stream_cache.has_entity_changed(
- user_id, int(stream_id)
- )
- if not changed:
- return {}, {}
-
- return await self.db_pool.runInteraction(
- "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
- )
-
- @cached(max_entries=5000, iterable=True)
- async def ignored_by(self, user_id: str) -> FrozenSet[str]:
- """
- Get users which ignore the given user.
-
- Params:
- user_id: The user ID which might be ignored.
-
- Return:
- The user IDs which ignore the given user.
- """
- return frozenset(
- await self.db_pool.simple_select_onecol(
- table="ignored_users",
- keyvalues={"ignored_user_id": user_id},
- retcol="ignorer_user_id",
- desc="ignored_by",
- )
- )
-
- @cached(max_entries=5000, iterable=True)
- async def ignored_users(self, user_id: str) -> FrozenSet[str]:
- """
- Get users which the given user ignores.
-
- Params:
- user_id: The user ID which is making the request.
-
- Return:
- The user IDs which are ignored by the given user.
- """
- return frozenset(
- await self.db_pool.simple_select_onecol(
- table="ignored_users",
- keyvalues={"ignorer_user_id": user_id},
- retcol="ignored_user_id",
- desc="ignored_users",
- )
- )
-
- def process_replication_rows(
- self,
- stream_name: str,
- instance_name: str,
- token: int,
- rows: Iterable[Any],
- ) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- for row in rows:
- if not row.room_id:
- self.get_global_account_data_by_type_for_user.invalidate(
- (row.user_id, row.data_type)
- )
- self.get_account_data_for_user.invalidate((row.user_id,))
- self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
- self.get_account_data_for_room_and_type.invalidate(
- (row.user_id, row.room_id, row.data_type)
- )
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-
- super().process_replication_rows(stream_name, instance_name, token, rows)
-
- async def add_account_data_to_room(
- self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
- ) -> int:
- """Add some account_data to a room for a user.
-
- Args:
- user_id: The user to add a tag for.
- room_id: The room to add a tag for.
- account_data_type: The type of account_data to add.
- content: A json object to associate with the tag.
-
- Returns:
- The maximum stream ID.
- """
- assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
-
- content_json = json_encoder.encode(content)
-
- async with self._account_data_id_gen.get_next() as next_id:
- # no need to lock here as room_account_data has a unique constraint
- # on (user_id, room_id, account_data_type) so simple_upsert will
- # retry if there is a conflict.
- await self.db_pool.simple_upsert(
- desc="add_room_account_data",
- table="room_account_data",
- keyvalues={
- "user_id": user_id,
- "room_id": room_id,
- "account_data_type": account_data_type,
- },
- values={"stream_id": next_id, "content": content_json},
- lock=False,
- )
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
- self.get_account_data_for_room.invalidate((user_id, room_id))
- self.get_account_data_for_room_and_type.prefill(
- (user_id, room_id, account_data_type), content
- )
-
- return self._account_data_id_gen.get_current_token()
-
- async def add_account_data_for_user(
- self, user_id: str, account_data_type: str, content: JsonDict
- ) -> int:
- """Add some global account_data for a user.
-
- Args:
- user_id: The user to add a tag for.
- account_data_type: The type of account_data to add.
- content: A json object to associate with the tag.
-
- Returns:
- The maximum stream ID.
- """
- assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
-
- async with self._account_data_id_gen.get_next() as next_id:
- await self.db_pool.runInteraction(
- "add_user_account_data",
- self._add_account_data_for_user,
- next_id,
- user_id,
- account_data_type,
- content,
- )
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
- self.get_global_account_data_by_type_for_user.invalidate(
- (user_id, account_data_type)
- )
-
- return self._account_data_id_gen.get_current_token()
-
- def _add_account_data_for_user(
- self,
- txn: LoggingTransaction,
- next_id: int,
- user_id: str,
- account_data_type: str,
- content: JsonDict,
- ) -> None:
- content_json = json_encoder.encode(content)
-
- # no need to lock here as account_data has a unique constraint on
- # (user_id, account_data_type) so simple_upsert will retry if
- # there is a conflict.
- self.db_pool.simple_upsert_txn(
- txn,
- table="account_data",
- keyvalues={"user_id": user_id, "account_data_type": account_data_type},
- values={"stream_id": next_id, "content": content_json},
- lock=False,
- )
-
- # Ignored users get denormalized into a separate table as an optimisation.
- if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
- return
-
- # Insert / delete to sync the list of ignored users.
- previously_ignored_users = set(
- self.db_pool.simple_select_onecol_txn(
- txn,
- table="ignored_users",
- keyvalues={"ignorer_user_id": user_id},
- retcol="ignored_user_id",
- )
- )
-
- # If the data is invalid, no one is ignored.
- ignored_users_content = content.get("ignored_users", {})
- if isinstance(ignored_users_content, dict):
- currently_ignored_users = set(ignored_users_content)
- else:
- currently_ignored_users = set()
-
- # If the data has not changed, nothing to do.
- if previously_ignored_users == currently_ignored_users:
- return
-
- # Delete entries which are no longer ignored.
- self.db_pool.simple_delete_many_txn(
- txn,
- table="ignored_users",
- column="ignored_user_id",
- values=previously_ignored_users - currently_ignored_users,
- keyvalues={"ignorer_user_id": user_id},
- )
-
- # Add entries which are newly ignored.
- self.db_pool.simple_insert_many_txn(
- txn,
- table="ignored_users",
- keys=("ignorer_user_id", "ignored_user_id"),
- values=[
- (user_id, u) for u in currently_ignored_users - previously_ignored_users
- ],
- )
-
- # Invalidate the cache for any ignored users which were added or removed.
- for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
- self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
- self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
-
- async def purge_account_data_for_user(self, user_id: str) -> None:
- """
- Removes ALL the account data for a user.
- Intended to be used upon user deactivation.
-
- Also purges the user from the ignored_users cache table
- and the push_rules cache tables.
- """
-
- await self.db_pool.runInteraction(
- "purge_account_data_for_user_txn",
- self._purge_account_data_for_user_txn,
- user_id,
- )
-
- def _purge_account_data_for_user_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> None:
- """
- See `purge_account_data_for_user`.
- """
- # Purge from the primary account_data tables.
- self.db_pool.simple_delete_txn(
- txn, table="account_data", keyvalues={"user_id": user_id}
- )
-
- self.db_pool.simple_delete_txn(
- txn, table="room_account_data", keyvalues={"user_id": user_id}
- )
-
- # Purge from ignored_users where this user is the ignorer.
- # N.B. We don't purge where this user is the ignoree, because that
- # interferes with other users' account data.
- # It's also not this user's data to delete!
- self.db_pool.simple_delete_txn(
- txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
- )
-
- # Remove the push rules
- self.db_pool.simple_delete_txn(
- txn, table="push_rules", keyvalues={"user_name": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn, table="push_rules_enable", keyvalues={"user_name": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn, table="push_rules_stream", keyvalues={"user_id": user_id}
- )
-
- # Invalidate caches as appropriate
- self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_room_and_type, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_user, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_global_account_data_by_type_for_user, (user_id,)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_room, (user_id,)
- )
- self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_enabled_for_user, (user_id,)
- )
- # This user might be contained in the ignored_by cache for other users,
- # so we have to invalidate it all.
- self._invalidate_all_cache_and_stream(txn, self.ignored_by)
-
- async def _delete_account_data_for_deactivated_users(
- self, progress: dict, batch_size: int
- ) -> int:
- """
- Retroactively purges account data for users that have already been deactivated.
- Gets run as a background update caused by a schema delta.
- """
-
- last_user: str = progress.get("last_user", "")
-
- def _delete_account_data_for_deactivated_users_txn(
- txn: LoggingTransaction,
- ) -> int:
- sql = """
- SELECT name FROM users
- WHERE deactivated = ? and name > ?
- ORDER BY name ASC
- LIMIT ?
- """
-
- txn.execute(sql, (1, last_user, batch_size))
- users = [row[0] for row in txn]
-
- for user in users:
- self._purge_account_data_for_user_txn(txn, user_id=user)
-
- if users:
- self.db_pool.updates._background_update_progress_txn(
- txn,
- "delete_account_data_for_deactivated_users",
- {"last_user": users[-1]},
- )
-
- return len(users)
-
- number_deleted = await self.db_pool.runInteraction(
- "_delete_account_data_for_deactivated_users",
- _delete_account_data_for_deactivated_users_txn,
- )
-
- if number_deleted < batch_size:
- await self.db_pool.updates._end_background_update(
- "delete_account_data_for_deactivated_users"
- )
-
- return number_deleted
-
-
- class AccountDataStore(AccountDataWorkerStore):
- pass
|