This should use fewer allocations and improves type hints.tags/v1.96.0rc1
@@ -0,0 +1 @@ | |||
Reduce memory allocations. |
@@ -103,10 +103,10 @@ class DeactivateAccountHandler: | |||
# Attempt to unbind any known bound threepids to this account from identity | |||
# server(s). | |||
bound_threepids = await self.store.user_get_bound_threepids(user_id) | |||
for threepid in bound_threepids: | |||
for medium, address in bound_threepids: | |||
try: | |||
result = await self._identity_handler.try_unbind_threepid( | |||
user_id, threepid["medium"], threepid["address"], id_server | |||
user_id, medium, address, id_server | |||
) | |||
except Exception: | |||
# Do we want this to be a fatal error or should we carry on? | |||
@@ -1206,10 +1206,7 @@ class SsoHandler: | |||
# We have no guarantee that all the devices of that session are for the same | |||
# `user_id`. Hence, we have to iterate over the list of devices and log them out | |||
# one by one. | |||
for device in devices: | |||
user_id = device["user_id"] | |||
device_id = device["device_id"] | |||
for user_id, device_id in devices: | |||
# If the user_id associated with that device/session is not the one we got | |||
# out of the `sub` claim, skip that device and show log an error. | |||
if expected_user_id is not None and user_id != expected_user_id: | |||
@@ -606,13 +606,16 @@ class DatabasePool: | |||
If the background updates have not completed, wait 15 sec and check again. | |||
""" | |||
updates = await self.simple_select_list( | |||
"background_updates", | |||
keyvalues=None, | |||
retcols=["update_name"], | |||
desc="check_background_updates", | |||
updates = cast( | |||
List[Tuple[str]], | |||
await self.simple_select_list( | |||
"background_updates", | |||
keyvalues=None, | |||
retcols=["update_name"], | |||
desc="check_background_updates", | |||
), | |||
) | |||
background_update_names = [x["update_name"] for x in updates] | |||
background_update_names = [x[0] for x in updates] | |||
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): | |||
if update_name not in background_update_names: | |||
@@ -1804,9 +1807,9 @@ class DatabasePool: | |||
keyvalues: Optional[Dict[str, Any]], | |||
retcols: Collection[str], | |||
desc: str = "simple_select_list", | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[Any, ...]]: | |||
"""Executes a SELECT query on the named table, which may return zero or | |||
more rows, returning the result as a list of dicts. | |||
more rows, returning the result as a list of tuples. | |||
Args: | |||
table: the table name | |||
@@ -1817,8 +1820,7 @@ class DatabasePool: | |||
desc: description of the transaction, for logging and metrics | |||
Returns: | |||
A list of dictionaries, one per result row, each a mapping between the | |||
column names from `retcols` and that column's value for the row. | |||
A list of tuples, one per result row, each the retcolumn's value for the row. | |||
""" | |||
return await self.runInteraction( | |||
desc, | |||
@@ -1836,9 +1838,9 @@ class DatabasePool: | |||
table: str, | |||
keyvalues: Optional[Dict[str, Any]], | |||
retcols: Iterable[str], | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[Any, ...]]: | |||
"""Executes a SELECT query on the named table, which may return zero or | |||
more rows, returning the result as a list of dicts. | |||
more rows, returning the result as a list of tuples. | |||
Args: | |||
txn: Transaction object | |||
@@ -1849,8 +1851,7 @@ class DatabasePool: | |||
retcols: the names of the columns to return | |||
Returns: | |||
A list of dictionaries, one per result row, each a mapping between the | |||
column names from `retcols` and that column's value for the row. | |||
A list of tuples, one per result row, each the retcolumn's value for the row. | |||
""" | |||
if keyvalues: | |||
sql = "SELECT %s FROM %s WHERE %s" % ( | |||
@@ -1863,7 +1864,7 @@ class DatabasePool: | |||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table) | |||
txn.execute(sql) | |||
return cls.cursor_to_dict(txn) | |||
return txn.fetchall() | |||
async def simple_select_many_batch( | |||
self, | |||
@@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
def get_account_data_for_room_txn( | |||
txn: LoggingTransaction, | |||
) -> Dict[str, JsonDict]: | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"room_account_data", | |||
{"user_id": user_id, "room_id": room_id}, | |||
["account_data_type", "content"], | |||
) -> Dict[str, JsonMapping]: | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
self.db_pool.simple_select_list_txn( | |||
txn, | |||
table="room_account_data", | |||
keyvalues={"user_id": user_id, "room_id": room_id}, | |||
retcols=["account_data_type", "content"], | |||
), | |||
) | |||
return { | |||
row["account_data_type"]: db_to_json(row["content"]) for row in rows | |||
account_data_type: db_to_json(content) | |||
for account_data_type, content in rows | |||
} | |||
return await self.db_pool.runInteraction( | |||
@@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore( | |||
Returns: | |||
A list of ApplicationServices, which may be empty. | |||
""" | |||
results = await self.db_pool.simple_select_list( | |||
"application_services_state", {"state": state.value}, ["as_id"] | |||
results = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_list( | |||
table="application_services_state", | |||
keyvalues={"state": state.value}, | |||
retcols=("as_id",), | |||
), | |||
) | |||
# NB: This assumes this class is linked with ApplicationServiceStore | |||
as_list = self.get_app_services() | |||
services = [] | |||
for res in results: | |||
for (as_id,) in results: | |||
for service in as_list: | |||
if service.id == res["as_id"]: | |||
if service.id == as_id: | |||
services.append(service) | |||
return services | |||
@@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke | |||
if device_id is not None: | |||
keyvalues["device_id"] = device_id | |||
res = await self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues=keyvalues, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
res = cast( | |||
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], | |||
await self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues=keyvalues, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
), | |||
) | |||
return { | |||
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo( | |||
user_id=d["user_id"], | |||
device_id=d["device_id"], | |||
ip=d["ip"], | |||
user_agent=d["user_agent"], | |||
last_seen=d["last_seen"], | |||
(user_id, device_id): DeviceLastConnectionInfo( | |||
user_id=user_id, | |||
device_id=device_id, | |||
ip=ip, | |||
user_agent=user_agent, | |||
last_seen=last_seen, | |||
) | |||
for d in res | |||
for user_id, ip, user_agent, device_id, last_seen in res | |||
} | |||
async def _get_user_ip_and_agents_from_database( | |||
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
allow_none=True, | |||
) | |||
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: | |||
async def get_devices_by_user( | |||
self, user_id: str | |||
) -> Dict[str, Dict[str, Optional[str]]]: | |||
"""Retrieve all of a user's registered devices. Only returns devices | |||
that are not marked as hidden. | |||
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
user_id: | |||
Returns: | |||
A mapping from device_id to a dict containing "device_id", "user_id" | |||
and "display_name" for each device. | |||
and "display_name" for each device. Display name may be null. | |||
""" | |||
devices = await self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "hidden": False}, | |||
retcols=("user_id", "device_id", "display_name"), | |||
desc="get_devices_by_user", | |||
devices = cast( | |||
List[Tuple[str, str, Optional[str]]], | |||
await self.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "hidden": False}, | |||
retcols=("user_id", "device_id", "display_name"), | |||
desc="get_devices_by_user", | |||
), | |||
) | |||
return {d["device_id"]: d for d in devices} | |||
return { | |||
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]} | |||
for d in devices | |||
} | |||
async def get_devices_by_auth_provider_session_id( | |||
self, auth_provider_id: str, auth_provider_session_id: str | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[str, str]]: | |||
"""Retrieve the list of devices associated with a SSO IdP session ID. | |||
Args: | |||
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
Returns: | |||
A list of dicts containing the device_id and the user_id of each device | |||
""" | |||
return await self.db_pool.simple_select_list( | |||
table="device_auth_providers", | |||
keyvalues={ | |||
"auth_provider_id": auth_provider_id, | |||
"auth_provider_session_id": auth_provider_session_id, | |||
}, | |||
retcols=("user_id", "device_id"), | |||
desc="get_devices_by_auth_provider_session_id", | |||
return cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="device_auth_providers", | |||
keyvalues={ | |||
"auth_provider_id": auth_provider_id, | |||
"auth_provider_session_id": auth_provider_session_id, | |||
}, | |||
retcols=("user_id", "device_id"), | |||
desc="get_devices_by_auth_provider_session_id", | |||
), | |||
) | |||
@trace | |||
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
async def get_cached_devices_for_user( | |||
self, user_id: str | |||
) -> Mapping[str, JsonMapping]: | |||
devices = await self.db_pool.simple_select_list( | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("device_id", "content"), | |||
desc="get_cached_devices_for_user", | |||
devices = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("device_id", "content"), | |||
desc="get_cached_devices_for_user", | |||
), | |||
) | |||
return { | |||
device["device_id"]: db_to_json(device["content"]) for device in devices | |||
} | |||
return {device[0]: db_to_json(device[1]) for device in devices} | |||
def get_cached_device_list_changes( | |||
self, | |||
@@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
The IDs of users whose device lists need resync. | |||
""" | |||
if user_ids: | |||
row_tuples = cast( | |||
rows = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_many_batch( | |||
table="device_lists_remote_resync", | |||
@@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
desc="get_user_ids_requiring_device_list_resync_with_iterable", | |||
), | |||
) | |||
return {row[0] for row in row_tuples} | |||
else: | |||
rows = cast( | |||
List[Dict[str, str]], | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_list( | |||
table="device_lists_remote_resync", | |||
keyvalues=None, | |||
@@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
), | |||
) | |||
return {row["user_id"] for row in rows} | |||
return {row[0] for row in rows} | |||
async def mark_remote_users_device_caches_as_stale( | |||
self, user_ids: StrCollection | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast | |||
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast | |||
from typing_extensions import Literal, TypedDict | |||
@@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): | |||
if session_id: | |||
keyvalues["session_id"] = session_id | |||
rows = await self.db_pool.simple_select_list( | |||
table="e2e_room_keys", | |||
keyvalues=keyvalues, | |||
retcols=( | |||
"user_id", | |||
"room_id", | |||
"session_id", | |||
"first_message_index", | |||
"forwarded_count", | |||
"is_verified", | |||
"session_data", | |||
rows = cast( | |||
List[Tuple[str, str, int, int, int, str]], | |||
await self.db_pool.simple_select_list( | |||
table="e2e_room_keys", | |||
keyvalues=keyvalues, | |||
retcols=( | |||
"room_id", | |||
"session_id", | |||
"first_message_index", | |||
"forwarded_count", | |||
"is_verified", | |||
"session_data", | |||
), | |||
desc="get_e2e_room_keys", | |||
), | |||
desc="get_e2e_room_keys", | |||
) | |||
sessions: Dict[ | |||
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] | |||
] = {"rooms": {}} | |||
for row in rows: | |||
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) | |||
room_entry["sessions"][row["session_id"]] = { | |||
"first_message_index": row["first_message_index"], | |||
"forwarded_count": row["forwarded_count"], | |||
for ( | |||
room_id, | |||
session_id, | |||
first_message_index, | |||
forwarded_count, | |||
is_verified, | |||
session_data, | |||
) in rows: | |||
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}}) | |||
room_entry["sessions"][session_id] = { | |||
"first_message_index": first_message_index, | |||
"forwarded_count": forwarded_count, | |||
# is_verified must be returned to the client as a boolean | |||
"is_verified": bool(row["is_verified"]), | |||
"session_data": db_to_json(row["session_data"]), | |||
"is_verified": bool(is_verified), | |||
"session_data": db_to_json(session_data), | |||
} | |||
return sessions | |||
@@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
# keeping only the forward extremities (i.e. the events not referenced | |||
# by other events in the queue). We do this so that we can always | |||
# backpaginate in all the events we have dropped. | |||
rows = await self.db_pool.simple_select_list( | |||
table="federation_inbound_events_staging", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("event_id", "event_json"), | |||
desc="prune_staged_events_in_room_fetch", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="federation_inbound_events_staging", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("event_id", "event_json"), | |||
desc="prune_staged_events_in_room_fetch", | |||
), | |||
) | |||
# Find the set of events referenced by those in the queue, as well as | |||
# collecting all the event IDs in the queue. | |||
referenced_events: Set[str] = set() | |||
seen_events: Set[str] = set() | |||
for row in rows: | |||
event_id = row["event_id"] | |||
for event_id, event_json in rows: | |||
seen_events.add(event_id) | |||
event_d = db_to_json(row["event_json"]) | |||
event_d = db_to_json(event_json) | |||
# We don't bother parsing the dicts into full blown event objects, | |||
# as that is needlessly expensive. | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Dict, FrozenSet | |||
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.databases.main import CacheInvalidationWorkerStore | |||
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): | |||
Returns: | |||
the features currently enabled for the user | |||
""" | |||
enabled = await self.db_pool.simple_select_list( | |||
"per_user_experimental_features", | |||
{"user_id": user_id, "enabled": True}, | |||
["feature"], | |||
enabled = cast( | |||
List[Tuple[str]], | |||
await self.db_pool.simple_select_list( | |||
table="per_user_experimental_features", | |||
keyvalues={"user_id": user_id, "enabled": True}, | |||
retcols=("feature",), | |||
), | |||
) | |||
return frozenset(feature["feature"] for feature in enabled) | |||
return frozenset(feature[0] for feature in enabled) | |||
async def set_features_for_user( | |||
self, | |||
@@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore): | |||
If we have multiple entries for a given key ID, returns the most recent. | |||
""" | |||
rows = await self.db_pool.simple_select_list( | |||
table="server_keys_json", | |||
keyvalues={"server_name": server_name}, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
rows = cast( | |||
List[Tuple[str, str, int, int, Union[bytes, memoryview]]], | |||
await self.db_pool.simple_select_list( | |||
table="server_keys_json", | |||
keyvalues={"server_name": server_name}, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
), | |||
desc="get_server_keys_json_for_remote", | |||
), | |||
desc="get_server_keys_json_for_remote", | |||
) | |||
if not rows: | |||
@@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore): | |||
# We sort the rows by ts_added_ms so that the most recently added entry | |||
# will stomp over older entries in the dictionary. | |||
rows.sort(key=lambda r: r["ts_added_ms"]) | |||
rows.sort(key=lambda r: r[2]) | |||
return { | |||
row["key_id"]: FetchKeyResultForRemote( | |||
key_id: FetchKeyResultForRemote( | |||
# Cast to bytes since postgresql returns a memoryview. | |||
key_json=bytes(row["key_json"]), | |||
valid_until_ts=row["ts_valid_until_ms"], | |||
added_ts=row["ts_added_ms"], | |||
key_json=bytes(key_json), | |||
valid_until_ts=ts_valid_until_ms, | |||
added_ts=ts_added_ms, | |||
) | |||
for row in rows | |||
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows | |||
} |
@@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]: | |||
rows = await self.db_pool.simple_select_list( | |||
"local_media_repository_thumbnails", | |||
{"media_id": media_id}, | |||
( | |||
"thumbnail_width", | |||
"thumbnail_height", | |||
"thumbnail_method", | |||
"thumbnail_type", | |||
"thumbnail_length", | |||
rows = cast( | |||
List[Tuple[int, int, str, str, int]], | |||
await self.db_pool.simple_select_list( | |||
"local_media_repository_thumbnails", | |||
{"media_id": media_id}, | |||
( | |||
"thumbnail_width", | |||
"thumbnail_height", | |||
"thumbnail_method", | |||
"thumbnail_type", | |||
"thumbnail_length", | |||
), | |||
desc="get_local_media_thumbnails", | |||
), | |||
desc="get_local_media_thumbnails", | |||
) | |||
return [ | |||
ThumbnailInfo( | |||
width=row["thumbnail_width"], | |||
height=row["thumbnail_height"], | |||
method=row["thumbnail_method"], | |||
type=row["thumbnail_type"], | |||
length=row["thumbnail_length"], | |||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] | |||
) | |||
for row in rows | |||
] | |||
@@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
async def get_remote_media_thumbnails( | |||
self, origin: str, media_id: str | |||
) -> List[ThumbnailInfo]: | |||
rows = await self.db_pool.simple_select_list( | |||
"remote_media_cache_thumbnails", | |||
{"media_origin": origin, "media_id": media_id}, | |||
( | |||
"thumbnail_width", | |||
"thumbnail_height", | |||
"thumbnail_method", | |||
"thumbnail_type", | |||
"thumbnail_length", | |||
rows = cast( | |||
List[Tuple[int, int, str, str, int]], | |||
await self.db_pool.simple_select_list( | |||
"remote_media_cache_thumbnails", | |||
{"media_origin": origin, "media_id": media_id}, | |||
( | |||
"thumbnail_width", | |||
"thumbnail_height", | |||
"thumbnail_method", | |||
"thumbnail_type", | |||
"thumbnail_length", | |||
), | |||
desc="get_remote_media_thumbnails", | |||
), | |||
desc="get_remote_media_thumbnails", | |||
) | |||
return [ | |||
ThumbnailInfo( | |||
width=row["thumbnail_width"], | |||
height=row["thumbnail_height"], | |||
method=row["thumbnail_method"], | |||
type=row["thumbnail_type"], | |||
length=row["thumbnail_length"], | |||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] | |||
) | |||
for row in rows | |||
] | |||
@@ -179,46 +179,44 @@ class PushRulesWorkerStore( | |||
@cached(max_entries=5000) | |||
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: | |||
rows = await self.db_pool.simple_select_list( | |||
table="push_rules", | |||
keyvalues={"user_name": user_id}, | |||
retcols=( | |||
"user_name", | |||
"rule_id", | |||
"priority_class", | |||
"priority", | |||
"conditions", | |||
"actions", | |||
rows = cast( | |||
List[Tuple[str, int, int, str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="push_rules", | |||
keyvalues={"user_name": user_id}, | |||
retcols=( | |||
"rule_id", | |||
"priority_class", | |||
"priority", | |||
"conditions", | |||
"actions", | |||
), | |||
desc="get_push_rules_for_user", | |||
), | |||
desc="get_push_rules_for_user", | |||
) | |||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) | |||
# Sort by highest priority_class, then highest priority. | |||
rows.sort(key=lambda row: (-int(row[1]), -int(row[2]))) | |||
enabled_map = await self.get_push_rules_enabled_for_user(user_id) | |||
return _load_rules( | |||
[ | |||
( | |||
row["rule_id"], | |||
row["priority_class"], | |||
row["conditions"], | |||
row["actions"], | |||
) | |||
for row in rows | |||
], | |||
[(row[0], row[1], row[3], row[4]) for row in rows], | |||
enabled_map, | |||
self.hs.config.experimental, | |||
) | |||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: | |||
results = await self.db_pool.simple_select_list( | |||
table="push_rules_enable", | |||
keyvalues={"user_name": user_id}, | |||
retcols=("rule_id", "enabled"), | |||
desc="get_push_rules_enabled_for_user", | |||
results = cast( | |||
List[Tuple[str, Optional[Union[int, bool]]]], | |||
await self.db_pool.simple_select_list( | |||
table="push_rules_enable", | |||
keyvalues={"user_name": user_id}, | |||
retcols=("rule_id", "enabled"), | |||
desc="get_push_rules_enabled_for_user", | |||
), | |||
) | |||
return {r["rule_id"]: bool(r["enabled"]) for r in results} | |||
return {r[0]: bool(r[1]) for r in results} | |||
async def have_push_rules_changed_for_user( | |||
self, user_id: str, last_id: int | |||
@@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore): | |||
async def get_throttle_params_by_room( | |||
self, pusher_id: int | |||
) -> Dict[str, ThrottleParams]: | |||
res = await self.db_pool.simple_select_list( | |||
"pusher_throttle", | |||
{"pusher": pusher_id}, | |||
["room_id", "last_sent_ts", "throttle_ms"], | |||
desc="get_throttle_params_by_room", | |||
res = cast( | |||
List[Tuple[str, Optional[int], Optional[int]]], | |||
await self.db_pool.simple_select_list( | |||
"pusher_throttle", | |||
{"pusher": pusher_id}, | |||
["room_id", "last_sent_ts", "throttle_ms"], | |||
desc="get_throttle_params_by_room", | |||
), | |||
) | |||
params_by_room = {} | |||
for row in res: | |||
params_by_room[row["room_id"]] = ThrottleParams( | |||
row["last_sent_ts"], | |||
row["throttle_ms"], | |||
for room_id, last_sent_ts, throttle_ms in res: | |||
params_by_room[room_id] = ThrottleParams( | |||
last_sent_ts or 0, throttle_ms or 0 | |||
) | |||
return params_by_room | |||
@@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
Returns: | |||
Tuples of (auth_provider, external_id) | |||
""" | |||
res = await self.db_pool.simple_select_list( | |||
table="user_external_ids", | |||
keyvalues={"user_id": mxid}, | |||
retcols=("auth_provider", "external_id"), | |||
desc="get_external_ids_by_user", | |||
return cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="user_external_ids", | |||
keyvalues={"user_id": mxid}, | |||
retcols=("auth_provider", "external_id"), | |||
desc="get_external_ids_by_user", | |||
), | |||
) | |||
return [(r["auth_provider"], r["external_id"]) for r in res] | |||
async def count_all_users(self) -> int: | |||
"""Counts all users registered on the homeserver.""" | |||
@@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
) | |||
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: | |||
results = await self.db_pool.simple_select_list( | |||
"user_threepids", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["medium", "address", "validated_at", "added_at"], | |||
desc="user_get_threepids", | |||
results = cast( | |||
List[Tuple[str, str, int, int]], | |||
await self.db_pool.simple_select_list( | |||
"user_threepids", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["medium", "address", "validated_at", "added_at"], | |||
desc="user_get_threepids", | |||
), | |||
) | |||
return [ThreepidResult(**r) for r in results] | |||
return [ | |||
ThreepidResult( | |||
medium=r[0], | |||
address=r[1], | |||
validated_at=r[2], | |||
added_at=r[3], | |||
) | |||
for r in results | |||
] | |||
async def user_delete_threepid( | |||
self, user_id: str, medium: str, address: str | |||
@@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
desc="add_user_bound_threepid", | |||
) | |||
async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: | |||
async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]: | |||
"""Get the threepids that a user has bound to an identity server through the homeserver | |||
The homeserver remembers where binds to an identity server occurred. Using this | |||
method can retrieve those threepids. | |||
@@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
user_id: The ID of the user to retrieve threepids for | |||
Returns: | |||
List of dictionaries containing the following keys: | |||
medium (str): The medium of the threepid (e.g "email") | |||
address (str): The address of the threepid (e.g "bob@example.com") | |||
""" | |||
return await self.db_pool.simple_select_list( | |||
table="user_threepid_id_server", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["medium", "address"], | |||
desc="user_get_bound_threepids", | |||
List of tuples of two strings: | |||
medium: The medium of the threepid (e.g "email") | |||
address: The address of the threepid (e.g "bob@example.com") | |||
""" | |||
return cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="user_threepid_id_server", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["medium", "address"], | |||
desc="user_get_bound_threepids", | |||
), | |||
) | |||
async def remove_user_bound_threepid( | |||
@@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore): | |||
def get_all_relation_ids_for_event_txn( | |||
txn: LoggingTransaction, | |||
) -> List[str]: | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn=txn, | |||
table="event_relations", | |||
keyvalues={"relates_to_id": event_id}, | |||
retcols=["event_id"], | |||
rows = cast( | |||
List[Tuple[str]], | |||
self.db_pool.simple_select_list_txn( | |||
txn=txn, | |||
table="event_relations", | |||
keyvalues={"relates_to_id": event_id}, | |||
retcols=["event_id"], | |||
), | |||
) | |||
return [row["event_id"] for row in rows] | |||
return [row[0] for row in rows] | |||
return await self.db_pool.runInteraction( | |||
desc="get_all_relation_ids_for_event", | |||
@@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
""" | |||
room_servers: Dict[str, PartialStateResyncInfo] = {} | |||
rows = await self.db_pool.simple_select_list( | |||
table="partial_state_rooms", | |||
keyvalues={}, | |||
retcols=("room_id", "joined_via"), | |||
desc="get_server_which_served_partial_join", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="partial_state_rooms", | |||
keyvalues={}, | |||
retcols=("room_id", "joined_via"), | |||
desc="get_server_which_served_partial_join", | |||
), | |||
) | |||
for row in rows: | |||
room_id = row["room_id"] | |||
joined_via = row["joined_via"] | |||
for room_id, joined_via in rows: | |||
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) | |||
rows = await self.db_pool.simple_select_list( | |||
"partial_state_rooms_servers", | |||
keyvalues=None, | |||
retcols=("room_id", "server_name"), | |||
desc="get_partial_state_rooms", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
"partial_state_rooms_servers", | |||
keyvalues=None, | |||
retcols=("room_id", "server_name"), | |||
desc="get_partial_state_rooms", | |||
), | |||
) | |||
for row in rows: | |||
room_id = row["room_id"] | |||
server_name = row["server_name"] | |||
for room_id, server_name in rows: | |||
entry = room_servers.get(room_id) | |||
if entry is None: | |||
# There is a foreign key constraint which enforces that every room_id in | |||
@@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
for fully-joined rooms. | |||
""" | |||
rows = await self.db_pool.simple_select_list( | |||
"current_state_events", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("event_id", "membership"), | |||
desc="has_completed_background_updates", | |||
rows = cast( | |||
List[Tuple[str, Optional[str]]], | |||
await self.db_pool.simple_select_list( | |||
"current_state_events", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("event_id", "membership"), | |||
desc="has_completed_background_updates", | |||
), | |||
) | |||
return {row["event_id"]: row["membership"] for row in rows} | |||
return dict(rows) | |||
# TODO This returns a mutable object, which is generally confusing when using a cache. | |||
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable] | |||
@@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
tag content. | |||
""" | |||
rows = await self.db_pool.simple_select_list( | |||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] | |||
rows = cast( | |||
List[Tuple[str, str, str]], | |||
await self.db_pool.simple_select_list( | |||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] | |||
), | |||
) | |||
tags_by_room: Dict[str, Dict[str, JsonDict]] = {} | |||
for row in rows: | |||
room_tags = tags_by_room.setdefault(row["room_id"], {}) | |||
room_tags[row["tag"]] = db_to_json(row["content"]) | |||
for room_id, tag, content in rows: | |||
room_tags = tags_by_room.setdefault(room_id, {}) | |||
room_tags[tag] = db_to_json(content) | |||
return tags_by_room | |||
async def get_all_updated_tags( | |||
@@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
Returns: | |||
A mapping of tags to tag content. | |||
""" | |||
rows = await self.db_pool.simple_select_list( | |||
table="room_tags", | |||
keyvalues={"user_id": user_id, "room_id": room_id}, | |||
retcols=("tag", "content"), | |||
desc="get_tags_for_room", | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="room_tags", | |||
keyvalues={"user_id": user_id, "room_id": room_id}, | |||
retcols=("tag", "content"), | |||
desc="get_tags_for_room", | |||
), | |||
) | |||
return {row["tag"]: db_to_json(row["content"]) for row in rows} | |||
return {tag: db_to_json(content) for tag, content in rows} | |||
async def add_tag_to_room( | |||
self, user_id: str, room_id: str, tag: str, content: JsonDict | |||
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
that auth-type. | |||
""" | |||
results = {} | |||
for row in await self.db_pool.simple_select_list( | |||
table="ui_auth_sessions_credentials", | |||
keyvalues={"session_id": session_id}, | |||
retcols=("stage_type", "result"), | |||
desc="get_completed_ui_auth_stages", | |||
): | |||
results[row["stage_type"]] = db_to_json(row["result"]) | |||
rows = cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="ui_auth_sessions_credentials", | |||
keyvalues={"session_id": session_id}, | |||
retcols=("stage_type", "result"), | |||
desc="get_completed_ui_auth_stages", | |||
), | |||
) | |||
for stage_type, result in rows: | |||
results[stage_type] = db_to_json(result) | |||
return results | |||
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
Returns: | |||
List of user_agent/ip pairs | |||
""" | |||
rows = await self.db_pool.simple_select_list( | |||
table="ui_auth_sessions_ips", | |||
keyvalues={"session_id": session_id}, | |||
retcols=("user_agent", "ip"), | |||
desc="get_user_agents_ips_to_ui_auth_session", | |||
return cast( | |||
List[Tuple[str, str]], | |||
await self.db_pool.simple_select_list( | |||
table="ui_auth_sessions_ips", | |||
keyvalues={"session_id": session_id}, | |||
retcols=("user_agent", "ip"), | |||
desc="get_user_agents_ips_to_ui_auth_session", | |||
), | |||
) | |||
return [(row["user_agent"], row["ip"]) for row in rows] | |||
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: | |||
""" | |||
@@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
if not prev_group: | |||
return _GetStateGroupDelta(None, None) | |||
delta_ids = self.db_pool.simple_select_list_txn( | |||
txn, | |||
table="state_groups_state", | |||
keyvalues={"state_group": state_group}, | |||
retcols=("type", "state_key", "event_id"), | |||
delta_ids = cast( | |||
List[Tuple[str, str, str]], | |||
self.db_pool.simple_select_list_txn( | |||
txn, | |||
table="state_groups_state", | |||
keyvalues={"state_group": state_group}, | |||
retcols=("type", "state_key", "event_id"), | |||
), | |||
) | |||
return _GetStateGroupDelta( | |||
prev_group, | |||
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, | |||
{ | |||
(event_type, state_key): event_id | |||
for event_type, state_key, event_id in delta_ids | |||
}, | |||
) | |||
return await self.db_pool.runInteraction( | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict, List, Optional | |||
from typing import Any, Dict, List, Optional, Tuple, cast | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
) | |||
) | |||
async def get_all_room_state(self) -> List[Dict[str, Any]]: | |||
return await self.store.db_pool.simple_select_list( | |||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias") | |||
async def get_all_room_state(self) -> List[Optional[str]]: | |||
rows = cast( | |||
List[Tuple[Optional[str]]], | |||
await self.store.db_pool.simple_select_list( | |||
"room_stats_state", None, retcols=("topic",) | |||
), | |||
) | |||
return [r[0] for r in rows] | |||
def _get_current_stats( | |||
self, stats_type: str, stat_id: str | |||
@@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r = self.get_success(self.get_all_room_state()) | |||
self.assertEqual(len(r), 1) | |||
self.assertEqual(r[0]["topic"], "foo") | |||
self.assertEqual(r[0], "foo") | |||
def test_create_user(self) -> None: | |||
""" | |||
@@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
if expected_row is not None: | |||
columns += expected_row.keys() | |||
rows = self.get_success( | |||
row_tuples = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table=table, | |||
keyvalues={ | |||
@@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||
if expected_row is not None: | |||
self.assertEqual( | |||
len(rows), | |||
len(row_tuples), | |||
1, | |||
f"Background update did not leave behind latest receipt in {table}", | |||
) | |||
self.assertEqual( | |||
rows[0], | |||
{ | |||
"room_id": room_id, | |||
"receipt_type": receipt_type, | |||
"user_id": user_id, | |||
**expected_row, | |||
}, | |||
row_tuples[0], | |||
( | |||
room_id, | |||
receipt_type, | |||
user_id, | |||
*expected_row.values(), | |||
), | |||
) | |||
else: | |||
self.assertEqual( | |||
len(rows), | |||
len(row_tuples), | |||
0, | |||
f"Background update did not remove all duplicate receipts from {table}", | |||
) | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
import secrets | |||
from typing import Generator, Tuple | |||
from typing import Generator, List, Tuple, cast | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): | |||
) | |||
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: | |||
res = self.get_success( | |||
self.storage.db_pool.simple_select_list( | |||
self.table_name, None, ["id, username, value"] | |||
) | |||
yield from cast( | |||
List[Tuple[int, str, str]], | |||
self.get_success( | |||
self.storage.db_pool.simple_select_list( | |||
self.table_name, None, ["id, username, value"] | |||
) | |||
), | |||
) | |||
for i in res: | |||
yield (i["id"], i["username"], i["value"]) | |||
def test_upsert_many(self) -> None: | |||
""" | |||
Upsert_many will perform the upsert operation across a batch of data. | |||
@@ -12,6 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import List, Tuple, cast | |||
from unittest.mock import AsyncMock, Mock | |||
import yaml | |||
@@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): | |||
self.wait_for_background_updates() | |||
# Check the correct values are in the new table. | |||
rows = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="test_constraint", | |||
keyvalues={}, | |||
retcols=("a", "b"), | |||
) | |||
rows = cast( | |||
List[Tuple[int, int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="test_constraint", | |||
keyvalues={}, | |||
retcols=("a", "b"), | |||
) | |||
), | |||
) | |||
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) | |||
self.assertCountEqual(rows, [(1, 1), (3, 3)]) | |||
# And check that invalid rows get correctly rejected. | |||
self.get_failure( | |||
@@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): | |||
self.wait_for_background_updates() | |||
# Check the correct values are in the new table. | |||
rows = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="test_constraint", | |||
keyvalues={}, | |||
retcols=("a", "b"), | |||
) | |||
rows = cast( | |||
List[Tuple[int, int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="test_constraint", | |||
keyvalues={}, | |||
retcols=("a", "b"), | |||
) | |||
), | |||
) | |||
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) | |||
self.assertCountEqual(rows, [(1, 1), (3, 3)]) | |||
# And check that invalid rows get correctly rejected. | |||
self.get_failure( | |||
@@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: | |||
self.mock_txn.rowcount = 3 | |||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) | |||
self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)] | |||
self.mock_txn.description = (("colA", None, None, None, None, None, None),) | |||
ret = yield defer.ensureDeferred( | |||
@@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
) | |||
self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) | |||
self.assertEqual([(1,), (2,), (3,)], ret) | |||
self.mock_txn.execute.assert_called_with( | |||
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"] | |||
) | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict | |||
from typing import Any, Dict, List, Optional, Tuple, cast | |||
from unittest.mock import AsyncMock | |||
from parameterized import parameterized | |||
@@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.reactor.advance(200) | |||
self.pump(0) | |||
result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], | |||
desc="get_user_ip_and_agents", | |||
) | |||
result = cast( | |||
List[Tuple[str, str, str, Optional[str], int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=[ | |||
"access_token", | |||
"ip", | |||
"user_agent", | |||
"device_id", | |||
"last_seen", | |||
], | |||
desc="get_user_ip_and_agents", | |||
) | |||
), | |||
) | |||
self.assertEqual( | |||
result, | |||
[ | |||
{ | |||
"access_token": "access_token", | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"device_id": None, | |||
"last_seen": 12345678000, | |||
} | |||
], | |||
result, [("access_token", "ip", "user_agent", None, 12345678000)] | |||
) | |||
# Add another & trigger the storage loop | |||
@@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.reactor.advance(10) | |||
self.pump(0) | |||
result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], | |||
desc="get_user_ip_and_agents", | |||
) | |||
result = cast( | |||
List[Tuple[str, str, str, Optional[str], int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=[ | |||
"access_token", | |||
"ip", | |||
"user_agent", | |||
"device_id", | |||
"last_seen", | |||
], | |||
desc="get_user_ip_and_agents", | |||
) | |||
), | |||
) | |||
# Only one result, has been upserted. | |||
self.assertEqual( | |||
result, | |||
[ | |||
{ | |||
"access_token": "access_token", | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"device_id": None, | |||
"last_seen": 12345878000, | |||
} | |||
], | |||
result, [("access_token", "ip", "user_agent", None, 12345878000)] | |||
) | |||
@parameterized.expand([(False,), (True,)]) | |||
@@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.reactor.advance(10) | |||
else: | |||
# Check that the new IP and user agent has not been stored yet | |||
db_result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={}, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
db_result = cast( | |||
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={}, | |||
retcols=( | |||
"user_id", | |||
"ip", | |||
"user_agent", | |||
"device_id", | |||
"last_seen", | |||
), | |||
), | |||
), | |||
) | |||
self.assertEqual( | |||
db_result, | |||
[ | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"ip": None, | |||
"user_agent": None, | |||
"last_seen": None, | |||
}, | |||
], | |||
) | |||
self.assertEqual(db_result, [(user_id, None, None, device_id, None)]) | |||
result = self.get_success( | |||
self.store.get_last_client_ip_by_device(user_id, device_id) | |||
@@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Check that the new IP and user agent has not been stored yet | |||
db_result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={}, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
db_result = cast( | |||
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="devices", | |||
keyvalues={}, | |||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), | |||
), | |||
), | |||
) | |||
self.assertCountEqual( | |||
db_result, | |||
[ | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id_1, | |||
"ip": "ip_1", | |||
"user_agent": "user_agent_1", | |||
"last_seen": 12345678000, | |||
}, | |||
{ | |||
"user_id": user_id, | |||
"device_id": device_id_2, | |||
"ip": "ip_2", | |||
"user_agent": "user_agent_2", | |||
"last_seen": 12345678000, | |||
}, | |||
(user_id, "ip_1", "user_agent_1", device_id_1, 12345678000), | |||
(user_id, "ip_2", "user_agent_2", device_id_2, 12345678000), | |||
], | |||
) | |||
@@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Check that the new IP and user agent has not been stored yet | |||
db_result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={}, | |||
retcols=("access_token", "ip", "user_agent", "last_seen"), | |||
db_result = cast( | |||
List[Tuple[str, str, str, int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={}, | |||
retcols=("access_token", "ip", "user_agent", "last_seen"), | |||
), | |||
), | |||
) | |||
self.assertEqual( | |||
db_result, | |||
[ | |||
{ | |||
"access_token": "access_token", | |||
"ip": "ip_1", | |||
"user_agent": "user_agent_1", | |||
"last_seen": 12345678000, | |||
}, | |||
{ | |||
"access_token": "access_token", | |||
"ip": "ip_2", | |||
"user_agent": "user_agent_2", | |||
"last_seen": 12345678000, | |||
}, | |||
("access_token", "ip_1", "user_agent_1", 12345678000), | |||
("access_token", "ip_2", "user_agent_2", 12345678000), | |||
], | |||
) | |||
@@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.reactor.advance(200) | |||
# We should see that in the DB | |||
result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], | |||
desc="get_user_ip_and_agents", | |||
) | |||
result = cast( | |||
List[Tuple[str, str, str, Optional[str], int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=[ | |||
"access_token", | |||
"ip", | |||
"user_agent", | |||
"device_id", | |||
"last_seen", | |||
], | |||
desc="get_user_ip_and_agents", | |||
) | |||
), | |||
) | |||
self.assertEqual( | |||
result, | |||
[ | |||
{ | |||
"access_token": "access_token", | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"device_id": device_id, | |||
"last_seen": 0, | |||
} | |||
], | |||
[("access_token", "ip", "user_agent", device_id, 0)], | |||
) | |||
# Now advance by a couple of months | |||
self.reactor.advance(60 * 24 * 60 * 60) | |||
# We should get no results. | |||
result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], | |||
desc="get_user_ip_and_agents", | |||
) | |||
result = cast( | |||
List[Tuple[str, str, str, Optional[str], int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcols=[ | |||
"access_token", | |||
"ip", | |||
"user_agent", | |||
"device_id", | |||
"last_seen", | |||
], | |||
desc="get_user_ip_and_agents", | |||
) | |||
), | |||
) | |||
self.assertEqual(result, []) | |||
@@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
self.reactor.advance(200) | |||
# We should see that in the DB | |||
result = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={}, | |||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], | |||
desc="get_user_ip_and_agents", | |||
) | |||
result = cast( | |||
List[Tuple[str, str, str, Optional[str], int]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="user_ips", | |||
keyvalues={}, | |||
retcols=[ | |||
"access_token", | |||
"ip", | |||
"user_agent", | |||
"device_id", | |||
"last_seen", | |||
], | |||
desc="get_user_ip_and_agents", | |||
) | |||
), | |||
) | |||
# ensure user1 is filtered out | |||
self.assertEqual( | |||
result, | |||
[ | |||
{ | |||
"access_token": access_token2, | |||
"ip": "ip", | |||
"user_agent": "user_agent", | |||
"device_id": device_id2, | |||
"last_seen": 0, | |||
} | |||
], | |||
) | |||
self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)]) | |||
class ClientIpAuthTestCase(unittest.HomeserverTestCase): | |||
@@ -12,6 +12,8 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import List, Optional, Tuple, cast | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import Membership | |||
@@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): | |||
def test__null_byte_in_display_name_properly_handled(self) -> None: | |||
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) | |||
res = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
"room_memberships", | |||
{"user_id": "@alice:test"}, | |||
["display_name", "event_id"], | |||
) | |||
res = cast( | |||
List[Tuple[Optional[str], str]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
"room_memberships", | |||
{"user_id": "@alice:test"}, | |||
["display_name", "event_id"], | |||
) | |||
), | |||
) | |||
# Check that we only got one result back | |||
self.assertEqual(len(res), 1) | |||
# Check that alice's display name is "alice" | |||
self.assertEqual(res[0]["display_name"], "alice") | |||
self.assertEqual(res[0][0], "alice") | |||
# Grab the event_id to use later | |||
event_id = res[0]["event_id"] | |||
event_id = res[0][1] | |||
# Create a profile with the offending null byte in the display name | |||
new_profile = {"displayname": "ali\u0000ce"} | |||
@@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): | |||
tok=self.t_alice, | |||
) | |||
res2 = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
"room_memberships", | |||
{"user_id": "@alice:test"}, | |||
["display_name", "event_id"], | |||
) | |||
res2 = cast( | |||
List[Tuple[Optional[str], str]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
"room_memberships", | |||
{"user_id": "@alice:test"}, | |||
["display_name", "event_id"], | |||
) | |||
), | |||
) | |||
# Check that we only have two results | |||
self.assertEqual(len(res2), 2) | |||
# Filter out the previous event using the event_id we grabbed above | |||
row = [row for row in res2 if row["event_id"] != event_id] | |||
row = [row for row in res2 if row[1] != event_id] | |||
# Check that alice's display name is now None | |||
self.assertEqual(row[0]["display_name"], None) | |||
self.assertIsNone(row[0][0]) | |||
def test_room_is_locally_forgotten(self) -> None: | |||
"""Test that when the last local user has forgotten a room it is known as forgotten.""" | |||
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import List, Tuple, cast | |||
from immutabledict import immutabledict | |||
@@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase): | |||
) | |||
# check that only state events are in state_groups, and all state events are in state_groups | |||
res = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="state_groups", | |||
keyvalues=None, | |||
retcols=("event_id",), | |||
) | |||
res = cast( | |||
List[Tuple[str]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="state_groups", | |||
keyvalues=None, | |||
retcols=("event_id",), | |||
) | |||
), | |||
) | |||
events = [] | |||
for result in res: | |||
self.assertNotIn(event3.event_id, result) | |||
events.append(result.get("event_id")) | |||
self.assertNotIn(event3.event_id, result) # XXX | |||
events.append(result[0]) | |||
for event, _ in processed_events_and_context: | |||
if event.is_state(): | |||
@@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase): | |||
# has an entry and prev event in state_group_edges | |||
for event, context in processed_events_and_context: | |||
if event.is_state(): | |||
state = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="state_groups_state", | |||
keyvalues={"state_group": context.state_group_after_event}, | |||
retcols=("type", "state_key"), | |||
) | |||
) | |||
self.assertEqual(event.type, state[0].get("type")) | |||
self.assertEqual(event.state_key, state[0].get("state_key")) | |||
groups = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="state_group_edges", | |||
keyvalues={"state_group": str(context.state_group_after_event)}, | |||
retcols=("*",), | |||
) | |||
state = cast( | |||
List[Tuple[str, str]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="state_groups_state", | |||
keyvalues={"state_group": context.state_group_after_event}, | |||
retcols=("type", "state_key"), | |||
) | |||
), | |||
) | |||
self.assertEqual( | |||
context.state_group_before_event, groups[0].get("prev_state_group") | |||
self.assertEqual(event.type, state[0][0]) | |||
self.assertEqual(event.state_key, state[0][1]) | |||
groups = cast( | |||
List[Tuple[str]], | |||
self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="state_group_edges", | |||
keyvalues={ | |||
"state_group": str(context.state_group_after_event) | |||
}, | |||
retcols=("prev_state_group",), | |||
) | |||
), | |||
) | |||
self.assertEqual(context.state_group_before_event, groups[0][0]) |
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import re | |||
from typing import Any, Dict, Set, Tuple | |||
from typing import Any, Dict, List, Optional, Set, Tuple, cast | |||
from unittest import mock | |||
from unittest.mock import Mock, patch | |||
@@ -62,14 +62,13 @@ class GetUserDirectoryTables: | |||
Returns a list of tuples (user_id, room_id) where room_id is public and | |||
contains the user with the given id. | |||
""" | |||
r = await self.store.db_pool.simple_select_list( | |||
"users_in_public_rooms", None, ("user_id", "room_id") | |||
r = cast( | |||
List[Tuple[str, str]], | |||
await self.store.db_pool.simple_select_list( | |||
"users_in_public_rooms", None, ("user_id", "room_id") | |||
), | |||
) | |||
retval = set() | |||
for i in r: | |||
retval.add((i["user_id"], i["room_id"])) | |||
return retval | |||
return set(r) | |||
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: | |||
"""Fetch the entire `users_who_share_private_rooms` table. | |||
@@ -78,27 +77,30 @@ class GetUserDirectoryTables: | |||
to the rows of `users_who_share_private_rooms`. | |||
""" | |||
rows = await self.store.db_pool.simple_select_list( | |||
"users_who_share_private_rooms", | |||
None, | |||
["user_id", "other_user_id", "room_id"], | |||
rows = cast( | |||
List[Tuple[str, str, str]], | |||
await self.store.db_pool.simple_select_list( | |||
"users_who_share_private_rooms", | |||
None, | |||
["user_id", "other_user_id", "room_id"], | |||
), | |||
) | |||
rv = set() | |||
for row in rows: | |||
rv.add((row["user_id"], row["other_user_id"], row["room_id"])) | |||
return rv | |||
return set(rows) | |||
async def get_users_in_user_directory(self) -> Set[str]: | |||
"""Fetch the set of users in the `user_directory` table. | |||
This is useful when checking we've correctly excluded users from the directory. | |||
""" | |||
result = await self.store.db_pool.simple_select_list( | |||
"user_directory", | |||
None, | |||
["user_id"], | |||
result = cast( | |||
List[Tuple[str]], | |||
await self.store.db_pool.simple_select_list( | |||
"user_directory", | |||
None, | |||
["user_id"], | |||
), | |||
) | |||
return {row["user_id"] for row in result} | |||
return {row[0] for row in result} | |||
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]: | |||
"""Fetch users and their profiles from the `user_directory` table. | |||
@@ -107,16 +109,17 @@ class GetUserDirectoryTables: | |||
It's almost the entire contents of the `user_directory` table: the only | |||
thing missing is an unused room_id column. | |||
""" | |||
rows = await self.store.db_pool.simple_select_list( | |||
"user_directory", | |||
None, | |||
("user_id", "display_name", "avatar_url"), | |||
rows = cast( | |||
List[Tuple[str, Optional[str], Optional[str]]], | |||
await self.store.db_pool.simple_select_list( | |||
"user_directory", | |||
None, | |||
("user_id", "display_name", "avatar_url"), | |||
), | |||
) | |||
return { | |||
row["user_id"]: ProfileInfo( | |||
display_name=row["display_name"], avatar_url=row["avatar_url"] | |||
) | |||
for row in rows | |||
user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url) | |||
for user_id, display_name, avatar_url in rows | |||
} | |||
async def get_tables( | |||