@@ -0,0 +1 @@ | |||
Convert various parts of the codebase to async/await. |
@@ -15,11 +15,10 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import List | |||
from canonicaljson import json | |||
from twisted.internet import defer | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json | |||
from synapse.storage.database import Database | |||
@@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
return {"notify_count": notify_count, "highlight_count": highlight_count} | |||
@defer.inlineCallbacks | |||
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): | |||
async def get_push_action_users_in_range( | |||
self, min_stream_ordering, max_stream_ordering | |||
): | |||
def f(txn): | |||
sql = ( | |||
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE" | |||
@@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) | |||
return [r[0] for r in txn] | |||
ret = yield self.db.runInteraction("get_push_action_users_in_range", f) | |||
ret = await self.db.runInteraction("get_push_action_users_in_range", f) | |||
return ret | |||
@defer.inlineCallbacks | |||
def get_unread_push_actions_for_user_in_range_for_http( | |||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20 | |||
): | |||
async def get_unread_push_actions_for_user_in_range_for_http( | |||
self, | |||
user_id: str, | |||
min_stream_ordering: int, | |||
max_stream_ordering: int, | |||
limit: int = 20, | |||
) -> List[dict]: | |||
"""Get a list of the most recent unread push actions for a given user, | |||
within the given stream ordering range. Called by the httppusher. | |||
Args: | |||
user_id (str): The user to fetch push actions for. | |||
min_stream_ordering(int): The exclusive lower bound on the | |||
user_id: The user to fetch push actions for. | |||
min_stream_ordering: The exclusive lower bound on the | |||
stream ordering of event push actions to fetch. | |||
max_stream_ordering(int): The inclusive upper bound on the | |||
max_stream_ordering: The inclusive upper bound on the | |||
stream ordering of event push actions to fetch. | |||
limit (int): The maximum number of rows to return. | |||
limit: The maximum number of rows to return. | |||
Returns: | |||
A promise which resolves to a list of dicts with the keys "event_id", | |||
"room_id", "stream_ordering", "actions". | |||
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". | |||
The list will be ordered by ascending stream_ordering. | |||
The list will have between 0~limit entries. | |||
""" | |||
@@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
after_read_receipt = yield self.db.runInteraction( | |||
after_read_receipt = await self.db.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt | |||
) | |||
@@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
no_read_receipt = yield self.db.runInteraction( | |||
no_read_receipt = await self.db.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt | |||
) | |||
@@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
# one of the subqueries may have hit the limit. | |||
return notifs[:limit] | |||
@defer.inlineCallbacks | |||
def get_unread_push_actions_for_user_in_range_for_email( | |||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20 | |||
): | |||
async def get_unread_push_actions_for_user_in_range_for_email( | |||
self, | |||
user_id: str, | |||
min_stream_ordering: int, | |||
max_stream_ordering: int, | |||
limit: int = 20, | |||
) -> List[dict]: | |||
"""Get a list of the most recent unread push actions for a given user, | |||
within the given stream ordering range. Called by the emailpusher | |||
Args: | |||
user_id (str): The user to fetch push actions for. | |||
min_stream_ordering(int): The exclusive lower bound on the | |||
user_id: The user to fetch push actions for. | |||
min_stream_ordering: The exclusive lower bound on the | |||
stream ordering of event push actions to fetch. | |||
max_stream_ordering(int): The inclusive upper bound on the | |||
max_stream_ordering: The inclusive upper bound on the | |||
stream ordering of event push actions to fetch. | |||
limit (int): The maximum number of rows to return. | |||
limit: The maximum number of rows to return. | |||
Returns: | |||
A promise which resolves to a list of dicts with the keys "event_id", | |||
"room_id", "stream_ordering", "actions", "received_ts". | |||
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". | |||
The list will be ordered by descending received_ts. | |||
The list will have between 0~limit entries. | |||
""" | |||
@@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
after_read_receipt = yield self.db.runInteraction( | |||
after_read_receipt = await self.db.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt | |||
) | |||
@@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, args) | |||
return txn.fetchall() | |||
no_read_receipt = yield self.db.runInteraction( | |||
no_read_receipt = await self.db.runInteraction( | |||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt | |||
) | |||
@@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
"add_push_actions_to_staging", _add_push_actions_to_staging_txn | |||
) | |||
@defer.inlineCallbacks | |||
def remove_push_actions_from_staging(self, event_id): | |||
async def remove_push_actions_from_staging(self, event_id: str) -> None: | |||
"""Called if we failed to persist the event to ensure that stale push | |||
actions don't build up in the DB | |||
Args: | |||
event_id (str) | |||
""" | |||
try: | |||
res = yield self.db.simple_delete( | |||
res = await self.db.simple_delete( | |||
table="event_push_actions_staging", | |||
keyvalues={"event_id": event_id}, | |||
desc="remove_push_actions_from_staging", | |||
@@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
return range_end | |||
@defer.inlineCallbacks | |||
def get_time_of_last_push_action_before(self, stream_ordering): | |||
async def get_time_of_last_push_action_before(self, stream_ordering): | |||
def f(txn): | |||
sql = ( | |||
"SELECT e.received_ts" | |||
@@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
txn.execute(sql, (stream_ordering,)) | |||
return txn.fetchone() | |||
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) | |||
result = await self.db.runInteraction("get_time_of_last_push_action_before", f) | |||
return result[0] if result else None | |||
@@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
self._start_rotate_notifs, 30 * 60 * 1000 | |||
) | |||
@defer.inlineCallbacks | |||
def get_push_actions_for_user( | |||
async def get_push_actions_for_user( | |||
self, user_id, before=None, limit=50, only_highlight=False | |||
): | |||
def f(txn): | |||
@@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
txn.execute(sql, args) | |||
return self.db.cursor_to_dict(txn) | |||
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) | |||
push_actions = await self.db.runInteraction("get_push_actions_for_user", f) | |||
for pa in push_actions: | |||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) | |||
return push_actions | |||
@defer.inlineCallbacks | |||
def get_latest_push_action_stream_ordering(self): | |||
async def get_latest_push_action_stream_ordering(self): | |||
def f(txn): | |||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") | |||
return txn.fetchone() | |||
result = yield self.db.runInteraction( | |||
result = await self.db.runInteraction( | |||
"get_latest_push_action_stream_ordering", f | |||
) | |||
return result[0] or 0 | |||
@@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
def _start_rotate_notifs(self): | |||
return run_as_background_process("rotate_notifs", self._rotate_notifs) | |||
@defer.inlineCallbacks | |||
def _rotate_notifs(self): | |||
async def _rotate_notifs(self): | |||
if self._doing_notif_rotation or self.stream_ordering_day_ago is None: | |||
return | |||
self._doing_notif_rotation = True | |||
@@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): | |||
while True: | |||
logger.info("Rotating notifications") | |||
caught_up = yield self.db.runInteraction( | |||
caught_up = await self.db.runInteraction( | |||
"_rotate_notifs", self._rotate_notifs_txn | |||
) | |||
if caught_up: | |||
break | |||
yield self.hs.get_clock().sleep(self._rotate_delay) | |||
await self.hs.get_clock().sleep(self._rotate_delay) | |||
finally: | |||
self._doing_notif_rotation = False | |||
@@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple | |||
from canonicaljson import json | |||
from twisted.internet import defer | |||
from synapse.api.constants import EventTypes | |||
from synapse.api.errors import StoreError | |||
from synapse.api.room_versions import RoomVersion, RoomVersions | |||
@@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.data_stores.main.search import SearchStore | |||
from synapse.storage.database import Database, LoggingTransaction | |||
from synapse.types import ThirdPartyInstanceID | |||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks | |||
from synapse.util.caches.descriptors import cached | |||
logger = logging.getLogger(__name__) | |||
@@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) | |||
@defer.inlineCallbacks | |||
def get_largest_public_rooms( | |||
async def get_largest_public_rooms( | |||
self, | |||
network_tuple: Optional[ThirdPartyInstanceID], | |||
search_filter: Optional[dict], | |||
@@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore): | |||
return results | |||
ret_val = yield self.db.runInteraction( | |||
ret_val = await self.db.runInteraction( | |||
"get_largest_public_rooms", _get_largest_public_rooms_txn | |||
) | |||
defer.returnValue(ret_val) | |||
return ret_val | |||
@cached(max_entries=10000) | |||
def is_room_blocked(self, room_id): | |||
@@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore): | |||
"get_rooms_paginate", _get_rooms_paginate_txn, | |||
) | |||
@cachedInlineCallbacks(max_entries=10000) | |||
def get_ratelimit_for_user(self, user_id): | |||
@cached(max_entries=10000) | |||
async def get_ratelimit_for_user(self, user_id): | |||
"""Check if there are any overrides for ratelimiting for the given | |||
user | |||
@@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
of RatelimitOverride are None or 0 then ratelimitng has been | |||
disabled for that user entirely. | |||
""" | |||
row = yield self.db.simple_select_one( | |||
row = await self.db.simple_select_one( | |||
table="ratelimit_override", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("messages_per_second", "burst_count"), | |||
@@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore): | |||
else: | |||
return None | |||
@cachedInlineCallbacks() | |||
def get_retention_policy_for_room(self, room_id): | |||
@cached() | |||
async def get_retention_policy_for_room(self, room_id): | |||
"""Get the retention policy for a given room. | |||
If no retention policy has been found for this room, returns a policy defined | |||
@@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore): | |||
return self.db.cursor_to_dict(txn) | |||
ret = yield self.db.runInteraction( | |||
ret = await self.db.runInteraction( | |||
"get_retention_policy_for_room", get_retention_policy_for_room_txn, | |||
) | |||
# If we don't know this room ID, ret will be None, in this case return the default | |||
# policy. | |||
if not ret: | |||
defer.returnValue( | |||
{ | |||
"min_lifetime": self.config.retention_default_min_lifetime, | |||
"max_lifetime": self.config.retention_default_max_lifetime, | |||
} | |||
) | |||
return { | |||
"min_lifetime": self.config.retention_default_min_lifetime, | |||
"max_lifetime": self.config.retention_default_max_lifetime, | |||
} | |||
row = ret[0] | |||
@@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore): | |||
if row["max_lifetime"] is None: | |||
row["max_lifetime"] = self.config.retention_default_max_lifetime | |||
defer.returnValue(row) | |||
return row | |||
def get_media_mxcs_in_room(self, room_id): | |||
"""Retrieves all the local and remote media MXC URIs in a given room | |||
@@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
self._background_add_rooms_room_version_column, | |||
) | |||
@defer.inlineCallbacks | |||
def _background_insert_retention(self, progress, batch_size): | |||
async def _background_insert_retention(self, progress, batch_size): | |||
"""Retrieves a list of all rooms within a range and inserts an entry for each of | |||
them into the room_retention table. | |||
NULLs the property's columns if missing from the retention event in the room's | |||
@@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
else: | |||
return False | |||
end = yield self.db.runInteraction( | |||
end = await self.db.runInteraction( | |||
"insert_room_retention", _background_insert_retention_txn, | |||
) | |||
if end: | |||
yield self.db.updates._end_background_update("insert_room_retention") | |||
await self.db.updates._end_background_update("insert_room_retention") | |||
defer.returnValue(batch_size) | |||
return batch_size | |||
async def _background_add_rooms_room_version_column( | |||
self, progress: dict, batch_size: int | |||
@@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
lock=False, | |||
) | |||
@defer.inlineCallbacks | |||
def store_room( | |||
async def store_room( | |||
self, | |||
room_id: str, | |||
room_creator_user_id: str, | |||
@@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
) | |||
with self._public_room_id_gen.get_next() as next_id: | |||
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) | |||
await self.db.runInteraction("store_room_txn", store_room_txn, next_id) | |||
except Exception as e: | |||
logger.error("store_room with room_id=%s failed: %s", room_id, e) | |||
raise StoreError(500, "Problem creating room.") | |||
@@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
lock=False, | |||
) | |||
@defer.inlineCallbacks | |||
def set_room_is_public(self, room_id, is_public): | |||
async def set_room_is_public(self, room_id, is_public): | |||
def set_room_is_public_txn(txn, next_id): | |||
self.db.simple_update_one_txn( | |||
txn, | |||
@@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
) | |||
with self._public_room_id_gen.get_next() as next_id: | |||
yield self.db.runInteraction( | |||
await self.db.runInteraction( | |||
"set_room_is_public", set_room_is_public_txn, next_id | |||
) | |||
self.hs.get_notifier().on_new_replication_data() | |||
@defer.inlineCallbacks | |||
def set_room_is_public_appservice( | |||
async def set_room_is_public_appservice( | |||
self, room_id, appservice_id, network_id, is_public | |||
): | |||
"""Edit the appservice/network specific public room list. | |||
@@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
) | |||
with self._public_room_id_gen.get_next() as next_id: | |||
yield self.db.runInteraction( | |||
await self.db.runInteraction( | |||
"set_room_is_public_appservice", | |||
set_room_is_public_appservice_txn, | |||
next_id, | |||
@@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
def get_current_public_room_stream_id(self): | |||
return self._public_room_id_gen.get_current_token() | |||
@defer.inlineCallbacks | |||
def block_room(self, room_id, user_id): | |||
async def block_room(self, room_id: str, user_id: str) -> None: | |||
"""Marks the room as blocked. Can be called multiple times. | |||
Args: | |||
room_id (str): Room to block | |||
user_id (str): Who blocked it | |||
Returns: | |||
Deferred | |||
room_id: Room to block | |||
user_id: Who blocked it | |||
""" | |||
yield self.db.simple_upsert( | |||
await self.db.simple_upsert( | |||
table="blocked_rooms", | |||
keyvalues={"room_id": room_id}, | |||
values={}, | |||
insertion_values={"user_id": user_id}, | |||
desc="block_room", | |||
) | |||
yield self.db.runInteraction( | |||
await self.db.runInteraction( | |||
"block_room_invalidation", | |||
self._invalidate_cache_and_stream, | |||
self.is_room_blocked, | |||
(room_id,), | |||
) | |||
@defer.inlineCallbacks | |||
def get_rooms_for_retention_period_in_range( | |||
self, min_ms, max_ms, include_null=False | |||
): | |||
async def get_rooms_for_retention_period_in_range( | |||
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False | |||
) -> Dict[str, dict]: | |||
"""Retrieves all of the rooms within the given retention range. | |||
Optionally includes the rooms which don't have a retention policy. | |||
Args: | |||
min_ms (int|None): Duration in milliseconds that define the lower limit of | |||
min_ms: Duration in milliseconds that define the lower limit of | |||
the range to handle (exclusive). If None, doesn't set a lower limit. | |||
max_ms (int|None): Duration in milliseconds that define the upper limit of | |||
max_ms: Duration in milliseconds that define the upper limit of | |||
the range to handle (inclusive). If None, doesn't set an upper limit. | |||
include_null (bool): Whether to include rooms which retention policy is NULL | |||
include_null: Whether to include rooms which retention policy is NULL | |||
in the returned set. | |||
Returns: | |||
dict[str, dict]: The rooms within this range, along with their retention | |||
policy. The key is "room_id", and maps to a dict describing the retention | |||
policy associated with this room ID. The keys for this nested dict are | |||
"min_lifetime" (int|None), and "max_lifetime" (int|None). | |||
The rooms within this range, along with their retention | |||
policy. The key is "room_id", and maps to a dict describing the retention | |||
policy associated with this room ID. The keys for this nested dict are | |||
"min_lifetime" (int|None), and "max_lifetime" (int|None). | |||
""" | |||
def get_rooms_for_retention_period_in_range_txn(txn): | |||
@@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
return rooms_dict | |||
rooms = yield self.db.runInteraction( | |||
rooms = await self.db.runInteraction( | |||
"get_rooms_for_retention_period_in_range", | |||
get_rooms_for_retention_period_in_range_txn, | |||
) | |||
defer.returnValue(rooms) | |||
return rooms |
@@ -16,12 +16,12 @@ | |||
import collections.abc | |||
import logging | |||
from collections import namedtuple | |||
from twisted.internet import defer | |||
from typing import Iterable, Optional, Set | |||
from synapse.api.constants import EventTypes, Membership | |||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError | |||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion | |||
from synapse.events import EventBase | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore | |||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore | |||
@@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
create_event = await self.get_create_event_for_room(room_id) | |||
return create_event.content.get("room_version", "1") | |||
@defer.inlineCallbacks | |||
def get_room_predecessor(self, room_id): | |||
async def get_room_predecessor(self, room_id: str) -> Optional[dict]: | |||
"""Get the predecessor of an upgraded room if it exists. | |||
Otherwise return None. | |||
Args: | |||
room_id (str) | |||
room_id: The room ID. | |||
Returns: | |||
Deferred[dict|None]: A dictionary containing the structure of the predecessor | |||
field from the room's create event. The structure is subject to other servers, | |||
but it is expected to be: | |||
* room_id (str): The room ID of the predecessor room | |||
* event_id (str): The ID of the tombstone event in the predecessor room | |||
A dictionary containing the structure of the predecessor | |||
field from the room's create event. The structure is subject to other servers, | |||
but it is expected to be: | |||
* room_id (str): The room ID of the predecessor room | |||
* event_id (str): The ID of the tombstone event in the predecessor room | |||
None if a predecessor key is not found, or is not a dictionary. | |||
None if a predecessor key is not found, or is not a dictionary. | |||
Raises: | |||
NotFoundError if the given room is unknown | |||
""" | |||
# Retrieve the room's create event | |||
create_event = yield self.get_create_event_for_room(room_id) | |||
create_event = await self.get_create_event_for_room(room_id) | |||
# Retrieve the predecessor key of the create event | |||
predecessor = create_event.content.get("predecessor", None) | |||
@@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
return predecessor | |||
@defer.inlineCallbacks | |||
def get_create_event_for_room(self, room_id): | |||
async def get_create_event_for_room(self, room_id: str) -> EventBase: | |||
"""Get the create state event for a room. | |||
Args: | |||
room_id (str) | |||
room_id: The room ID. | |||
Returns: | |||
Deferred[EventBase]: The room creation event. | |||
The room creation event. | |||
Raises: | |||
NotFoundError if the room is unknown | |||
""" | |||
state_ids = yield self.get_current_state_ids(room_id) | |||
state_ids = await self.get_current_state_ids(room_id) | |||
create_id = state_ids.get((EventTypes.Create, "")) | |||
# If we can't find the create event, assume we've hit a dead end | |||
@@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
raise NotFoundError("Unknown room %s" % (room_id,)) | |||
# Retrieve the room's create event and return | |||
create_event = yield self.get_event(create_id) | |||
create_event = await self.get_event(create_id) | |||
return create_event | |||
@cached(max_entries=100000, iterable=True) | |||
@@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn | |||
) | |||
@defer.inlineCallbacks | |||
def get_canonical_alias_for_room(self, room_id): | |||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: | |||
"""Get canonical alias for room, if any | |||
Args: | |||
room_id (str) | |||
room_id: The room ID | |||
Returns: | |||
Deferred[str|None]: The canonical alias, if any | |||
The canonical alias, if any | |||
""" | |||
state = yield self.get_filtered_current_state_ids( | |||
state = await self.get_filtered_current_state_ids( | |||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) | |||
) | |||
@@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
if not event_id: | |||
return | |||
event = yield self.get_event(event_id, allow_none=True) | |||
event = await self.get_event(event_id, allow_none=True) | |||
if not event: | |||
return | |||
@@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
return {row["event_id"]: row["state_group"] for row in rows} | |||
@defer.inlineCallbacks | |||
def get_referenced_state_groups(self, state_groups): | |||
async def get_referenced_state_groups( | |||
self, state_groups: Iterable[int] | |||
) -> Set[int]: | |||
"""Check if the state groups are referenced by events. | |||
Args: | |||
state_groups (Iterable[int]) | |||
state_groups | |||
Returns: | |||
Deferred[set[int]]: The subset of state groups that are | |||
referenced. | |||
The subset of state groups that are referenced. | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = await self.db.simple_select_many_batch( | |||
table="event_to_state_groups", | |||
column="state_group", | |||
iterable=state_groups, | |||
@@ -16,8 +16,8 @@ | |||
import logging | |||
from itertools import chain | |||
from typing import Tuple | |||
from twisted.internet import defer | |||
from twisted.internet.defer import DeferredLock | |||
from synapse.api.constants import EventTypes, Membership | |||
@@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore): | |||
""" | |||
return (ts // self.stats_bucket_size) * self.stats_bucket_size | |||
@defer.inlineCallbacks | |||
def _populate_stats_process_users(self, progress, batch_size): | |||
async def _populate_stats_process_users(self, progress, batch_size): | |||
""" | |||
This is a background update which regenerates statistics for users. | |||
""" | |||
if not self.stats_enabled: | |||
yield self.db.updates._end_background_update("populate_stats_process_users") | |||
await self.db.updates._end_background_update("populate_stats_process_users") | |||
return 1 | |||
last_user_id = progress.get("last_user_id", "") | |||
@@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore): | |||
txn.execute(sql, (last_user_id, batch_size)) | |||
return [r for r, in txn] | |||
users_to_work_on = yield self.db.runInteraction( | |||
users_to_work_on = await self.db.runInteraction( | |||
"_populate_stats_process_users", _get_next_batch | |||
) | |||
# No more rooms -- complete the transaction. | |||
if not users_to_work_on: | |||
yield self.db.updates._end_background_update("populate_stats_process_users") | |||
await self.db.updates._end_background_update("populate_stats_process_users") | |||
return 1 | |||
for user_id in users_to_work_on: | |||
yield self._calculate_and_set_initial_state_for_user(user_id) | |||
await self._calculate_and_set_initial_state_for_user(user_id) | |||
progress["last_user_id"] = user_id | |||
yield self.db.runInteraction( | |||
await self.db.runInteraction( | |||
"populate_stats_process_users", | |||
self.db.updates._background_update_progress_txn, | |||
"populate_stats_process_users", | |||
@@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore): | |||
return len(users_to_work_on) | |||
@defer.inlineCallbacks | |||
def _populate_stats_process_rooms(self, progress, batch_size): | |||
async def _populate_stats_process_rooms(self, progress, batch_size): | |||
""" | |||
This is a background update which regenerates statistics for rooms. | |||
""" | |||
if not self.stats_enabled: | |||
yield self.db.updates._end_background_update("populate_stats_process_rooms") | |||
await self.db.updates._end_background_update("populate_stats_process_rooms") | |||
return 1 | |||
last_room_id = progress.get("last_room_id", "") | |||
@@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore): | |||
txn.execute(sql, (last_room_id, batch_size)) | |||
return [r for r, in txn] | |||
rooms_to_work_on = yield self.db.runInteraction( | |||
rooms_to_work_on = await self.db.runInteraction( | |||
"populate_stats_rooms_get_batch", _get_next_batch | |||
) | |||
# No more rooms -- complete the transaction. | |||
if not rooms_to_work_on: | |||
yield self.db.updates._end_background_update("populate_stats_process_rooms") | |||
await self.db.updates._end_background_update("populate_stats_process_rooms") | |||
return 1 | |||
for room_id in rooms_to_work_on: | |||
yield self._calculate_and_set_initial_state_for_room(room_id) | |||
await self._calculate_and_set_initial_state_for_room(room_id) | |||
progress["last_room_id"] = room_id | |||
yield self.db.runInteraction( | |||
await self.db.runInteraction( | |||
"_populate_stats_process_rooms", | |||
self.db.updates._background_update_progress_txn, | |||
"populate_stats_process_rooms", | |||
@@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore): | |||
return room_deltas, user_deltas | |||
@defer.inlineCallbacks | |||
def _calculate_and_set_initial_state_for_room(self, room_id): | |||
async def _calculate_and_set_initial_state_for_room( | |||
self, room_id: str | |||
) -> Tuple[dict, dict, int]: | |||
"""Calculate and insert an entry into room_stats_current. | |||
Args: | |||
room_id (str) | |||
room_id: The room ID under calculation. | |||
Returns: | |||
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership | |||
counts and stream position. | |||
A tuple of room state, membership counts and stream position. | |||
""" | |||
def _fetch_current_state_stats(txn): | |||
@@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore): | |||
current_state_events_count, | |||
users_in_room, | |||
pos, | |||
) = yield self.db.runInteraction( | |||
) = await self.db.runInteraction( | |||
"get_initial_state_for_room", _fetch_current_state_stats | |||
) | |||
state_event_map = yield self.get_events(event_ids, get_prev_content=False) | |||
state_event_map = await self.get_events(event_ids, get_prev_content=False) | |||
room_state = { | |||
"join_rules": None, | |||
@@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore): | |||
event.content.get("m.federate", True) is True | |||
) | |||
yield self.update_room_state(room_id, room_state) | |||
await self.update_room_state(room_id, room_state) | |||
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] | |||
yield self.update_stats_delta( | |||
await self.update_stats_delta( | |||
ts=self.clock.time_msec(), | |||
stats_type="room", | |||
stats_id=room_id, | |||
@@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore): | |||
}, | |||
) | |||
@defer.inlineCallbacks | |||
def _calculate_and_set_initial_state_for_user(self, user_id): | |||
async def _calculate_and_set_initial_state_for_user(self, user_id): | |||
def _calculate_and_set_initial_state_for_user_txn(txn): | |||
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) | |||
@@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore): | |||
(count,) = txn.fetchone() | |||
return count, pos | |||
joined_rooms, pos = yield self.db.runInteraction( | |||
joined_rooms, pos = await self.db.runInteraction( | |||
"calculate_and_set_initial_state_for_user", | |||
_calculate_and_set_initial_state_for_user_txn, | |||
) | |||
yield self.update_stats_delta( | |||
await self.update_stats_delta( | |||
ts=self.clock.time_msec(), | |||
stats_type="user", | |||
stats_id=user_id, | |||
@@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
"get_state_group_delta", _get_state_group_delta_txn | |||
) | |||
@defer.inlineCallbacks | |||
def _get_state_groups_from_groups( | |||
async def _get_state_groups_from_groups( | |||
self, groups: List[int], state_filter: StateFilter | |||
): | |||
) -> Dict[int, StateMap[str]]: | |||
"""Returns the state groups for a given set of groups from the | |||
database, filtering on types of state events. | |||
@@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
state_filter: The state filter used to fetch state | |||
from the database. | |||
Returns: | |||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. | |||
Dict of state group to state map. | |||
""" | |||
results = {} | |||
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] | |||
for chunk in chunks: | |||
res = yield self.db.runInteraction( | |||
res = await self.db.runInteraction( | |||
"_get_state_groups_from_groups", | |||
self._get_state_groups_from_groups_txn, | |||
chunk, | |||
@@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
return state_filter.filter_state(state_dict_ids), not missing_types | |||
@defer.inlineCallbacks | |||
def _get_state_for_groups( | |||
async def _get_state_for_groups( | |||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() | |||
): | |||
) -> Dict[int, StateMap[str]]: | |||
"""Gets the state at each of a list of state groups, optionally | |||
filtering by type/state_key | |||
@@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
state_filter: The state filter used to fetch state | |||
from the database. | |||
Returns: | |||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. | |||
Dict of state group to state map. | |||
""" | |||
member_filter, non_member_filter = state_filter.get_member_split() | |||
@@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
( | |||
non_member_state, | |||
incomplete_groups_nm, | |||
) = yield self._get_state_for_groups_using_cache( | |||
) = self._get_state_for_groups_using_cache( | |||
groups, self._state_group_cache, state_filter=non_member_filter | |||
) | |||
( | |||
member_state, | |||
incomplete_groups_m, | |||
) = yield self._get_state_for_groups_using_cache( | |||
(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( | |||
groups, self._state_group_members_cache, state_filter=member_filter | |||
) | |||
@@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
# Help the cache hit ratio by expanding the filter a bit | |||
db_state_filter = state_filter.return_expanded() | |||
group_to_state_dict = yield self._get_state_groups_from_groups( | |||
group_to_state_dict = await self._get_state_groups_from_groups( | |||
list(incomplete_groups), state_filter=db_state_filter | |||
) | |||
@@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
((sg,) for sg in state_groups_to_delete), | |||
) | |||
@defer.inlineCallbacks | |||
def get_previous_state_groups(self, state_groups): | |||
async def get_previous_state_groups( | |||
self, state_groups: Iterable[int] | |||
) -> Dict[int, int]: | |||
"""Fetch the previous groups of the given state groups. | |||
Args: | |||
state_groups (Iterable[int]) | |||
state_groups | |||
Returns: | |||
Deferred[dict[int, int]]: mapping from state group to previous | |||
state group. | |||
A mapping from state group to previous state group. | |||
""" | |||
rows = yield self.db.simple_select_many_batch( | |||
rows = await self.db.simple_select_many_batch( | |||
table="state_group_edges", | |||
column="prev_state_group", | |||
iterable=state_groups, | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar | |||
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar | |||
import attr | |||
@@ -419,7 +419,7 @@ class StateGroupStorage(object): | |||
def _get_state_groups_from_groups( | |||
self, groups: List[int], state_filter: StateFilter | |||
): | |||
) -> Awaitable[Dict[int, StateMap[str]]]: | |||
"""Returns the state groups for a given set of groups, filtering on | |||
types of state events. | |||
@@ -429,7 +429,7 @@ class StateGroupStorage(object): | |||
from the database. | |||
Returns: | |||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. | |||
Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_groups_from_groups(groups, state_filter) | |||
@@ -532,7 +532,7 @@ class StateGroupStorage(object): | |||
def _get_state_for_groups( | |||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() | |||
): | |||
) -> Awaitable[Dict[int, StateMap[str]]]: | |||
"""Gets the state at each of a list of state groups, optionally | |||
filtering by type/state_key | |||
@@ -540,8 +540,9 @@ class StateGroupStorage(object): | |||
groups: list of state groups for which we want to get the state. | |||
state_filter: The state filter used to fetch state. | |||
from the database. | |||
Returns: | |||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. | |||
Dict of state group to state map. | |||
""" | |||
return self.stores.state._get_state_for_groups(groups, state_filter) | |||
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_get_unread_push_actions_for_user_in_range_for_http(self): | |||
yield self.store.get_unread_push_actions_for_user_in_range_for_http( | |||
USER_ID, 0, 1000, 20 | |||
yield defer.ensureDeferred( | |||
self.store.get_unread_push_actions_for_user_in_range_for_http( | |||
USER_ID, 0, 1000, 20 | |||
) | |||
) | |||
@defer.inlineCallbacks | |||
def test_get_unread_push_actions_for_user_in_range_for_email(self): | |||
yield self.store.get_unread_push_actions_for_user_in_range_for_email( | |||
USER_ID, 0, 1000, 20 | |||
yield defer.ensureDeferred( | |||
self.store.get_unread_push_actions_for_user_in_range_for_email( | |||
USER_ID, 0, 1000, 20 | |||
) | |||
) | |||
@defer.inlineCallbacks | |||
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): | |||
self.alias = RoomAlias.from_string("#a-room-name:test") | |||
self.u_creator = UserID.from_string("@creator:test") | |||
yield self.store.store_room( | |||
self.room.to_string(), | |||
room_creator_user_id=self.u_creator.to_string(), | |||
is_public=True, | |||
room_version=RoomVersions.V1, | |||
yield defer.ensureDeferred( | |||
self.store.store_room( | |||
self.room.to_string(), | |||
room_creator_user_id=self.u_creator.to_string(), | |||
is_public=True, | |||
room_version=RoomVersions.V1, | |||
) | |||
) | |||
@defer.inlineCallbacks | |||
@@ -88,11 +90,13 @@ class RoomEventsStoreTestCase(unittest.TestCase): | |||
self.room = RoomID.from_string("!abcde:test") | |||
yield self.store.store_room( | |||
self.room.to_string(), | |||
room_creator_user_id="@creator:text", | |||
is_public=True, | |||
room_version=RoomVersions.V1, | |||
yield defer.ensureDeferred( | |||
self.store.store_room( | |||
self.room.to_string(), | |||
room_creator_user_id="@creator:text", | |||
is_public=True, | |||
room_version=RoomVersions.V1, | |||
) | |||
) | |||
@defer.inlineCallbacks | |||
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): | |||
self.room = RoomID.from_string("!abc123:test") | |||
yield self.store.store_room( | |||
self.room.to_string(), | |||
room_creator_user_id="@creator:text", | |||
is_public=True, | |||
room_version=RoomVersions.V1, | |||
yield defer.ensureDeferred( | |||
self.store.store_room( | |||
self.room.to_string(), | |||
room_creator_user_id="@creator:text", | |||
is_public=True, | |||
room_version=RoomVersions.V1, | |||
) | |||
) | |||
@defer.inlineCallbacks | |||