@@ -0,0 +1 @@ | |||
Add some type hints to datastore. |
@@ -28,8 +28,6 @@ exclude = (?x) | |||
|synapse/storage/databases/main/cache.py | |||
|synapse/storage/databases/main/devices.py | |||
|synapse/storage/databases/main/event_federation.py | |||
|synapse/storage/databases/main/push_rule.py | |||
|synapse/storage/databases/main/roommember.py | |||
|synapse/storage/schema/ | |||
|tests/api/test_auth.py | |||
@@ -15,7 +15,17 @@ | |||
import abc | |||
import logging | |||
from collections import OrderedDict | |||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Collection, | |||
Dict, | |||
Hashable, | |||
Iterable, | |||
List, | |||
Optional, | |||
Set, | |||
Tuple, | |||
) | |||
import attr | |||
from prometheus_client import Counter | |||
@@ -409,7 +419,7 @@ class FederationSender(AbstractFederationSender): | |||
) | |||
return | |||
destinations: Optional[Set[str]] = None | |||
destinations: Optional[Collection[str]] = None | |||
if not event.prev_event_ids(): | |||
# If there are no prev event IDs then the state is empty | |||
# and so no remote servers in the room | |||
@@ -444,7 +454,7 @@ class FederationSender(AbstractFederationSender): | |||
) | |||
return | |||
destinations = { | |||
sharded_destinations = { | |||
d | |||
for d in destinations | |||
if self._federation_shard_config.should_handle( | |||
@@ -456,12 +466,12 @@ class FederationSender(AbstractFederationSender): | |||
# If we are sending the event on behalf of another server | |||
# then it already has the event and there is no reason to | |||
# send the event to it. | |||
destinations.discard(send_on_behalf_of) | |||
sharded_destinations.discard(send_on_behalf_of) | |||
logger.debug("Sending %s to %r", event, destinations) | |||
logger.debug("Sending %s to %r", event, sharded_destinations) | |||
if destinations: | |||
await self._send_pdu(event, destinations) | |||
if sharded_destinations: | |||
await self._send_pdu(event, sharded_destinations) | |||
now = self.clock.time_msec() | |||
ts = await self.store.get_received_ts(event.event_id) | |||
@@ -411,10 +411,10 @@ class SyncHandler: | |||
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result)) | |||
return sync_result | |||
async def push_rules_for_user(self, user: UserID) -> JsonDict: | |||
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]: | |||
user_id = user.to_string() | |||
rules = await self.store.get_push_rules_for_user(user_id) | |||
rules = format_push_rules_for_user(user, rules) | |||
rules_raw = await self.store.get_push_rules_for_user(user_id) | |||
rules = format_push_rules_for_user(user, rules_raw) | |||
return rules | |||
async def ephemeral_by_room( | |||
@@ -148,9 +148,9 @@ class PushRuleRestServlet(RestServlet): | |||
# we build up the full structure and then decide which bits of it | |||
# to send which means doing unnecessary work sometimes but is | |||
# is probably not going to make a whole lot of difference | |||
rules = await self.store.get_push_rules_for_user(user_id) | |||
rules_raw = await self.store.get_push_rules_for_user(user_id) | |||
rules = format_push_rules_for_user(requester.user, rules) | |||
rules = format_push_rules_for_user(requester.user, rules_raw) | |||
path_parts = path.split("/")[1:] | |||
@@ -239,13 +239,13 @@ class StateHandler: | |||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) | |||
return await self.store.get_joined_users_from_state(room_id, entry) | |||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: | |||
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]: | |||
event_ids = await self.store.get_latest_event_ids_in_room(room_id) | |||
return await self.get_hosts_in_room_at_events(room_id, event_ids) | |||
async def get_hosts_in_room_at_events( | |||
self, room_id: str, event_ids: Collection[str] | |||
) -> Set[str]: | |||
) -> FrozenSet[str]: | |||
"""Get the hosts that were in a room at the given event ids | |||
Args: | |||
@@ -26,11 +26,7 @@ from synapse.storage.database import ( | |||
from synapse.storage.databases.main.stats import UserSortOrder | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | |||
from synapse.storage.types import Cursor | |||
from synapse.storage.util.id_generators import ( | |||
IdGenerator, | |||
MultiWriterIdGenerator, | |||
StreamIdGenerator, | |||
) | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | |||
from synapse.types import JsonDict, get_domain_from_id | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
@@ -155,8 +151,6 @@ class DataStore( | |||
], | |||
) | |||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") | |||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") | |||
self._group_updates_id_gen = StreamIdGenerator( | |||
db_conn, "local_group_updates", "stream_id" | |||
) | |||
@@ -14,16 +14,19 @@ | |||
import calendar | |||
import logging | |||
import time | |||
from typing import TYPE_CHECKING, Dict | |||
from typing import TYPE_CHECKING, Dict, List, Tuple, cast | |||
from synapse.metrics import GaugeBucketCollector | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.databases.main.event_push_actions import ( | |||
EventPushActionsWorkerStore, | |||
) | |||
from synapse.storage.types import Cursor | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
@wrap_as_background_process("read_forward_extremities") | |||
async def _read_forward_extremities(self) -> None: | |||
def fetch(txn): | |||
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]: | |||
txn.execute( | |||
""" | |||
SELECT t1.c, t2.c | |||
@@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
) t2 ON t1.room_id = t2.room_id | |||
""" | |||
) | |||
return txn.fetchall() | |||
return cast(List[Tuple[int, int]], txn.fetchall()) | |||
res = await self.db_pool.runInteraction("read_forward_extremities", fetch) | |||
@@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
call to this function, it will return None. | |||
""" | |||
def _count_messages(txn): | |||
def _count_messages(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT COUNT(*) FROM events | |||
WHERE type = 'm.room.encrypted' | |||
AND stream_ordering > ? | |||
""" | |||
txn.execute(sql, (self.stream_ordering_day_ago,)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) | |||
async def count_daily_sent_e2ee_messages(self) -> int: | |||
def _count_messages(txn): | |||
def _count_messages(txn: LoggingTransaction) -> int: | |||
# This is good enough as if you have silly characters in your own | |||
# hostname then that's your own fault. | |||
like_clause = "%:" + self.hs.hostname | |||
@@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
""" | |||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
return await self.db_pool.runInteraction( | |||
@@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
) | |||
async def count_daily_active_e2ee_rooms(self) -> int: | |||
def _count(txn): | |||
def _count(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT COUNT(DISTINCT room_id) FROM events | |||
WHERE type = 'm.room.encrypted' | |||
AND stream_ordering > ? | |||
""" | |||
txn.execute(sql, (self.stream_ordering_day_ago,)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
return await self.db_pool.runInteraction( | |||
@@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
call to this function, it will return None. | |||
""" | |||
def _count_messages(txn): | |||
def _count_messages(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT COUNT(*) FROM events | |||
WHERE type = 'm.room.message' | |||
AND stream_ordering > ? | |||
""" | |||
txn.execute(sql, (self.stream_ordering_day_ago,)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
return await self.db_pool.runInteraction("count_messages", _count_messages) | |||
async def count_daily_sent_messages(self) -> int: | |||
def _count_messages(txn): | |||
def _count_messages(txn: LoggingTransaction) -> int: | |||
# This is good enough as if you have silly characters in your own | |||
# hostname then that's your own fault. | |||
like_clause = "%:" + self.hs.hostname | |||
@@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
""" | |||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
return await self.db_pool.runInteraction( | |||
@@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
) | |||
async def count_daily_active_rooms(self) -> int: | |||
def _count(txn): | |||
def _count(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT COUNT(DISTINCT room_id) FROM events | |||
WHERE type = 'm.room.message' | |||
AND stream_ordering > ? | |||
""" | |||
txn.execute(sql, (self.stream_ordering_day_ago,)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
return await self.db_pool.runInteraction("count_daily_active_rooms", _count) | |||
@@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
"count_monthly_users", self._count_users, thirty_days_ago | |||
) | |||
def _count_users(self, txn: Cursor, time_from: int) -> int: | |||
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int: | |||
""" | |||
Returns number of users seen in the past time_from period | |||
""" | |||
@@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
# Mypy knows that fetchone() might return None if there are no rows. | |||
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always | |||
# returns exactly one row. | |||
(count,) = txn.fetchone() # type: ignore[misc] | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return count | |||
async def count_r30_users(self) -> Dict[str, int]: | |||
@@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
A mapping of counts globally as well as broken out by platform. | |||
""" | |||
def _count_r30_users(txn): | |||
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]: | |||
thirty_days_in_secs = 86400 * 30 | |||
now = int(self._clock.time()) | |||
thirty_days_ago_in_secs = now - thirty_days_in_secs | |||
@@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
results["all"] = count | |||
return results | |||
@@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
- "web" (any web application -- it's not possible to distinguish Element Web here) | |||
""" | |||
def _count_r30v2_users(txn): | |||
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]: | |||
thirty_days_in_secs = 86400 * 30 | |||
now = int(self._clock.time()) | |||
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs | |||
@@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
thirty_days_in_secs * 1000, | |||
), | |||
) | |||
row = txn.fetchone() | |||
if row is None: | |||
results["all"] = 0 | |||
else: | |||
results["all"] = row[0] | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
results["all"] = count | |||
return results | |||
@@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
Generates daily visit data for use in cohort/ retention analysis | |||
""" | |||
def _generate_user_daily_visits(txn): | |||
def _generate_user_daily_visits(txn: LoggingTransaction) -> None: | |||
logger.info("Calling _generate_user_daily_visits") | |||
today_start = self._get_start_of_day() | |||
a_day_in_milliseconds = 24 * 60 * 60 * 1000 | |||
@@ -14,14 +14,18 @@ | |||
# limitations under the License. | |||
import abc | |||
import logging | |||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union | |||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast | |||
from synapse.api.errors import StoreError | |||
from synapse.config.homeserver import ExperimentalConfig | |||
from synapse.push.baserules import list_with_base_rules | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.databases.main.pusher import PusherWorkerStore | |||
@@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore | |||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine | |||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException | |||
from synapse.storage.util.id_generators import ( | |||
AbstractStreamIdGenerator, | |||
AbstractStreamIdTracker, | |||
IdGenerator, | |||
StreamIdGenerator, | |||
) | |||
from synapse.types import JsonDict | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
@@ -57,7 +64,11 @@ def _is_experimental_rule_enabled( | |||
return True | |||
def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig): | |||
def _load_rules( | |||
rawrules: List[JsonDict], | |||
enabled_map: Dict[str, bool], | |||
experimental_config: ExperimentalConfig, | |||
) -> List[JsonDict]: | |||
ruleslist = [] | |||
for rawrule in rawrules: | |||
rule = dict(rawrule) | |||
@@ -137,7 +148,7 @@ class PushRulesWorkerStore( | |||
) | |||
@abc.abstractmethod | |||
def get_max_push_rules_stream_id(self): | |||
def get_max_push_rules_stream_id(self) -> int: | |||
"""Get the position of the push rules stream. | |||
Returns: | |||
@@ -146,7 +157,7 @@ class PushRulesWorkerStore( | |||
raise NotImplementedError() | |||
@cached(max_entries=5000) | |||
async def get_push_rules_for_user(self, user_id): | |||
async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: | |||
rows = await self.db_pool.simple_select_list( | |||
table="push_rules", | |||
keyvalues={"user_name": user_id}, | |||
@@ -168,7 +179,7 @@ class PushRulesWorkerStore( | |||
return _load_rules(rows, enabled_map, self.hs.config.experimental) | |||
@cached(max_entries=5000) | |||
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: | |||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: | |||
results = await self.db_pool.simple_select_list( | |||
table="push_rules_enable", | |||
keyvalues={"user_name": user_id}, | |||
@@ -184,13 +195,13 @@ class PushRulesWorkerStore( | |||
return False | |||
else: | |||
def have_push_rules_changed_txn(txn): | |||
def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: | |||
sql = ( | |||
"SELECT COUNT(stream_id) FROM push_rules_stream" | |||
" WHERE user_id = ? AND ? < stream_id" | |||
) | |||
txn.execute(sql, (user_id, last_id)) | |||
(count,) = txn.fetchone() | |||
(count,) = cast(Tuple[int], txn.fetchone()) | |||
return bool(count) | |||
return await self.db_pool.runInteraction( | |||
@@ -202,11 +213,13 @@ class PushRulesWorkerStore( | |||
list_name="user_ids", | |||
num_args=1, | |||
) | |||
async def bulk_get_push_rules(self, user_ids): | |||
async def bulk_get_push_rules( | |||
self, user_ids: Collection[str] | |||
) -> Dict[str, List[JsonDict]]: | |||
if not user_ids: | |||
return {} | |||
results = {user_id: [] for user_id in user_ids} | |||
results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="push_rules", | |||
@@ -250,7 +263,7 @@ class PushRulesWorkerStore( | |||
condition["pattern"] = new_room_id | |||
# Add the rule for the new room | |||
await self.add_push_rule( | |||
await self.add_push_rule( # type: ignore[attr-defined] | |||
user_id=user_id, | |||
rule_id=new_rule_id, | |||
priority_class=rule["priority_class"], | |||
@@ -286,11 +299,13 @@ class PushRulesWorkerStore( | |||
list_name="user_ids", | |||
num_args=1, | |||
) | |||
async def bulk_get_push_rules_enabled(self, user_ids): | |||
async def bulk_get_push_rules_enabled( | |||
self, user_ids: Collection[str] | |||
) -> Dict[str, Dict[str, bool]]: | |||
if not user_ids: | |||
return {} | |||
results = {user_id: {} for user_id in user_ids} | |||
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="push_rules_enable", | |||
@@ -306,7 +321,7 @@ class PushRulesWorkerStore( | |||
async def get_all_push_rule_updates( | |||
self, instance_name: str, last_id: int, current_id: int, limit: int | |||
) -> Tuple[List[Tuple[int, tuple]], int, bool]: | |||
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: | |||
"""Get updates for push_rules replication stream. | |||
Args: | |||
@@ -331,7 +346,9 @@ class PushRulesWorkerStore( | |||
if last_id == current_id: | |||
return [], current_id, False | |||
def get_all_push_rule_updates_txn(txn): | |||
def get_all_push_rule_updates_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: | |||
sql = """ | |||
SELECT stream_id, user_id | |||
FROM push_rules_stream | |||
@@ -340,7 +357,10 @@ class PushRulesWorkerStore( | |||
LIMIT ? | |||
""" | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] | |||
updates = cast( | |||
List[Tuple[int, Tuple[str]]], | |||
[(stream_id, (user_id,)) for stream_id, user_id in txn], | |||
) | |||
limited = False | |||
upper_bound = current_id | |||
@@ -356,15 +376,30 @@ class PushRulesWorkerStore( | |||
class PushRuleStore(PushRulesWorkerStore): | |||
# Because we have write access, this will be a StreamIdGenerator | |||
# (see PushRulesWorkerStore.__init__) | |||
_push_rules_stream_id_gen: AbstractStreamIdGenerator | |||
def __init__( | |||
self, | |||
database: DatabasePool, | |||
db_conn: LoggingDatabaseConnection, | |||
hs: "HomeServer", | |||
): | |||
super().__init__(database, db_conn, hs) | |||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") | |||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") | |||
async def add_push_rule( | |||
self, | |||
user_id, | |||
rule_id, | |||
priority_class, | |||
conditions, | |||
actions, | |||
before=None, | |||
after=None, | |||
user_id: str, | |||
rule_id: str, | |||
priority_class: int, | |||
conditions: List[Dict[str, str]], | |||
actions: List[Union[JsonDict, str]], | |||
before: Optional[str] = None, | |||
after: Optional[str] = None, | |||
) -> None: | |||
conditions_json = json_encoder.encode(conditions) | |||
actions_json = json_encoder.encode(actions) | |||
@@ -400,17 +435,17 @@ class PushRuleStore(PushRulesWorkerStore): | |||
def _add_push_rule_relative_txn( | |||
self, | |||
txn, | |||
stream_id, | |||
event_stream_ordering, | |||
user_id, | |||
rule_id, | |||
priority_class, | |||
conditions_json, | |||
actions_json, | |||
before, | |||
after, | |||
): | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
user_id: str, | |||
rule_id: str, | |||
priority_class: int, | |||
conditions_json: str, | |||
actions_json: str, | |||
before: str, | |||
after: str, | |||
) -> None: | |||
# Lock the table since otherwise we'll have annoying races between the | |||
# SELECT here and the UPSERT below. | |||
self.database_engine.lock_table(txn, "push_rules") | |||
@@ -470,15 +505,15 @@ class PushRuleStore(PushRulesWorkerStore): | |||
def _add_push_rule_highest_priority_txn( | |||
self, | |||
txn, | |||
stream_id, | |||
event_stream_ordering, | |||
user_id, | |||
rule_id, | |||
priority_class, | |||
conditions_json, | |||
actions_json, | |||
): | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
user_id: str, | |||
rule_id: str, | |||
priority_class: int, | |||
conditions_json: str, | |||
actions_json: str, | |||
) -> None: | |||
# Lock the table since otherwise we'll have annoying races between the | |||
# SELECT here and the UPSERT below. | |||
self.database_engine.lock_table(txn, "push_rules") | |||
@@ -510,17 +545,17 @@ class PushRuleStore(PushRulesWorkerStore): | |||
def _upsert_push_rule_txn( | |||
self, | |||
txn, | |||
stream_id, | |||
event_stream_ordering, | |||
user_id, | |||
rule_id, | |||
priority_class, | |||
priority, | |||
conditions_json, | |||
actions_json, | |||
update_stream=True, | |||
): | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
user_id: str, | |||
rule_id: str, | |||
priority_class: int, | |||
priority: int, | |||
conditions_json: str, | |||
actions_json: str, | |||
update_stream: bool = True, | |||
) -> None: | |||
"""Specialised version of simple_upsert_txn that picks a push_rule_id | |||
using the _push_rule_id_gen if it needs to insert the rule. It assumes | |||
that the "push_rules" table is locked""" | |||
@@ -600,7 +635,11 @@ class PushRuleStore(PushRulesWorkerStore): | |||
rule_id: The rule_id of the rule to be deleted | |||
""" | |||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering): | |||
def delete_push_rule_txn( | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
) -> None: | |||
# we don't use simple_delete_one_txn because that would fail if the | |||
# user did not have a push_rule_enable row. | |||
self.db_pool.simple_delete_txn( | |||
@@ -661,14 +700,14 @@ class PushRuleStore(PushRulesWorkerStore): | |||
def _set_push_rule_enabled_txn( | |||
self, | |||
txn, | |||
stream_id, | |||
event_stream_ordering, | |||
user_id, | |||
rule_id, | |||
enabled, | |||
is_default_rule, | |||
): | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
user_id: str, | |||
rule_id: str, | |||
enabled: bool, | |||
is_default_rule: bool, | |||
) -> None: | |||
new_id = self._push_rules_enable_id_gen.get_next() | |||
if not is_default_rule: | |||
@@ -740,7 +779,11 @@ class PushRuleStore(PushRulesWorkerStore): | |||
""" | |||
actions_json = json_encoder.encode(actions) | |||
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): | |||
def set_push_rule_actions_txn( | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
) -> None: | |||
if is_default_rule: | |||
# Add a dummy rule to the rules table with the user specified | |||
# actions. | |||
@@ -794,8 +837,15 @@ class PushRuleStore(PushRulesWorkerStore): | |||
) | |||
def _insert_push_rules_update_txn( | |||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None | |||
): | |||
self, | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
event_stream_ordering: int, | |||
user_id: str, | |||
rule_id: str, | |||
op: str, | |||
data: Optional[JsonDict] = None, | |||
) -> None: | |||
values = { | |||
"stream_id": stream_id, | |||
"event_stream_ordering": event_stream_ordering, | |||
@@ -814,5 +864,5 @@ class PushRuleStore(PushRulesWorkerStore): | |||
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id | |||
) | |||
def get_max_push_rules_stream_id(self): | |||
def get_max_push_rules_stream_id(self) -> int: | |||
return self._push_rules_stream_id_gen.get_current_token() |
@@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import ( | |||
wrap_as_background_process, | |||
) | |||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.engines import Sqlite3Engine | |||
from synapse.storage.roommember import ( | |||
@@ -46,7 +51,7 @@ from synapse.storage.roommember import ( | |||
ProfileInfo, | |||
RoomsForUser, | |||
) | |||
from synapse.types import PersistedEventPosition, get_domain_from_id | |||
from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id | |||
from synapse.util.async_helpers import Linearizer | |||
from synapse.util.caches import intern_string | |||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList | |||
@@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
) | |||
@wrap_as_background_process("_count_known_servers") | |||
async def _count_known_servers(self): | |||
async def _count_known_servers(self) -> int: | |||
""" | |||
Count the servers that this server knows about. | |||
@@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
`synapse_federation_known_servers` LaterGauge to collect. | |||
""" | |||
def _transact(txn): | |||
def _transact(txn: LoggingTransaction) -> int: | |||
if isinstance(self.database_engine, Sqlite3Engine): | |||
query = """ | |||
SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) | |||
@@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
self._known_servers_count = max([count, 1]) | |||
return self._known_servers_count | |||
def _check_safe_current_state_events_membership_updated_txn(self, txn): | |||
def _check_safe_current_state_events_membership_updated_txn( | |||
self, txn: LoggingTransaction | |||
) -> None: | |||
"""Checks if it is safe to assume the new current_state_events | |||
membership column is up to date | |||
""" | |||
@@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
"get_users_in_room", self.get_users_in_room_txn, room_id | |||
) | |||
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: | |||
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: | |||
# If we can assume current_state_events.membership is up to date | |||
# then we can avoid a join, which is a Very Good Thing given how | |||
# frequently this function gets called. | |||
@@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
A mapping from user ID to ProfileInfo. | |||
""" | |||
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: | |||
def _get_users_in_room_with_profiles( | |||
txn: LoggingTransaction, | |||
) -> Dict[str, ProfileInfo]: | |||
sql = """ | |||
SELECT state_key, display_name, avatar_url FROM room_memberships as m | |||
INNER JOIN current_state_events as c | |||
@@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
dict of membership states, pointing to a MemberSummary named tuple. | |||
""" | |||
def _get_room_summary_txn(txn): | |||
def _get_room_summary_txn( | |||
txn: LoggingTransaction, | |||
) -> Dict[str, MemberSummary]: | |||
# first get counts. | |||
# We do this all in one transaction to keep the cache small. | |||
# FIXME: get rid of this when we have room_stats | |||
@@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
""" | |||
txn.execute(sql, (room_id,)) | |||
res = {} | |||
res: Dict[str, MemberSummary] = {} | |||
for count, membership in txn: | |||
res.setdefault(membership, MemberSummary([], count)) | |||
@@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
def _get_rooms_for_local_user_where_membership_is_txn( | |||
self, | |||
txn, | |||
txn: LoggingTransaction, | |||
user_id: str, | |||
membership_list: List[str], | |||
) -> List[RoomsForUser]: | |||
@@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
) | |||
def _get_rooms_for_user_with_stream_ordering_txn( | |||
self, txn, user_id: str | |||
self, txn: LoggingTransaction, user_id: str | |||
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: | |||
# We use `current_state_events` here and not `local_current_membership` | |||
# as a) this gets called with remote users and b) this only gets called | |||
@@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
) | |||
def _get_rooms_for_users_with_stream_ordering_txn( | |||
self, txn, user_ids: Collection[str] | |||
self, txn: LoggingTransaction, user_ids: Collection[str] | |||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: | |||
clause, args = make_in_list_sql_clause( | |||
@@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
txn.execute(sql, [Membership.JOIN] + args) | |||
result = {user_id: set() for user_id in user_ids} | |||
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { | |||
user_id: set() for user_id in user_ids | |||
} | |||
for user_id, room_id, instance, stream_id in txn: | |||
result[user_id].add( | |||
GetRoomsForUserWithStreamOrdering( | |||
@@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
if not user_ids: | |||
return set() | |||
def _get_users_server_still_shares_room_with_txn(txn): | |||
def _get_users_server_still_shares_room_with_txn( | |||
txn: LoggingTransaction, | |||
) -> Set[str]: | |||
sql = """ | |||
SELECT state_key FROM current_state_events | |||
WHERE | |||
@@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
async def get_joined_users_from_context( | |||
self, event: EventBase, context: EventContext | |||
) -> Dict[str, ProfileInfo]: | |||
state_group = context.state_group | |||
state_group: Union[object, int] = context.state_group | |||
if not state_group: | |||
# If state_group is None it means it has yet to be assigned a | |||
# state group, i.e. we need to make sure that calls with a state_group | |||
@@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
state_group = object() | |||
current_state_ids = await context.get_current_state_ids() | |||
assert current_state_ids is not None | |||
assert state_group is not None | |||
return await self._get_joined_users_from_context( | |||
event.room_id, state_group, current_state_ids, event=event, context=context | |||
) | |||
async def get_joined_users_from_state( | |||
self, room_id, state_entry | |||
self, room_id: str, state_entry: "_StateCacheEntry" | |||
) -> Dict[str, ProfileInfo]: | |||
state_group = state_entry.state_group | |||
state_group: Union[object, int] = state_entry.state_group | |||
if not state_group: | |||
# If state_group is None it means it has yet to be assigned a | |||
# state group, i.e. we need to make sure that calls with a state_group | |||
@@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
# To do this we set the state_group to a new object as object() != object() | |||
state_group = object() | |||
assert state_group is not None | |||
with Measure(self._clock, "get_joined_users_from_state"): | |||
return await self._get_joined_users_from_context( | |||
room_id, state_group, state_entry.state, context=state_entry | |||
@@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) | |||
async def _get_joined_users_from_context( | |||
self, | |||
room_id, | |||
state_group, | |||
current_state_ids, | |||
cache_context, | |||
event=None, | |||
context=None, | |||
room_id: str, | |||
state_group: Union[object, int], | |||
current_state_ids: StateMap[str], | |||
cache_context: _CacheContext, | |||
event: Optional[EventBase] = None, | |||
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, | |||
) -> Dict[str, ProfileInfo]: | |||
# We don't use `state_group`, it's there so that we can cache based | |||
# on it. However, it's important that it's never None, since two current_states | |||
@@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return users_in_room | |||
@cached(max_entries=10000) | |||
def _get_joined_profile_from_event_id(self, event_id): | |||
def _get_joined_profile_from_event_id( | |||
self, event_id: str | |||
) -> Optional[Tuple[str, ProfileInfo]]: | |||
raise NotImplementedError() | |||
@cachedList( | |||
cached_method_name="_get_joined_profile_from_event_id", | |||
list_name="event_ids", | |||
) | |||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): | |||
async def _get_joined_profiles_from_event_ids( | |||
self, event_ids: Iterable[str] | |||
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: | |||
"""For given set of member event_ids check if they point to a join | |||
event and if so return the associated user and profile info. | |||
@@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
event_ids: The member event IDs to lookup | |||
Returns: | |||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID | |||
to `user_id` and ProfileInfo (or None if not join event). | |||
Map from event ID to `user_id` and ProfileInfo (or None if not join event). | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
@@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return True | |||
async def get_joined_hosts(self, room_id: str, state_entry): | |||
state_group = state_entry.state_group | |||
async def get_joined_hosts( | |||
self, room_id: str, state_entry: "_StateCacheEntry" | |||
) -> FrozenSet[str]: | |||
state_group: Union[object, int] = state_entry.state_group | |||
if not state_group: | |||
# If state_group is None it means it has yet to be assigned a | |||
# state group, i.e. we need to make sure that calls with a state_group | |||
@@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
# To do this we set the state_group to a new object as object() != object() | |||
state_group = object() | |||
assert state_group is not None | |||
with Measure(self._clock, "get_joined_hosts"): | |||
return await self._get_joined_hosts( | |||
room_id, state_group, state_entry=state_entry | |||
@@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
@cached(num_args=2, max_entries=10000, iterable=True) | |||
async def _get_joined_hosts( | |||
self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" | |||
self, | |||
room_id: str, | |||
state_group: Union[object, int], | |||
state_entry: "_StateCacheEntry", | |||
) -> FrozenSet[str]: | |||
# We don't use `state_group`, it's there so that we can cache based on | |||
# it. However, its important that its never None, since two | |||
@@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
# `get_joined_hosts` is called with the "current" state group for the | |||
# room, and so consecutive calls will be for consecutive state groups | |||
# which point to the previous state group. | |||
cache = await self._get_joined_hosts_cache(room_id) | |||
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] | |||
# If the state group in the cache matches, we already have the data we need. | |||
if state_entry.state_group == cache.state_group: | |||
@@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
elif state_entry.prev_group == cache.state_group: | |||
# The cached work is for the previous state group, so we work out | |||
# the delta. | |||
assert state_entry.delta_ids is not None | |||
for (typ, state_key), event_id in state_entry.delta_ids.items(): | |||
if typ != EventTypes.Member: | |||
continue | |||
@@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
Returns False if they have since re-joined.""" | |||
def f(txn): | |||
def f(txn: LoggingTransaction) -> int: | |||
sql = ( | |||
"SELECT" | |||
" COUNT(*)" | |||
@@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
The forgotten rooms. | |||
""" | |||
def _get_forgotten_rooms_for_user_txn(txn): | |||
def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]: | |||
# This is a slightly convoluted query that first looks up all rooms | |||
# that the user has forgotten in the past, then rechecks that list | |||
# to see if any have subsequently been updated. This is done so that | |||
@@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
clause, | |||
) | |||
def _is_local_host_in_room_ignoring_users_txn(txn): | |||
def _is_local_host_in_room_ignoring_users_txn( | |||
txn: LoggingTransaction, | |||
) -> bool: | |||
txn.execute(sql, (room_id, Membership.JOIN, *args)) | |||
return bool(txn.fetchone()) | |||
@@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
where_clause="forgotten = 1", | |||
) | |||
async def _background_add_membership_profile(self, progress, batch_size): | |||
async def _background_add_membership_profile( | |||
self, progress: JsonDict, batch_size: int | |||
) -> int: | |||
target_min_stream_id = progress.get( | |||
"target_min_stream_id_inclusive", self._min_stream_order_on_start | |||
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] | |||
) | |||
max_stream_id = progress.get( | |||
"max_stream_id_exclusive", self._stream_order_on_start + 1 | |||
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] | |||
) | |||
def add_membership_profile_txn(txn): | |||
def add_membership_profile_txn(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT stream_ordering, event_id, events.room_id, event_json.json | |||
FROM events | |||
@@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
return result | |||
async def _background_current_state_membership(self, progress, batch_size): | |||
async def _background_current_state_membership( | |||
self, progress: JsonDict, batch_size: int | |||
) -> int: | |||
"""Update the new membership column on current_state_events. | |||
This works by iterating over all rooms in alphebetical order. | |||
""" | |||
def _background_current_state_membership_txn(txn, last_processed_room): | |||
def _background_current_state_membership_txn( | |||
txn: LoggingTransaction, last_processed_room: str | |||
) -> Tuple[int, bool]: | |||
processed = 0 | |||
while processed < batch_size: | |||
txn.execute( | |||
@@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
return row_count | |||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
class RoomMemberStore( | |||
RoomMemberWorkerStore, | |||
RoomMemberBackgroundUpdateStore, | |||
CacheInvalidationWorkerStore, | |||
): | |||
def __init__( | |||
self, | |||
database: DatabasePool, | |||
@@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
async def forget(self, user_id: str, room_id: str) -> None: | |||
"""Indicate that user_id wishes to discard history for room_id.""" | |||
def f(txn): | |||
def f(txn: LoggingTransaction) -> None: | |||
sql = ( | |||
"UPDATE" | |||
" room_memberships" | |||
@@ -1288,5 +1328,5 @@ class _JoinedHostsCache: | |||
# equal to anything else). | |||
state_group: Union[object, int] = attr.Factory(object) | |||
def __len__(self): | |||
def __len__(self) -> int: | |||
return sum(len(v) for v in self.hosts_to_joined_users.values()) |