@@ -0,0 +1 @@ | |||
Reduce memory allocations. |
@@ -1874,9 +1874,9 @@ class DatabasePool: | |||
keyvalues: Optional[Dict[str, Any]] = None, | |||
desc: str = "simple_select_many_batch", | |||
batch_size: int = 100, | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[Any, ...]]: | |||
"""Executes a SELECT query on the named table, which may return zero or | |||
more rows, returning the result as a list of dicts. | |||
more rows. | |||
Filters rows by whether the value of `column` is in `iterable`. | |||
@@ -1888,10 +1888,13 @@ class DatabasePool: | |||
keyvalues: dict of column names and values to select the rows with | |||
desc: description of the transaction, for logging and metrics | |||
batch_size: the number of rows for each select query | |||
Returns: | |||
The results as a list of tuples. | |||
""" | |||
keyvalues = keyvalues or {} | |||
results: List[Dict[str, Any]] = [] | |||
results: List[Tuple[Any, ...]] = [] | |||
for chunk in batch_iter(iterable, batch_size): | |||
rows = await self.runInteraction( | |||
@@ -1918,9 +1921,9 @@ class DatabasePool: | |||
iterable: Collection[Any], | |||
keyvalues: Dict[str, Any], | |||
retcols: Iterable[str], | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[Any, ...]]: | |||
"""Executes a SELECT query on the named table, which may return zero or | |||
more rows, returning the result as a list of dicts. | |||
more rows. | |||
Filters rows by whether the value of `column` is in `iterable`. | |||
@@ -1931,6 +1934,9 @@ class DatabasePool: | |||
iterable: list | |||
keyvalues: dict of column names and values to select the rows with | |||
retcols: list of strings giving the names of the columns to return | |||
Returns: | |||
The results as a list of tuples. | |||
""" | |||
if not iterable: | |||
return [] | |||
@@ -1949,7 +1955,7 @@ class DatabasePool: | |||
) | |||
txn.execute(sql, values) | |||
return cls.cursor_to_dict(txn) | |||
return txn.fetchall() | |||
async def simple_update( | |||
self, | |||
@@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
# Note that this is more efficient than just dropping `device_id` from the query, | |||
# since device_inbox has an index on `(user_id, device_id, stream_id)` | |||
if not device_ids_to_query: | |||
user_device_dicts = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="devices", | |||
column="user_id", | |||
iterable=user_ids_to_query, | |||
keyvalues={"hidden": False}, | |||
retcols=("device_id",), | |||
user_device_dicts = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="devices", | |||
column="user_id", | |||
iterable=user_ids_to_query, | |||
keyvalues={"hidden": False}, | |||
retcols=("device_id",), | |||
), | |||
) | |||
device_ids_to_query.update( | |||
{row["device_id"] for row in user_device_dicts} | |||
) | |||
device_ids_to_query.update({row[0] for row in user_device_dicts}) | |||
if not device_ids_to_query: | |||
# We've ended up with no devices to query. | |||
@@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
# We exclude hidden devices (such as cross-signing keys) here as they are | |||
# not expected to receive to-device messages. | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="devices", | |||
keyvalues={"user_id": user_id, "hidden": False}, | |||
column="device_id", | |||
iterable=devices, | |||
retcols=("device_id",), | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="devices", | |||
keyvalues={"user_id": user_id, "hidden": False}, | |||
column="device_id", | |||
iterable=devices, | |||
retcols=("device_id",), | |||
), | |||
) | |||
for row in rows: | |||
for (device_id,) in rows: | |||
# Only insert into the local inbox if the device exists on | |||
# this server | |||
device_id = row["device_id"] | |||
with start_active_span("serialise_to_device_message"): | |||
msg = messages_by_device[device_id] | |||
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) | |||
@@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
async def get_device_list_last_stream_id_for_remotes( | |||
self, user_ids: Iterable[str] | |||
) -> Mapping[str, Optional[str]]: | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_extremeties", | |||
column="user_id", | |||
iterable=user_ids, | |||
retcols=("user_id", "stream_id"), | |||
desc="get_device_list_last_stream_id_for_remotes", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_extremeties", | |||
column="user_id", | |||
iterable=user_ids, | |||
retcols=("user_id", "stream_id"), | |||
desc="get_device_list_last_stream_id_for_remotes", | |||
), | |||
) | |||
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} | |||
results.update({row["user_id"]: row["stream_id"] for row in rows}) | |||
results.update(rows) | |||
return results | |||
@@ -1077,22 +1080,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
The IDs of users whose device lists need resync. | |||
""" | |||
if user_ids: | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_resync", | |||
column="user_id", | |||
iterable=user_ids, | |||
retcols=("user_id",), | |||
desc="get_user_ids_requiring_device_list_resync_with_iterable", | |||
row_tuples = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_resync", | |||
column="user_id", | |||
iterable=user_ids, | |||
retcols=("user_id",), | |||
desc="get_user_ids_requiring_device_list_resync_with_iterable", | |||
), | |||
) | |||
return {row[0] for row in row_tuples} | |||
else: | |||
rows = await self.db_pool.simple_select_list( | |||
table="device_lists_remote_resync", | |||
keyvalues=None, | |||
retcols=("user_id",), | |||
desc="get_user_ids_requiring_device_list_resync", | |||
rows = cast( | |||
List[Dict[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="device_lists_remote_resync", | |||
keyvalues=None, | |||
retcols=("user_id",), | |||
desc="get_user_ids_requiring_device_list_resync", | |||
), | |||
) | |||
return {row["user_id"] for row in rows} | |||
return {row["user_id"] for row in rows} | |||
async def mark_remote_users_device_caches_as_stale( | |||
self, user_ids: StrCollection | |||
@@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
A map from (algorithm, key_id) to json string for key | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="e2e_one_time_keys_json", | |||
column="key_id", | |||
iterable=key_ids, | |||
retcols=("algorithm", "key_id", "key_json"), | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
desc="add_e2e_one_time_keys_check", | |||
rows = cast( | |||
List[Tuple[str, str, str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="e2e_one_time_keys_json", | |||
column="key_id", | |||
iterable=key_ids, | |||
retcols=("algorithm", "key_id", "key_json"), | |||
keyvalues={"user_id": user_id, "device_id": device_id}, | |||
desc="add_e2e_one_time_keys_check", | |||
), | |||
) | |||
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} | |||
result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows} | |||
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) | |||
return result | |||
@@ -1049,15 +1049,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
Args: | |||
event_ids: The event IDs to calculate the max depth of. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="events", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=( | |||
"event_id", | |||
"depth", | |||
rows = cast( | |||
List[Tuple[str, int]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="events", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=( | |||
"event_id", | |||
"depth", | |||
), | |||
desc="get_max_depth_of", | |||
), | |||
desc="get_max_depth_of", | |||
) | |||
if not rows: | |||
@@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
else: | |||
max_depth_event_id = "" | |||
current_max_depth = 0 | |||
for row in rows: | |||
if row["depth"] > current_max_depth: | |||
max_depth_event_id = row["event_id"] | |||
current_max_depth = row["depth"] | |||
for event_id, depth in rows: | |||
if depth > current_max_depth: | |||
max_depth_event_id = event_id | |||
current_max_depth = depth | |||
return max_depth_event_id, current_max_depth | |||
@@ -1078,15 +1081,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
Args: | |||
event_ids: The event IDs to calculate the max depth of. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="events", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=( | |||
"event_id", | |||
"depth", | |||
rows = cast( | |||
List[Tuple[str, int]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="events", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=( | |||
"event_id", | |||
"depth", | |||
), | |||
desc="get_min_depth_of", | |||
), | |||
desc="get_min_depth_of", | |||
) | |||
if not rows: | |||
@@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
else: | |||
min_depth_event_id = "" | |||
current_min_depth = MAX_DEPTH | |||
for row in rows: | |||
if row["depth"] < current_min_depth: | |||
min_depth_event_id = row["event_id"] | |||
current_min_depth = row["depth"] | |||
for event_id, depth in rows: | |||
if depth < current_min_depth: | |||
min_depth_event_id = event_id | |||
current_min_depth = depth | |||
return min_depth_event_id, current_min_depth | |||
@@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
A filtered down list of `event_ids` that have previous failed pull attempts. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="event_failed_pull_attempts", | |||
column="event_id", | |||
iterable=event_ids, | |||
keyvalues={}, | |||
retcols=("event_id",), | |||
desc="get_event_ids_with_failed_pull_attempts", | |||
rows = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="event_failed_pull_attempts", | |||
column="event_id", | |||
iterable=event_ids, | |||
keyvalues={}, | |||
retcols=("event_id",), | |||
desc="get_event_ids_with_failed_pull_attempts", | |||
), | |||
) | |||
event_ids_with_failed_pull_attempts: Set[str] = { | |||
row["event_id"] for row in rows | |||
} | |||
return event_ids_with_failed_pull_attempts | |||
return {row[0] for row in rows} | |||
@trace | |||
async def get_event_ids_to_not_pull_from_backoff( | |||
@@ -1585,32 +1590,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
A dictionary of event_ids that should not be attempted to be pulled and the | |||
next timestamp at which we may try pulling them again. | |||
""" | |||
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( | |||
table="event_failed_pull_attempts", | |||
column="event_id", | |||
iterable=event_ids, | |||
keyvalues={}, | |||
retcols=( | |||
"event_id", | |||
"last_attempt_ts", | |||
"num_attempts", | |||
event_failed_pull_attempts = cast( | |||
List[Tuple[str, int, int]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="event_failed_pull_attempts", | |||
column="event_id", | |||
iterable=event_ids, | |||
keyvalues={}, | |||
retcols=( | |||
"event_id", | |||
"last_attempt_ts", | |||
"num_attempts", | |||
), | |||
desc="get_event_ids_to_not_pull_from_backoff", | |||
), | |||
desc="get_event_ids_to_not_pull_from_backoff", | |||
) | |||
current_time = self._clock.time_msec() | |||
event_ids_with_backoff = {} | |||
for event_failed_pull_attempt in event_failed_pull_attempts: | |||
event_id = event_failed_pull_attempt["event_id"] | |||
for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts: | |||
# Exponential back-off (up to the upper bound) so we don't try to | |||
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. | |||
backoff_end_time = ( | |||
event_failed_pull_attempt["last_attempt_ts"] | |||
last_attempt_ts | |||
+ ( | |||
2 | |||
** min( | |||
event_failed_pull_attempt["num_attempts"], | |||
num_attempts, | |||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, | |||
) | |||
) | |||
@@ -27,6 +27,7 @@ from typing import ( | |||
Optional, | |||
Set, | |||
Tuple, | |||
Union, | |||
cast, | |||
) | |||
@@ -501,16 +502,19 @@ class PersistEventsStore: | |||
# We ignore legacy rooms that we aren't filling the chain cover index | |||
# for. | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="rooms", | |||
column="room_id", | |||
iterable={event.room_id for event in events if event.is_state()}, | |||
keyvalues={}, | |||
retcols=("room_id", "has_auth_chain_index"), | |||
rows = cast( | |||
List[Tuple[str, Optional[Union[int, bool]]]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="rooms", | |||
column="room_id", | |||
iterable={event.room_id for event in events if event.is_state()}, | |||
keyvalues={}, | |||
retcols=("room_id", "has_auth_chain_index"), | |||
), | |||
) | |||
rooms_using_chain_index = { | |||
row["room_id"] for row in rows if row["has_auth_chain_index"] | |||
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index | |||
} | |||
state_events = { | |||
@@ -571,19 +575,18 @@ class PersistEventsStore: | |||
# We check if there are any events that need to be handled in the rooms | |||
# we're looking at. These should just be out of band memberships, where | |||
# we didn't have the auth chain when we first persisted. | |||
rows = db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_auth_chain_to_calculate", | |||
keyvalues={}, | |||
column="room_id", | |||
iterable=set(event_to_room_id.values()), | |||
retcols=("event_id", "type", "state_key"), | |||
auth_chain_to_calc_rows = cast( | |||
List[Tuple[str, str, str]], | |||
db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_auth_chain_to_calculate", | |||
keyvalues={}, | |||
column="room_id", | |||
iterable=set(event_to_room_id.values()), | |||
retcols=("event_id", "type", "state_key"), | |||
), | |||
) | |||
for row in rows: | |||
event_id = row["event_id"] | |||
event_type = row["type"] | |||
state_key = row["state_key"] | |||
for event_id, event_type, state_key in auth_chain_to_calc_rows: | |||
# (We could pull out the auth events for all rows at once using | |||
# simple_select_many, but this case happens rarely and almost always | |||
# with a single row.) | |||
@@ -753,23 +756,31 @@ class PersistEventsStore: | |||
# Step 1, fetch all existing links from all the chains we've seen | |||
# referenced. | |||
chain_links = _LinkMap() | |||
rows = db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_auth_chain_links", | |||
column="origin_chain_id", | |||
iterable={chain_id for chain_id, _ in chain_map.values()}, | |||
keyvalues={}, | |||
retcols=( | |||
"origin_chain_id", | |||
"origin_sequence_number", | |||
"target_chain_id", | |||
"target_sequence_number", | |||
auth_chain_rows = cast( | |||
List[Tuple[int, int, int, int]], | |||
db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_auth_chain_links", | |||
column="origin_chain_id", | |||
iterable={chain_id for chain_id, _ in chain_map.values()}, | |||
keyvalues={}, | |||
retcols=( | |||
"origin_chain_id", | |||
"origin_sequence_number", | |||
"target_chain_id", | |||
"target_sequence_number", | |||
), | |||
), | |||
) | |||
for row in rows: | |||
for ( | |||
origin_chain_id, | |||
origin_sequence_number, | |||
target_chain_id, | |||
target_sequence_number, | |||
) in auth_chain_rows: | |||
chain_links.add_link( | |||
(row["origin_chain_id"], row["origin_sequence_number"]), | |||
(row["target_chain_id"], row["target_sequence_number"]), | |||
(origin_chain_id, origin_sequence_number), | |||
(target_chain_id, target_sequence_number), | |||
new=False, | |||
) | |||
@@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] | |||
for chunk in chunks: | |||
ev_rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_json", | |||
column="event_id", | |||
iterable=chunk, | |||
retcols=["event_id", "json"], | |||
keyvalues={}, | |||
ev_rows = cast( | |||
List[Tuple[str, str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_json", | |||
column="event_id", | |||
iterable=chunk, | |||
retcols=["event_id", "json"], | |||
keyvalues={}, | |||
), | |||
) | |||
for row in ev_rows: | |||
event_id = row["event_id"] | |||
event_json = db_to_json(row["json"]) | |||
for event_id, json in ev_rows: | |||
event_json = db_to_json(json) | |||
try: | |||
origin_server_ts = event_json["origin_server_ts"] | |||
except (KeyError, AttributeError): | |||
@@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
if deleted: | |||
# We now need to invalidate the caches of these rooms | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="events", | |||
column="event_id", | |||
iterable=to_delete, | |||
keyvalues={}, | |||
retcols=("room_id",), | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="events", | |||
column="event_id", | |||
iterable=to_delete, | |||
keyvalues={}, | |||
retcols=("room_id",), | |||
), | |||
) | |||
room_ids = {row["room_id"] for row in rows} | |||
room_ids = {row[0] for row in rows} | |||
for room_id in room_ids: | |||
txn.call_after( | |||
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] | |||
@@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
count = len(rows) | |||
# We also need to fetch the auth events for them. | |||
auth_events = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_auth", | |||
column="event_id", | |||
iterable=event_to_room_id, | |||
keyvalues={}, | |||
retcols=("event_id", "auth_id"), | |||
auth_events = cast( | |||
List[Tuple[str, str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="event_auth", | |||
column="event_id", | |||
iterable=event_to_room_id, | |||
keyvalues={}, | |||
retcols=("event_id", "auth_id"), | |||
), | |||
) | |||
event_to_auth_chain: Dict[str, List[str]] = {} | |||
for row in auth_events: | |||
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) | |||
for event_id, auth_id in auth_events: | |||
event_to_auth_chain.setdefault(event_id, []).append(auth_id) | |||
# Calculate and persist the chain cover index for this set of events. | |||
# | |||
@@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore): | |||
"""Given a list of event ids, check if we have already processed and | |||
stored them as non outliers. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="events", | |||
retcols=("event_id",), | |||
column="event_id", | |||
iterable=list(event_ids), | |||
keyvalues={"outlier": False}, | |||
desc="have_events_in_timeline", | |||
rows = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="events", | |||
retcols=("event_id",), | |||
column="event_id", | |||
iterable=list(event_ids), | |||
keyvalues={"outlier": False}, | |||
desc="have_events_in_timeline", | |||
), | |||
) | |||
return {r["event_id"] for r in rows} | |||
return {r[0] for r in rows} | |||
@trace | |||
@tag_args | |||
@@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore): | |||
a dict mapping from event id to partial-stateness. We return True for | |||
any of the events which are unknown (or are outliers). | |||
""" | |||
result = await self.db_pool.simple_select_many_batch( | |||
table="partial_state_events", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=["event_id"], | |||
desc="get_partial_state_events", | |||
result = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="partial_state_events", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=["event_id"], | |||
desc="get_partial_state_events", | |||
), | |||
) | |||
# convert the result to a dict, to make @cachedList work | |||
partial = {r["event_id"] for r in result} | |||
partial = {r[0] for r in result} | |||
return {e_id: e_id in partial for e_id in event_ids} | |||
@cached() | |||
@@ -16,7 +16,7 @@ | |||
import itertools | |||
import json | |||
import logging | |||
from typing import Dict, Iterable, Mapping, Optional, Tuple | |||
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast | |||
from canonicaljson import encode_canonical_json | |||
from signedjson.key import decode_verify_key_bytes | |||
@@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore): | |||
If we have multiple entries for a given key ID, returns the most recent. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="server_keys_json", | |||
column="key_id", | |||
iterable=key_ids, | |||
keyvalues={"server_name": server_name}, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
rows = cast( | |||
List[Tuple[str, str, int, int, Union[bytes, memoryview]]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="server_keys_json", | |||
column="key_id", | |||
iterable=key_ids, | |||
keyvalues={"server_name": server_name}, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
), | |||
desc="get_server_keys_json_for_remote", | |||
), | |||
desc="get_server_keys_json_for_remote", | |||
) | |||
if not rows: | |||
return {} | |||
# We sort the rows so that the most recently added entry is picked up. | |||
rows.sort(key=lambda r: r["ts_added_ms"]) | |||
# We sort the rows by ts_added_ms so that the most recently added entry | |||
# will stomp over older entries in the dictionary. | |||
rows.sort(key=lambda r: r[2]) | |||
return { | |||
row["key_id"]: FetchKeyResultForRemote( | |||
key_id: FetchKeyResultForRemote( | |||
# Cast to bytes since postgresql returns a memoryview. | |||
key_json=bytes(row["key_json"]), | |||
valid_until_ts=row["ts_valid_until_ms"], | |||
added_ts=row["ts_added_ms"], | |||
key_json=bytes(key_json), | |||
valid_until_ts=ts_valid_until_ms, | |||
added_ts=ts_added_ms, | |||
) | |||
for row in rows | |||
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows | |||
} | |||
async def get_all_server_keys_json_for_remote( | |||
@@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore): | |||
if not rows: | |||
return {} | |||
# We sort the rows by ts_added_ms so that the most recently added entry | |||
# will stomp over older entries in the dictionary. | |||
rows.sort(key=lambda r: r["ts_added_ms"]) | |||
return { | |||
@@ -261,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) | |||
async def get_presence_for_users( | |||
self, user_ids: Iterable[str] | |||
) -> Mapping[str, UserPresenceState]: | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="presence_stream", | |||
column="user_id", | |||
iterable=user_ids, | |||
keyvalues={}, | |||
retcols=( | |||
"user_id", | |||
"state", | |||
"last_active_ts", | |||
"last_federation_update_ts", | |||
"last_user_sync_ts", | |||
"status_msg", | |||
"currently_active", | |||
# TODO All these columns are nullable, but we don't expect that: | |||
# https://github.com/matrix-org/synapse/issues/16467 | |||
rows = cast( | |||
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="presence_stream", | |||
column="user_id", | |||
iterable=user_ids, | |||
keyvalues={}, | |||
retcols=( | |||
"user_id", | |||
"state", | |||
"last_active_ts", | |||
"last_federation_update_ts", | |||
"last_user_sync_ts", | |||
"status_msg", | |||
"currently_active", | |||
), | |||
desc="get_presence_for_users", | |||
), | |||
desc="get_presence_for_users", | |||
) | |||
for row in rows: | |||
row["currently_active"] = bool(row["currently_active"]) | |||
return {row["user_id"]: UserPresenceState(**row) for row in rows} | |||
return { | |||
user_id: UserPresenceState( | |||
user_id=user_id, | |||
state=state, | |||
last_active_ts=last_active_ts, | |||
last_federation_update_ts=last_federation_update_ts, | |||
last_user_sync_ts=last_user_sync_ts, | |||
status_msg=status_msg, | |||
currently_active=bool(currently_active), | |||
) | |||
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows | |||
} | |||
async def should_user_receive_full_presence_with_token( | |||
self, | |||
@@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) | |||
limit = 100 | |||
offset = 0 | |||
while True: | |||
# TODO All these columns are nullable, but we don't expect that: | |||
# https://github.com/matrix-org/synapse/issues/16467 | |||
rows = cast( | |||
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], | |||
await self.db_pool.runInteraction( | |||
@@ -62,20 +62,34 @@ logger = logging.getLogger(__name__) | |||
def _load_rules( | |||
rawrules: List[JsonDict], | |||
rawrules: List[Tuple[str, int, str, str]], | |||
enabled_map: Dict[str, bool], | |||
experimental_config: ExperimentalConfig, | |||
) -> FilteredPushRules: | |||
"""Take the DB rows returned from the DB and convert them into a full | |||
`FilteredPushRules` object. | |||
Args: | |||
rawrules: List of tuples of: | |||
* rule ID | |||
* Priority lass | |||
* Conditions (as serialized JSON) | |||
* Actions (as serialized JSON) | |||
enabled_map: A dictionary of rule ID to a boolean of whether the rule is | |||
enabled. This might not include all rule IDs from rawrules. | |||
experimental_config: The `experimental_features` section of the Synapse | |||
config. (Used to check if various features are enabled.) | |||
Returns: | |||
A new FilteredPushRules object. | |||
""" | |||
ruleslist = [ | |||
PushRule.from_db( | |||
rule_id=rawrule["rule_id"], | |||
priority_class=rawrule["priority_class"], | |||
conditions=rawrule["conditions"], | |||
actions=rawrule["actions"], | |||
rule_id=rawrule[0], | |||
priority_class=rawrule[1], | |||
conditions=rawrule[2], | |||
actions=rawrule[3], | |||
) | |||
for rawrule in rawrules | |||
] | |||
@@ -183,7 +197,19 @@ class PushRulesWorkerStore( | |||
enabled_map = await self.get_push_rules_enabled_for_user(user_id) | |||
return _load_rules(rows, enabled_map, self.hs.config.experimental) | |||
return _load_rules( | |||
[ | |||
( | |||
row["rule_id"], | |||
row["priority_class"], | |||
row["conditions"], | |||
row["actions"], | |||
) | |||
for row in rows | |||
], | |||
enabled_map, | |||
self.hs.config.experimental, | |||
) | |||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: | |||
results = await self.db_pool.simple_select_list( | |||
@@ -221,21 +247,36 @@ class PushRulesWorkerStore( | |||
if not user_ids: | |||
return {} | |||
raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} | |||
raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = { | |||
user_id: [] for user_id in user_ids | |||
} | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="push_rules", | |||
column="user_name", | |||
iterable=user_ids, | |||
retcols=("*",), | |||
desc="bulk_get_push_rules", | |||
batch_size=1000, | |||
rows = cast( | |||
List[Tuple[str, str, int, int, str, str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="push_rules", | |||
column="user_name", | |||
iterable=user_ids, | |||
retcols=( | |||
"user_name", | |||
"rule_id", | |||
"priority_class", | |||
"priority", | |||
"conditions", | |||
"actions", | |||
), | |||
desc="bulk_get_push_rules", | |||
batch_size=1000, | |||
), | |||
) | |||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) | |||
# Sort by highest priority_class, then highest priority. | |||
rows.sort(key=lambda row: (-int(row[2]), -int(row[3]))) | |||
for row in rows: | |||
raw_rules.setdefault(row["user_name"], []).append(row) | |||
for user_name, rule_id, priority_class, _, conditions, actions in rows: | |||
raw_rules.setdefault(user_name, []).append( | |||
(rule_id, priority_class, conditions, actions) | |||
) | |||
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) | |||
@@ -256,17 +297,19 @@ class PushRulesWorkerStore( | |||
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="push_rules_enable", | |||
column="user_name", | |||
iterable=user_ids, | |||
retcols=("user_name", "rule_id", "enabled"), | |||
desc="bulk_get_push_rules_enabled", | |||
batch_size=1000, | |||
rows = cast( | |||
List[Tuple[str, str, Optional[int]]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="push_rules_enable", | |||
column="user_name", | |||
iterable=user_ids, | |||
retcols=("user_name", "rule_id", "enabled"), | |||
desc="bulk_get_push_rules_enabled", | |||
batch_size=1000, | |||
), | |||
) | |||
for row in rows: | |||
enabled = bool(row["enabled"]) | |||
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled | |||
for user_name, rule_id, enabled in rows: | |||
results.setdefault(user_name, {})[rule_id] = bool(enabled) | |||
return results | |||
async def get_all_push_rule_updates( | |||
@@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore): | |||
def get_all_relation_ids_for_event_with_types_txn( | |||
txn: LoggingTransaction, | |||
) -> List[str]: | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn=txn, | |||
table="event_relations", | |||
column="relation_type", | |||
iterable=relation_types, | |||
keyvalues={"relates_to_id": event_id}, | |||
retcols=["event_id"], | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn=txn, | |||
table="event_relations", | |||
column="relation_type", | |||
iterable=relation_types, | |||
keyvalues={"relates_to_id": event_id}, | |||
retcols=["event_id"], | |||
), | |||
) | |||
return [row["event_id"] for row in rows] | |||
return [row[0] for row in rows] | |||
return await self.db_pool.runInteraction( | |||
desc="get_all_relation_ids_for_event_with_types", | |||
@@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
complete. | |||
""" | |||
rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch( | |||
table="partial_state_rooms", | |||
column="room_id", | |||
iterable=room_ids, | |||
retcols=("room_id",), | |||
desc="is_partial_state_room_batched", | |||
) | |||
partial_state_rooms = {row_dict["room_id"] for row_dict in rows} | |||
rows = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="partial_state_rooms", | |||
column="room_id", | |||
iterable=room_ids, | |||
retcols=("room_id",), | |||
desc="is_partial_state_room_batched", | |||
), | |||
) | |||
partial_state_rooms = {row[0] for row in rows} | |||
return {room_id: room_id in partial_state_rooms for room_id in room_ids} | |||
async def get_join_event_id_and_device_lists_stream_id_for_partial_state( | |||
@@ -27,6 +27,7 @@ from typing import ( | |||
Set, | |||
Tuple, | |||
Union, | |||
cast, | |||
) | |||
import attr | |||
@@ -683,25 +684,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
Map from user_id to set of rooms that is currently in. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="current_state_events", | |||
column="state_key", | |||
iterable=user_ids, | |||
retcols=( | |||
"state_key", | |||
"room_id", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="current_state_events", | |||
column="state_key", | |||
iterable=user_ids, | |||
retcols=( | |||
"state_key", | |||
"room_id", | |||
), | |||
keyvalues={ | |||
"type": EventTypes.Member, | |||
"membership": Membership.JOIN, | |||
}, | |||
desc="get_rooms_for_users", | |||
), | |||
keyvalues={ | |||
"type": EventTypes.Member, | |||
"membership": Membership.JOIN, | |||
}, | |||
desc="get_rooms_for_users", | |||
) | |||
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} | |||
for row in rows: | |||
user_rooms[row["state_key"]].add(row["room_id"]) | |||
for state_key, room_id in rows: | |||
user_rooms[state_key].add(room_id) | |||
return {key: frozenset(rooms) for key, rooms in user_rooms.items()} | |||
@@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
Map from event ID to `user_id`, or None if event is not a join. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="room_memberships", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=("user_id", "event_id"), | |||
keyvalues={"membership": Membership.JOIN}, | |||
batch_size=1000, | |||
desc="_get_user_ids_from_membership_event_ids", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="room_memberships", | |||
column="event_id", | |||
iterable=event_ids, | |||
retcols=("event_id", "user_id"), | |||
keyvalues={"membership": Membership.JOIN}, | |||
batch_size=1000, | |||
desc="_get_user_ids_from_membership_event_ids", | |||
), | |||
) | |||
return {row["event_id"]: row["user_id"] for row in rows} | |||
return dict(rows) | |||
@cached(max_entries=10000) | |||
async def is_host_joined(self, room_id: str, host: str) -> bool: | |||
@@ -1202,21 +1209,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
membership event, otherwise the value is None. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="room_memberships", | |||
column="event_id", | |||
iterable=member_event_ids, | |||
retcols=("user_id", "membership", "event_id"), | |||
keyvalues={}, | |||
batch_size=500, | |||
desc="get_membership_from_event_ids", | |||
rows = cast( | |||
List[Tuple[str, str, str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="room_memberships", | |||
column="event_id", | |||
iterable=member_event_ids, | |||
retcols=("user_id", "membership", "event_id"), | |||
keyvalues={}, | |||
batch_size=500, | |||
desc="get_membership_from_event_ids", | |||
), | |||
) | |||
return { | |||
row["event_id"]: EventIdMembership( | |||
membership=row["membership"], user_id=row["user_id"] | |||
) | |||
for row in rows | |||
event_id: EventIdMembership(membership=membership, user_id=user_id) | |||
for user_id, membership, event_id in rows | |||
} | |||
async def is_local_host_in_room_ignoring_users( | |||
@@ -20,10 +20,12 @@ from typing import ( | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Set, | |||
Tuple, | |||
cast, | |||
) | |||
import attr | |||
@@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
Raises: | |||
RuntimeError if the state is unknown at any of the given events | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="event_to_state_groups", | |||
column="event_id", | |||
iterable=event_ids, | |||
keyvalues={}, | |||
retcols=("event_id", "state_group"), | |||
desc="_get_state_group_for_events", | |||
rows = cast( | |||
List[Tuple[str, int]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="event_to_state_groups", | |||
column="event_id", | |||
iterable=event_ids, | |||
keyvalues={}, | |||
retcols=("event_id", "state_group"), | |||
desc="_get_state_group_for_events", | |||
), | |||
) | |||
res = {row["event_id"]: row["state_group"] for row in rows} | |||
res = dict(rows) | |||
for e in event_ids: | |||
if e not in res: | |||
raise RuntimeError("No state group for unknown or outlier event %s" % e) | |||
@@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
The subset of state groups that are referenced. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="event_to_state_groups", | |||
column="state_group", | |||
iterable=state_groups, | |||
keyvalues={}, | |||
retcols=("DISTINCT state_group",), | |||
desc="get_referenced_state_groups", | |||
rows = cast( | |||
List[Tuple[int]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="event_to_state_groups", | |||
column="state_group", | |||
iterable=state_groups, | |||
keyvalues={}, | |||
retcols=("DISTINCT state_group",), | |||
desc="get_referenced_state_groups", | |||
), | |||
) | |||
return {row["state_group"] for row in rows} | |||
return {row[0] for row in rows} | |||
async def update_state_for_partial_state_event( | |||
self, | |||
@@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): | |||
# potentially stale, since there may have been a period where the | |||
# server didn't share a room with the remote user and therefore may | |||
# have missed any device updates. | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="current_state_events", | |||
column="room_id", | |||
iterable=to_delete, | |||
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, | |||
retcols=("state_key",), | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="current_state_events", | |||
column="room_id", | |||
iterable=to_delete, | |||
keyvalues={ | |||
"type": EventTypes.Member, | |||
"membership": Membership.JOIN, | |||
}, | |||
retcols=("state_key",), | |||
), | |||
) | |||
potentially_left_users = {row["state_key"] for row in rows} | |||
potentially_left_users = {row[0] for row in rows} | |||
# Now lets actually delete the rooms from the DB. | |||
self.db_pool.simple_delete_many_txn( | |||
@@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore): | |||
) -> Tuple[List[str], Dict[str, int], int, List[str], int]: | |||
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined] | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="current_state_events", | |||
column="type", | |||
iterable=[ | |||
EventTypes.Create, | |||
EventTypes.JoinRules, | |||
EventTypes.RoomHistoryVisibility, | |||
EventTypes.RoomEncryption, | |||
EventTypes.Name, | |||
EventTypes.Topic, | |||
EventTypes.RoomAvatar, | |||
EventTypes.CanonicalAlias, | |||
], | |||
keyvalues={"room_id": room_id, "state_key": ""}, | |||
retcols=["event_id"], | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="current_state_events", | |||
column="type", | |||
iterable=[ | |||
EventTypes.Create, | |||
EventTypes.JoinRules, | |||
EventTypes.RoomHistoryVisibility, | |||
EventTypes.RoomEncryption, | |||
EventTypes.Name, | |||
EventTypes.Topic, | |||
EventTypes.RoomAvatar, | |||
EventTypes.CanonicalAlias, | |||
], | |||
keyvalues={"room_id": room_id, "state_key": ""}, | |||
retcols=["event_id"], | |||
), | |||
) | |||
event_ids = cast(List[str], [row["event_id"] for row in rows]) | |||
event_ids = [row[0] for row in rows] | |||
txn.execute( | |||
""" | |||
@@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
async def get_destination_retry_timings_batch( | |||
self, destinations: StrCollection | |||
) -> Mapping[str, Optional[DestinationRetryTimings]]: | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="destinations", | |||
iterable=destinations, | |||
column="destination", | |||
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), | |||
desc="get_destination_retry_timings_batch", | |||
rows = cast( | |||
List[Tuple[str, Optional[int], Optional[int], Optional[int]]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="destinations", | |||
iterable=destinations, | |||
column="destination", | |||
retcols=( | |||
"destination", | |||
"failure_ts", | |||
"retry_last_ts", | |||
"retry_interval", | |||
), | |||
desc="get_destination_retry_timings_batch", | |||
), | |||
) | |||
return { | |||
row.pop("destination"): DestinationRetryTimings(**row) | |||
for row in rows | |||
if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"] | |||
destination: DestinationRetryTimings( | |||
failure_ts, retry_last_ts, retry_interval | |||
) | |||
for destination, failure_ts, retry_last_ts, retry_interval in rows | |||
if retry_last_ts and failure_ts and retry_interval | |||
} | |||
async def set_destination_retry_timings( | |||
@@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
# If a registration token was used, decrement the pending counter | |||
# before deleting the session. | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="ui_auth_sessions_credentials", | |||
column="session_id", | |||
iterable=session_ids, | |||
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, | |||
retcols=["result"], | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="ui_auth_sessions_credentials", | |||
column="session_id", | |||
iterable=session_ids, | |||
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, | |||
retcols=["result"], | |||
), | |||
) | |||
# Get the tokens used and how much pending needs to be decremented by. | |||
@@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
# registration token stage for that session will be True. | |||
# If a token was used to authenticate, but registration was | |||
# never completed, the result will be the token used. | |||
token = db_to_json(r["result"]) | |||
token = db_to_json(r[0]) | |||
if isinstance(token, str): | |||
token_counts[token] = token_counts.get(token, 0) + 1 | |||
# Update the `pending` counters. | |||
if len(token_counts) > 0: | |||
token_rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="registration_tokens", | |||
column="token", | |||
iterable=list(token_counts.keys()), | |||
keyvalues={}, | |||
retcols=["token", "pending"], | |||
token_rows = cast( | |||
List[Tuple[str, int]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="registration_tokens", | |||
column="token", | |||
iterable=list(token_counts.keys()), | |||
keyvalues={}, | |||
retcols=["token", "pending"], | |||
), | |||
) | |||
for token_row in token_rows: | |||
token = token_row["token"] | |||
new_pending = token_row["pending"] - token_counts[token] | |||
for token, pending in token_rows: | |||
new_pending = pending - token_counts[token] | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
table="registration_tokens", | |||
@@ -410,25 +410,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
) | |||
# Next fetch their profiles. Note that not all users have profiles. | |||
profile_rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="profiles", | |||
column="full_user_id", | |||
iterable=list(users_to_insert), | |||
retcols=( | |||
"full_user_id", | |||
"displayname", | |||
"avatar_url", | |||
profile_rows = cast( | |||
List[Tuple[str, Optional[str], Optional[str]]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="profiles", | |||
column="full_user_id", | |||
iterable=list(users_to_insert), | |||
retcols=( | |||
"full_user_id", | |||
"displayname", | |||
"avatar_url", | |||
), | |||
keyvalues={}, | |||
), | |||
keyvalues={}, | |||
) | |||
profiles = { | |||
row["full_user_id"]: _UserDirProfile( | |||
row["full_user_id"], | |||
row["displayname"], | |||
row["avatar_url"], | |||
) | |||
for row in profile_rows | |||
full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url) | |||
for full_user_id, displayname, avatar_url in profile_rows | |||
} | |||
profiles_to_insert = [ | |||
@@ -517,18 +516,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined] | |||
] | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="users", | |||
column="name", | |||
iterable=users, | |||
keyvalues={ | |||
"deactivated": 0, | |||
}, | |||
retcols=("name", "user_type"), | |||
rows = cast( | |||
List[Tuple[str, Optional[str]]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="users", | |||
column="name", | |||
iterable=users, | |||
keyvalues={ | |||
"deactivated": 0, | |||
}, | |||
retcols=("name", "user_type"), | |||
), | |||
) | |||
return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT] | |||
return [name for name, user_type in rows if user_type != UserTypes.SUPPORT] | |||
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: | |||
"""Check if the room is either world_readable or publically joinable""" | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Iterable, Mapping | |||
from typing import Iterable, List, Mapping, Tuple, cast | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.databases.main import CacheInvalidationWorkerStore | |||
@@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore): | |||
Returns: | |||
for each user, whether the user has requested erasure. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="erased_users", | |||
column="user_id", | |||
iterable=user_ids, | |||
retcols=("user_id",), | |||
desc="are_users_erased", | |||
rows = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="erased_users", | |||
column="user_id", | |||
iterable=user_ids, | |||
retcols=("user_id",), | |||
desc="are_users_erased", | |||
), | |||
) | |||
erased_users = {row["user_id"] for row in rows} | |||
erased_users = {row[0] for row in rows} | |||
return {u: u in erased_users for u in user_ids} | |||
@@ -13,7 +13,17 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
Optional, | |||
Set, | |||
Tuple, | |||
cast, | |||
) | |||
import attr | |||
@@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
"[purge] found %i state groups to delete", len(state_groups_to_delete) | |||
) | |||
rows = self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="state_group_edges", | |||
column="prev_state_group", | |||
iterable=state_groups_to_delete, | |||
keyvalues={}, | |||
retcols=("state_group",), | |||
rows = cast( | |||
List[Tuple[int]], | |||
self.db_pool.simple_select_many_txn( | |||
txn, | |||
table="state_group_edges", | |||
column="prev_state_group", | |||
iterable=state_groups_to_delete, | |||
keyvalues={}, | |||
retcols=("state_group",), | |||
), | |||
) | |||
remaining_state_groups = { | |||
row["state_group"] | |||
for row in rows | |||
if row["state_group"] not in state_groups_to_delete | |||
state_group | |||
for state_group, in rows | |||
if state_group not in state_groups_to_delete | |||
} | |||
logger.info( | |||
@@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
A mapping from state group to previous state group. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="state_group_edges", | |||
column="prev_state_group", | |||
iterable=state_groups, | |||
keyvalues={}, | |||
retcols=("prev_state_group", "state_group"), | |||
desc="get_previous_state_groups", | |||
rows = cast( | |||
List[Tuple[int, int]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="state_group_edges", | |||
column="prev_state_group", | |||
iterable=state_groups, | |||
keyvalues={}, | |||
retcols=("state_group", "prev_state_group"), | |||
desc="get_previous_state_groups", | |||
), | |||
) | |||
return {row["state_group"]: row["prev_state_group"] for row in rows} | |||
return dict(rows) | |||
async def purge_room_state( | |||
self, room_id: str, state_groups_to_delete: Collection[int] | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Dict, List, Set, Tuple | |||
from typing import Dict, List, Set, Tuple, cast | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.trial import unittest | |||
@@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||
self, events: List[EventBase] | |||
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: | |||
# Fetch the map from event ID -> (chain ID, sequence number) | |||
rows = self.get_success( | |||
self.store.db_pool.simple_select_many_batch( | |||
table="event_auth_chains", | |||
column="event_id", | |||
iterable=[e.event_id for e in events], | |||
retcols=("event_id", "chain_id", "sequence_number"), | |||
keyvalues={}, | |||
) | |||
rows = cast( | |||
List[Tuple[str, int, int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_many_batch( | |||
table="event_auth_chains", | |||
column="event_id", | |||
iterable=[e.event_id for e in events], | |||
retcols=("event_id", "chain_id", "sequence_number"), | |||
keyvalues={}, | |||
) | |||
), | |||
) | |||
chain_map = { | |||
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows | |||
event_id: (chain_id, sequence_number) | |||
for event_id, chain_id, sequence_number in rows | |||
} | |||
# Fetch all the links and pass them to the _LinkMap. | |||
rows = self.get_success( | |||
self.store.db_pool.simple_select_many_batch( | |||
table="event_auth_chain_links", | |||
column="origin_chain_id", | |||
iterable=[chain_id for chain_id, _ in chain_map.values()], | |||
retcols=( | |||
"origin_chain_id", | |||
"origin_sequence_number", | |||
"target_chain_id", | |||
"target_sequence_number", | |||
), | |||
keyvalues={}, | |||
) | |||
auth_chain_rows = cast( | |||
List[Tuple[int, int, int, int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_many_batch( | |||
table="event_auth_chain_links", | |||
column="origin_chain_id", | |||
iterable=[chain_id for chain_id, _ in chain_map.values()], | |||
retcols=( | |||
"origin_chain_id", | |||
"origin_sequence_number", | |||
"target_chain_id", | |||
"target_sequence_number", | |||
), | |||
keyvalues={}, | |||
) | |||
), | |||
) | |||
link_map = _LinkMap() | |||
for row in rows: | |||
for ( | |||
origin_chain_id, | |||
origin_sequence_number, | |||
target_chain_id, | |||
target_sequence_number, | |||
) in auth_chain_rows: | |||
added = link_map.add_link( | |||
(row["origin_chain_id"], row["origin_sequence_number"]), | |||
(row["target_chain_id"], row["target_sequence_number"]), | |||
(origin_chain_id, origin_sequence_number), | |||
(target_chain_id, target_sequence_number), | |||
) | |||
# We shouldn't have persisted any redundant links | |||