@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -348,8 +348,7 @@ class Porter: | |||
backward_chunk = 0 | |||
already_ported = 0 | |||
else: | |||
forward_chunk = row["forward_rowid"] | |||
backward_chunk = row["backward_rowid"] | |||
forward_chunk, backward_chunk = row | |||
if total_to_port is None: | |||
already_ported, total_to_port = await self._get_total_count_to_port( | |||
@@ -269,7 +269,7 @@ class RoomCreationHandler: | |||
self, | |||
requester: Requester, | |||
old_room_id: str, | |||
old_room: Dict[str, Any], | |||
old_room: Tuple[bool, str, bool], | |||
new_room_id: str, | |||
new_version: RoomVersion, | |||
tombstone_event: EventBase, | |||
@@ -279,7 +279,7 @@ class RoomCreationHandler: | |||
Args: | |||
requester: the user requesting the upgrade | |||
old_room_id: the id of the room to be replaced | |||
old_room: a dict containing room information for the room to be replaced, | |||
old_room: a tuple containing room information for the room to be replaced, | |||
as returned by `RoomWorkerStore.get_room`. | |||
new_room_id: the id of the replacement room | |||
new_version: the version to upgrade the room to | |||
@@ -299,7 +299,7 @@ class RoomCreationHandler: | |||
await self.store.store_room( | |||
room_id=new_room_id, | |||
room_creator_user_id=user_id, | |||
is_public=old_room["is_public"], | |||
is_public=old_room[0], | |||
room_version=new_version, | |||
) | |||
@@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
# Add new room to the room directory if the old room was there | |||
# Remove old room from the room directory | |||
old_room = await self.store.get_room(old_room_id) | |||
if old_room is not None and old_room["is_public"]: | |||
# If the old room exists and is public. | |||
if old_room is not None and old_room[0]: | |||
await self.store.set_room_is_public(old_room_id, False) | |||
await self.store.set_room_is_public(room_id, True) | |||
@@ -1860,7 +1860,8 @@ class PublicRoomListManager: | |||
if not room: | |||
return False | |||
return room.get("is_public", False) | |||
# The first item is whether the room is public. | |||
return room[0] | |||
async def add_room_to_public_room_list(self, room_id: str) -> None: | |||
"""Publishes a room to the public room list. | |||
@@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet): | |||
) -> Tuple[int, JsonDict]: | |||
await assert_requester_is_admin(self.auth, request) | |||
ret = await self.store.get_room(room_id) | |||
if not ret: | |||
room = await self.store.get_room(room_id) | |||
if not room: | |||
raise NotFoundError("Room not found") | |||
members = await self.store.get_users_in_room(room_id) | |||
@@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet): | |||
) -> Tuple[int, JsonDict]: | |||
await assert_requester_is_admin(self.auth, request) | |||
ret = await self.store.get_room(room_id) | |||
if not ret: | |||
room = await self.store.get_room(room_id) | |||
if not room: | |||
raise NotFoundError("Room not found") | |||
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) | |||
@@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet): | |||
if room is None: | |||
raise NotFoundError("Unknown room") | |||
return 200, {"visibility": "public" if room["is_public"] else "private"} | |||
return 200, {"visibility": "public" if room[0] else "private"} | |||
class PutBody(RequestBodyModel): | |||
visibility: Literal["public", "private"] = "public" | |||
@@ -1597,7 +1597,7 @@ class DatabasePool: | |||
retcols: Collection[str], | |||
allow_none: Literal[False] = False, | |||
desc: str = "simple_select_one", | |||
) -> Dict[str, Any]: | |||
) -> Tuple[Any, ...]: | |||
... | |||
@overload | |||
@@ -1608,7 +1608,7 @@ class DatabasePool: | |||
retcols: Collection[str], | |||
allow_none: Literal[True] = True, | |||
desc: str = "simple_select_one", | |||
) -> Optional[Dict[str, Any]]: | |||
) -> Optional[Tuple[Any, ...]]: | |||
... | |||
async def simple_select_one( | |||
@@ -1618,7 +1618,7 @@ class DatabasePool: | |||
retcols: Collection[str], | |||
allow_none: bool = False, | |||
desc: str = "simple_select_one", | |||
) -> Optional[Dict[str, Any]]: | |||
) -> Optional[Tuple[Any, ...]]: | |||
"""Executes a SELECT query on the named table, which is expected to | |||
return a single row, returning multiple columns from it. | |||
@@ -2127,7 +2127,7 @@ class DatabasePool: | |||
keyvalues: Dict[str, Any], | |||
retcols: Collection[str], | |||
allow_none: bool = False, | |||
) -> Optional[Dict[str, Any]]: | |||
) -> Optional[Tuple[Any, ...]]: | |||
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) | |||
if keyvalues: | |||
@@ -2145,7 +2145,7 @@ class DatabasePool: | |||
if txn.rowcount > 1: | |||
raise StoreError(500, "More than one row matched (%s)" % (table,)) | |||
return dict(zip(retcols, row)) | |||
return row | |||
async def simple_delete_one( | |||
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" | |||
@@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
A dict containing the device information, or `None` if the device does not | |||
exist. | |||
""" | |||
return await self.db_pool.simple_select_one( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, | |||
retcols=("user_id", "device_id", "display_name"), | |||
desc="get_device", | |||
allow_none=True, | |||
) | |||
async def get_device_opt( | |||
self, user_id: str, device_id: str | |||
) -> Optional[Dict[str, Any]]: | |||
"""Retrieve a device. Only returns devices that are not marked as | |||
hidden. | |||
Args: | |||
user_id: The ID of the user which owns the device | |||
device_id: The ID of the device to retrieve | |||
Returns: | |||
A dict containing the device information, or None if the device does not exist. | |||
""" | |||
return await self.db_pool.simple_select_one( | |||
row = await self.db_pool.simple_select_one( | |||
table="devices", | |||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, | |||
retcols=("user_id", "device_id", "display_name"), | |||
desc="get_device", | |||
allow_none=True, | |||
) | |||
if row is None: | |||
return None | |||
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]} | |||
async def get_devices_by_user( | |||
self, user_id: str | |||
@@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
retcols=["device_id", "device_data"], | |||
allow_none=True, | |||
) | |||
return ( | |||
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None | |||
) | |||
return (row[0], json_decoder.decode(row[1])) if row else None | |||
def _store_dehydrated_device_txn( | |||
self, | |||
@@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
`FALSE` have not been converted. | |||
""" | |||
row = await self.db_pool.simple_select_one( | |||
table="device_lists_changes_converted_stream_position", | |||
keyvalues={}, | |||
retcols=["stream_id", "room_id"], | |||
desc="get_device_change_last_converted_pos", | |||
return cast( | |||
Tuple[int, str], | |||
await self.db_pool.simple_select_one( | |||
table="device_lists_changes_converted_stream_position", | |||
keyvalues={}, | |||
retcols=["stream_id", "room_id"], | |||
desc="get_device_change_last_converted_pos", | |||
), | |||
) | |||
return row["stream_id"], row["room_id"] | |||
async def set_device_change_last_converted_pos( | |||
self, | |||
@@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): | |||
# it isn't there. | |||
raise StoreError(404, "No backup with that version exists") | |||
result = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="e2e_room_keys_versions", | |||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, | |||
retcols=("version", "algorithm", "auth_data", "etag"), | |||
allow_none=False, | |||
row = cast( | |||
Tuple[int, str, str, Optional[int]], | |||
self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="e2e_room_keys_versions", | |||
keyvalues={ | |||
"user_id": user_id, | |||
"version": this_version, | |||
"deleted": 0, | |||
}, | |||
retcols=("version", "algorithm", "auth_data", "etag"), | |||
allow_none=False, | |||
), | |||
) | |||
assert result is not None # see comment on `simple_select_one_txn` | |||
result["auth_data"] = db_to_json(result["auth_data"]) | |||
result["version"] = str(result["version"]) | |||
if result["etag"] is None: | |||
result["etag"] = 0 | |||
return result | |||
return { | |||
"auth_data": db_to_json(row[2]), | |||
"version": str(row[0]), | |||
"algorithm": row[1], | |||
"etag": 0 if row[3] is None else row[3], | |||
} | |||
return await self.db_pool.runInteraction( | |||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn | |||
@@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
if row is None: | |||
continue | |||
key_id = row["key_id"] | |||
key_json = row["key_json"] | |||
used = row["used"] | |||
key_id, key_json, used = row | |||
# Mark fallback key as used if not already. | |||
if not used and mark_as_used: | |||
@@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
# Check if we have indexed the room so we can use the chain cover | |||
# algorithm. | |||
room = await self.get_room(room_id) # type: ignore[attr-defined] | |||
if room["has_auth_chain_index"]: | |||
# If the room has an auth chain index. | |||
if room[1]: | |||
try: | |||
return await self.db_pool.runInteraction( | |||
"get_auth_chain_ids_chains", | |||
@@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
# Check if we have indexed the room so we can use the chain cover | |||
# algorithm. | |||
room = await self.get_room(room_id) # type: ignore[attr-defined] | |||
if room["has_auth_chain_index"]: | |||
# If the room has an auth chain index. | |||
if room[1]: | |||
try: | |||
return await self.db_pool.runInteraction( | |||
"get_auth_chain_difference_chains", | |||
@@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
) | |||
if event_lookup_result is not None: | |||
event_type, depth, stream_ordering = event_lookup_result | |||
logger.debug( | |||
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", | |||
room_id, | |||
seed_event_id, | |||
event_lookup_result["depth"], | |||
event_lookup_result["stream_ordering"], | |||
event_lookup_result["type"], | |||
depth, | |||
stream_ordering, | |||
event_type, | |||
) | |||
if event_lookup_result["depth"]: | |||
queue.put( | |||
( | |||
-event_lookup_result["depth"], | |||
-event_lookup_result["stream_ordering"], | |||
seed_event_id, | |||
event_lookup_result["type"], | |||
) | |||
) | |||
if depth: | |||
queue.put((-depth, -stream_ordering, seed_event_id, event_type)) | |||
while not queue.empty() and len(event_id_results) < limit: | |||
try: | |||
@@ -1934,8 +1934,7 @@ class PersistEventsStore: | |||
if row is None: | |||
return | |||
redacted_relates_to = row["relates_to_id"] | |||
rel_type = row["relation_type"] | |||
redacted_relates_to, rel_type = row | |||
self.db_pool.simple_delete_txn( | |||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} | |||
) | |||
@@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
if not res: | |||
raise SynapseError(404, "Could not find event %s" % (event_id,)) | |||
return int(res["topological_ordering"]), int(res["stream_ordering"]) | |||
return int(res[0]), int(res[1]) | |||
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: | |||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry | |||
@@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
if row is None: | |||
return None | |||
return LocalMedia(media_id=media_id, **row) | |||
return LocalMedia( | |||
media_id=media_id, | |||
media_type=row[0], | |||
media_length=row[1], | |||
upload_name=row[2], | |||
created_ts=row[3], | |||
quarantined_by=row[4], | |||
url_cache=row[5], | |||
last_access_ts=row[6], | |||
safe_from_quarantine=row[7], | |||
) | |||
async def get_local_media_by_user_paginate( | |||
self, | |||
@@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
if row is None: | |||
return row | |||
return RemoteMedia(media_origin=origin, media_id=media_id, **row) | |||
return RemoteMedia( | |||
media_origin=origin, | |||
media_id=media_id, | |||
media_type=row[0], | |||
media_length=row[1], | |||
upload_name=row[2], | |||
created_ts=row[3], | |||
filesystem_id=row[4], | |||
last_access_ts=row[5], | |||
quarantined_by=row[6], | |||
) | |||
async def store_cached_remote_media( | |||
self, | |||
@@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
if row is None: | |||
return None | |||
return ThumbnailInfo( | |||
width=row["thumbnail_width"], | |||
height=row["thumbnail_height"], | |||
method=row["thumbnail_method"], | |||
type=row["thumbnail_type"], | |||
length=row["thumbnail_length"], | |||
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] | |||
) | |||
@trace | |||
@@ -13,7 +13,6 @@ | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Optional | |||
from synapse.api.errors import StoreError | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
@@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore): | |||
return 50 | |||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: | |||
try: | |||
profile = await self.db_pool.simple_select_one( | |||
table="profiles", | |||
keyvalues={"full_user_id": user_id.to_string()}, | |||
retcols=("displayname", "avatar_url"), | |||
desc="get_profileinfo", | |||
) | |||
except StoreError as e: | |||
if e.code == 404: | |||
# no match | |||
return ProfileInfo(None, None) | |||
else: | |||
raise | |||
return ProfileInfo( | |||
avatar_url=profile["avatar_url"], display_name=profile["displayname"] | |||
profile = await self.db_pool.simple_select_one( | |||
table="profiles", | |||
keyvalues={"full_user_id": user_id.to_string()}, | |||
retcols=("displayname", "avatar_url"), | |||
desc="get_profileinfo", | |||
allow_none=True, | |||
) | |||
if profile is None: | |||
# no match | |||
return ProfileInfo(None, None) | |||
return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) | |||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: | |||
return await self.db_pool.simple_select_one_onecol( | |||
@@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore): | |||
"before/after rule not found: %s" % (relative_to_rule,) | |||
) | |||
base_priority_class = res["priority_class"] | |||
base_rule_priority = res["priority"] | |||
base_priority_class, base_rule_priority = res | |||
if base_priority_class != priority_class: | |||
raise InconsistentRuleException( | |||
@@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
allow_none=True, | |||
) | |||
stream_ordering = int(res["stream_ordering"]) if res else None | |||
rx_ts = res["received_ts"] if res else 0 | |||
stream_ordering = int(res[0]) if res else None | |||
rx_ts = res[1] if res else 0 | |||
# We don't want to clobber receipts for more recent events, so we | |||
# have to compare orderings of existing receipts | |||
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
account timestamp as milliseconds since the epoch. None if the account | |||
has not been renewed using the current token yet. | |||
""" | |||
ret_dict = await self.db_pool.simple_select_one( | |||
table="account_validity", | |||
keyvalues={"renewal_token": renewal_token}, | |||
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], | |||
desc="get_user_from_renewal_token", | |||
) | |||
return ( | |||
ret_dict["user_id"], | |||
ret_dict["expiration_ts_ms"], | |||
ret_dict["token_used_ts_ms"], | |||
return cast( | |||
Tuple[str, int, Optional[int]], | |||
await self.db_pool.simple_select_one( | |||
table="account_validity", | |||
keyvalues={"renewal_token": renewal_token}, | |||
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], | |||
desc="get_user_from_renewal_token", | |||
), | |||
) | |||
async def get_renewal_token_for_user(self, user_id: str) -> str: | |||
@@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
Returns: | |||
user id, or None if no user id/threepid mapping exists | |||
""" | |||
ret = self.db_pool.simple_select_one_txn( | |||
return self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
"user_threepids", | |||
{"medium": medium, "address": address}, | |||
["user_id"], | |||
"user_id", | |||
True, | |||
) | |||
if ret: | |||
return ret["user_id"] | |||
return None | |||
async def user_add_threepid( | |||
self, | |||
@@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
if res is None: | |||
return False | |||
uses_allowed, pending, completed, expiry_time = res | |||
# Check if the token has expired | |||
now = self._clock.time_msec() | |||
if res["expiry_time"] and res["expiry_time"] < now: | |||
if expiry_time and expiry_time < now: | |||
return False | |||
# Check if the token has been used up | |||
if ( | |||
res["uses_allowed"] | |||
and res["pending"] + res["completed"] >= res["uses_allowed"] | |||
): | |||
if uses_allowed and pending + completed >= uses_allowed: | |||
return False | |||
# Otherwise, the token is valid | |||
@@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
# Override type because the return type is only optional if | |||
# allow_none is True, and we don't want mypy throwing errors | |||
# about None not being indexable. | |||
res = cast( | |||
Dict[str, Any], | |||
pending, completed = cast( | |||
Tuple[int, int], | |||
self.db_pool.simple_select_one_txn( | |||
txn, | |||
"registration_tokens", | |||
@@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
"registration_tokens", | |||
keyvalues={"token": token}, | |||
updatevalues={ | |||
"completed": res["completed"] + 1, | |||
"pending": res["pending"] - 1, | |||
"completed": completed + 1, | |||
"pending": pending - 1, | |||
}, | |||
) | |||
@@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
Returns: | |||
A dict, or None if token doesn't exist. | |||
""" | |||
return await self.db_pool.simple_select_one( | |||
row = await self.db_pool.simple_select_one( | |||
"registration_tokens", | |||
keyvalues={"token": token}, | |||
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], | |||
allow_none=True, | |||
desc="get_one_registration_token", | |||
) | |||
if row is None: | |||
return None | |||
return { | |||
"token": row[0], | |||
"uses_allowed": row[1], | |||
"pending": row[2], | |||
"completed": row[3], | |||
"expiry_time": row[4], | |||
} | |||
async def generate_registration_token( | |||
self, length: int, chars: str | |||
@@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
return None | |||
# Get all info about the token so it can be sent in the response | |||
return self.db_pool.simple_select_one_txn( | |||
result = self.db_pool.simple_select_one_txn( | |||
txn, | |||
"registration_tokens", | |||
keyvalues={"token": token}, | |||
@@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
allow_none=True, | |||
) | |||
if result is None: | |||
return result | |||
return { | |||
"token": result[0], | |||
"uses_allowed": result[1], | |||
"pending": result[2], | |||
"completed": result[3], | |||
"expiry_time": result[4], | |||
} | |||
return await self.db_pool.runInteraction( | |||
"update_registration_token", _update_registration_token_txn | |||
) | |||
@@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
keyvalues={"token": token}, | |||
updatevalues={"used_ts": ts}, | |||
) | |||
user_id = values["user_id"] | |||
expiry_ts = values["expiry_ts"] | |||
used_ts = values["used_ts"] | |||
auth_provider_id = values["auth_provider_id"] | |||
auth_provider_session_id = values["auth_provider_session_id"] | |||
( | |||
user_id, | |||
expiry_ts, | |||
used_ts, | |||
auth_provider_id, | |||
auth_provider_session_id, | |||
) = values | |||
# Token was already used | |||
if used_ts is not None: | |||
@@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): | |||
# reason, the next check is on the client secret, which is NOT NULL, | |||
# so we don't have to worry about the client secret matching by | |||
# accident. | |||
row = {"client_secret": None, "validated_at": None} | |||
row = None, None | |||
else: | |||
raise ThreepidValidationError("Unknown session_id") | |||
retrieved_client_secret = row["client_secret"] | |||
validated_at = row["validated_at"] | |||
retrieved_client_secret, validated_at = row | |||
row = self.db_pool.simple_select_one_txn( | |||
txn, | |||
@@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): | |||
raise ThreepidValidationError( | |||
"Validation token not found or has expired" | |||
) | |||
expires = row["expires"] | |||
next_link = row["next_link"] | |||
expires, next_link = row | |||
if retrieved_client_secret != client_secret: | |||
raise ThreepidValidationError( | |||
@@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
logger.error("store_room with room_id=%s failed: %s", room_id, e) | |||
raise StoreError(500, "Problem creating room.") | |||
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: | |||
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]: | |||
"""Retrieve a room. | |||
Args: | |||
room_id: The ID of the room to retrieve. | |||
Returns: | |||
A dict containing the room information, or None if the room is unknown. | |||
A tuple containing the room information: | |||
* True if the room is public | |||
* True if the room has an auth chain index | |||
or None if the room is unknown. | |||
""" | |||
return await self.db_pool.simple_select_one( | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), | |||
desc="get_room", | |||
allow_none=True, | |||
row = cast( | |||
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]], | |||
await self.db_pool.simple_select_one( | |||
table="rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("is_public", "has_auth_chain_index"), | |||
desc="get_room", | |||
allow_none=True, | |||
), | |||
) | |||
if row is None: | |||
return row | |||
return bool(row[0]), bool(row[1]) | |||
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: | |||
"""Retrieve room with statistics. | |||
@@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
) | |||
if row: | |||
return RatelimitOverride( | |||
messages_per_second=row["messages_per_second"], | |||
burst_count=row["burst_count"], | |||
) | |||
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1]) | |||
else: | |||
return None | |||
@@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
join. | |||
""" | |||
result = await self.db_pool.simple_select_one( | |||
table="partial_state_rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("join_event_id", "device_lists_stream_id"), | |||
desc="get_join_event_id_for_partial_state", | |||
return cast( | |||
Tuple[str, int], | |||
await self.db_pool.simple_select_one( | |||
table="partial_state_rooms", | |||
keyvalues={"room_id": room_id}, | |||
retcols=("join_event_id", "device_lists_stream_id"), | |||
desc="get_join_event_id_for_partial_state", | |||
), | |||
) | |||
return result["join_event_id"], result["device_lists_stream_id"] | |||
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: | |||
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( | |||
@@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
"non-local user %s" % (user_id,), | |||
) | |||
results_dict = await self.db_pool.simple_select_one( | |||
"local_current_membership", | |||
{"room_id": room_id, "user_id": user_id}, | |||
("membership", "event_id"), | |||
allow_none=True, | |||
desc="get_local_current_membership_for_user_in_room", | |||
results = cast( | |||
Optional[Tuple[str, str]], | |||
await self.db_pool.simple_select_one( | |||
"local_current_membership", | |||
{"room_id": room_id, "user_id": user_id}, | |||
("membership", "event_id"), | |||
allow_none=True, | |||
desc="get_local_current_membership_for_user_in_room", | |||
), | |||
) | |||
if not results_dict: | |||
if not results: | |||
return None, None | |||
return results_dict.get("membership"), results_dict.get("event_id") | |||
return results | |||
@cached(max_entries=500000, iterable=True) | |||
async def get_rooms_for_user_with_stream_ordering( | |||
@@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
desc="get_position_for_event", | |||
) | |||
return PersistedEventPosition( | |||
row["instance_name"] or "master", row["stream_ordering"] | |||
) | |||
return PersistedEventPosition(row[1] or "master", row[0]) | |||
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: | |||
"""The stream token for an event | |||
@@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
retcols=("stream_ordering", "topological_ordering"), | |||
desc="get_topological_token_for_event", | |||
) | |||
return RoomStreamToken( | |||
topological=row["topological_ordering"], stream=row["stream_ordering"] | |||
) | |||
return RoomStreamToken(topological=row[1], stream=row[0]) | |||
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: | |||
"""Gets the topological token in a room after or at the given stream | |||
@@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
dict | |||
""" | |||
results = self.db_pool.simple_select_one_txn( | |||
txn, | |||
"events", | |||
keyvalues={"event_id": event_id, "room_id": room_id}, | |||
retcols=["stream_ordering", "topological_ordering"], | |||
stream_ordering, topological_ordering = cast( | |||
Tuple[int, int], | |||
self.db_pool.simple_select_one_txn( | |||
txn, | |||
"events", | |||
keyvalues={"event_id": event_id, "room_id": room_id}, | |||
retcols=["stream_ordering", "topological_ordering"], | |||
), | |||
) | |||
# This cannot happen as `allow_none=False`. | |||
assert results is not None | |||
# Paginating backwards includes the event at the token, but paginating | |||
# forward doesn't. | |||
before_token = RoomStreamToken( | |||
topological=results["topological_ordering"] - 1, | |||
stream=results["stream_ordering"], | |||
topological=topological_ordering - 1, stream=stream_ordering | |||
) | |||
after_token = RoomStreamToken( | |||
topological=results["topological_ordering"], | |||
stream=results["stream_ordering"], | |||
topological=topological_ordering, stream=stream_ordering | |||
) | |||
rows, start_token = self._paginate_room_events_txn( | |||
@@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore): | |||
Returns: the task if available, `None` otherwise | |||
""" | |||
row = await self.db_pool.simple_select_one( | |||
table="scheduled_tasks", | |||
keyvalues={"id": id}, | |||
retcols=( | |||
"id", | |||
"action", | |||
"status", | |||
"timestamp", | |||
"resource_id", | |||
"params", | |||
"result", | |||
"error", | |||
row = cast( | |||
Optional[ScheduledTaskRow], | |||
await self.db_pool.simple_select_one( | |||
table="scheduled_tasks", | |||
keyvalues={"id": id}, | |||
retcols=( | |||
"id", | |||
"action", | |||
"status", | |||
"timestamp", | |||
"resource_id", | |||
"params", | |||
"result", | |||
"error", | |||
), | |||
allow_none=True, | |||
desc="get_scheduled_task", | |||
), | |||
allow_none=True, | |||
desc="get_scheduled_task", | |||
) | |||
return ( | |||
TaskSchedulerWorkerStore._convert_row_to_task( | |||
( | |||
row["id"], | |||
row["action"], | |||
row["status"], | |||
row["timestamp"], | |||
row["resource_id"], | |||
row["params"], | |||
row["result"], | |||
row["error"], | |||
) | |||
) | |||
if row | |||
else None | |||
) | |||
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None | |||
async def delete_scheduled_task(self, id: str) -> None: | |||
"""Delete a specific task from its id. | |||
@@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
txn, | |||
table="received_transactions", | |||
keyvalues={"transaction_id": transaction_id, "origin": origin}, | |||
retcols=( | |||
"transaction_id", | |||
"origin", | |||
"ts", | |||
"response_code", | |||
"response_json", | |||
"has_been_referenced", | |||
), | |||
retcols=("response_code", "response_json"), | |||
allow_none=True, | |||
) | |||
if result and result["response_code"]: | |||
return result["response_code"], db_to_json(result["response_json"]) | |||
# If the result exists and the response code is non-0. | |||
if result and result[0]: | |||
return result[0], db_to_json(result[1]) | |||
else: | |||
return None | |||
@@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
# check we have a row and retry_last_ts is not null or zero | |||
# (retry_last_ts can't be negative) | |||
if result and result["retry_last_ts"]: | |||
return DestinationRetryTimings(**result) | |||
if result and result[1]: | |||
return DestinationRetryTimings( | |||
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2] | |||
) | |||
else: | |||
return None | |||
@@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
desc="get_ui_auth_session", | |||
) | |||
result["clientdict"] = db_to_json(result["clientdict"]) | |||
return UIAuthSessionData(session_id, **result) | |||
return UIAuthSessionData( | |||
session_id, | |||
clientdict=db_to_json(result[0]), | |||
uri=result[1], | |||
method=result[2], | |||
description=result[3], | |||
) | |||
async def mark_ui_auth_stage_complete( | |||
self, | |||
@@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any | |||
) -> None: | |||
# Get the current value. | |||
result = cast( | |||
Dict[str, Any], | |||
self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="ui_auth_sessions", | |||
keyvalues={"session_id": session_id}, | |||
retcols=("serverdict",), | |||
), | |||
result = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
table="ui_auth_sessions", | |||
keyvalues={"session_id": session_id}, | |||
retcol="serverdict", | |||
) | |||
# Update it and add it back to the database. | |||
serverdict = db_to_json(result["serverdict"]) | |||
serverdict = db_to_json(result) | |||
serverdict[key] = value | |||
self.db_pool.simple_update_one_txn( | |||
@@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
Raises: | |||
StoreError if the session cannot be found. | |||
""" | |||
result = await self.db_pool.simple_select_one( | |||
result = await self.db_pool.simple_select_one_onecol( | |||
table="ui_auth_sessions", | |||
keyvalues={"session_id": session_id}, | |||
retcols=("serverdict",), | |||
retcol="serverdict", | |||
desc="get_ui_auth_session_data", | |||
) | |||
serverdict = db_to_json(result["serverdict"]) | |||
serverdict = db_to_json(result) | |||
return serverdict.get(key, default) | |||
@@ -20,7 +20,6 @@ from typing import ( | |||
Collection, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
@@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn | |||
) | |||
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: | |||
return await self.db_pool.simple_select_one( | |||
table="user_directory", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("display_name", "avatar_url"), | |||
allow_none=True, | |||
desc="get_user_in_directory", | |||
async def _get_user_in_directory( | |||
self, user_id: str | |||
) -> Optional[Tuple[Optional[str], Optional[str]]]: | |||
""" | |||
Fetch the user information in the user directory. | |||
Returns: | |||
None if the user is unknown, otherwise a tuple of display name and | |||
avatar URL (both of which may be None). | |||
""" | |||
return cast( | |||
Optional[Tuple[Optional[str], Optional[str]]], | |||
await self.db_pool.simple_select_one( | |||
table="user_directory", | |||
keyvalues={"user_id": user_id}, | |||
retcols=("display_name", "avatar_url"), | |||
allow_none=True, | |||
desc="get_user_in_directory", | |||
), | |||
) | |||
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: | |||
@@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) | |||
return self.get_success( | |||
row = self.get_success( | |||
self.store.db_pool.simple_select_one( | |||
table + "_current", | |||
{id_col: stat_id}, | |||
@@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
) | |||
) | |||
return None if row is None else dict(zip(cols, row)) | |||
def _perform_background_initial_update(self) -> None: | |||
# Do the initial population of the stats via the background update | |||
self._add_background_updates() | |||
@@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||
) | |||
profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) | |||
assert profile is not None | |||
self.assertTrue(profile["display_name"] == display_name) | |||
self.assertTrue(profile[0] == display_name) | |||
def test_handle_local_profile_change_with_deactivated_user(self) -> None: | |||
# create user | |||
@@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||
# profile is in directory | |||
profile = self.get_success(self.store._get_user_in_directory(r_user_id)) | |||
assert profile is not None | |||
self.assertTrue(profile["display_name"] == display_name) | |||
self.assertEqual(profile[0], display_name) | |||
# deactivate user | |||
self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) | |||
@@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): | |||
# is in user directory | |||
profile = self.get_success(self.store._get_user_in_directory(self.other_user)) | |||
assert profile is not None | |||
self.assertTrue(profile["display_name"] == "User") | |||
self.assertEqual(profile[0], "User") | |||
# Deactivate user | |||
channel = self.make_request( | |||
@@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): | |||
# | |||
# Note that we don't have the UI Auth session ID, so just pull out the single | |||
# row. | |||
ui_auth_data = self.get_success( | |||
self.store.db_pool.simple_select_one( | |||
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",) | |||
result = self.get_success( | |||
self.store.db_pool.simple_select_one_onecol( | |||
"ui_auth_sessions", keyvalues={}, retcol="clientdict" | |||
) | |||
) | |||
client_dict = db_to_json(ui_auth_data["clientdict"]) | |||
client_dict = db_to_json(result) | |||
self.assertNotIn("new_password", client_dict) | |||
@override_config({"rc_3pid_validation": {"burst_count": 3}}) | |||
@@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): | |||
self.assertLessEqual(det_data.items(), channel.json_body.items()) | |||
# Check the `completed` counter has been incremented and pending is 0 | |||
res = self.get_success( | |||
pending, completed = self.get_success( | |||
store.db_pool.simple_select_one( | |||
"registration_tokens", | |||
keyvalues={"token": token}, | |||
retcols=["pending", "completed"], | |||
) | |||
) | |||
self.assertEqual(res["completed"], 1) | |||
self.assertEqual(res["pending"], 0) | |||
self.assertEqual(completed, 1) | |||
self.assertEqual(pending, 0) | |||
@override_config({"registration_requires_token": True}) | |||
def test_POST_registration_token_invalid(self) -> None: | |||
@@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): | |||
params1["auth"]["type"] = LoginType.DUMMY | |||
self.make_request(b"POST", self.url, params1) | |||
# Check pending=0 and completed=1 | |||
res = self.get_success( | |||
pending, completed = self.get_success( | |||
store.db_pool.simple_select_one( | |||
"registration_tokens", | |||
keyvalues={"token": token}, | |||
retcols=["pending", "completed"], | |||
) | |||
) | |||
self.assertEqual(res["pending"], 0) | |||
self.assertEqual(res["completed"], 1) | |||
self.assertEqual(pending, 0) | |||
self.assertEqual(completed, 1) | |||
# Check auth still fails when using token with session2 | |||
channel = self.make_request(b"POST", self.url, params2) | |||
@@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
) | |||
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) | |||
self.assertEqual((1, 2, 3), ret) | |||
self.mock_txn.execute.assert_called_once_with( | |||
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] | |||
) | |||
@@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): | |||
) | |||
) | |||
self.assertFalse(ret) | |||
self.assertIsNone(ret) | |||
@defer.inlineCallbacks | |||
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: | |||
@@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase): | |||
) | |||
def test_get_room(self) -> None: | |||
res = self.get_success(self.store.get_room(self.room.to_string())) | |||
assert res is not None | |||
self.assertLessEqual( | |||
{ | |||
"room_id": self.room.to_string(), | |||
"creator": self.u_creator.to_string(), | |||
"is_public": True, | |||
}.items(), | |||
res.items(), | |||
) | |||
room = self.get_success(self.store.get_room(self.room.to_string())) | |||
assert room is not None | |||
self.assertTrue(room[0]) | |||
def test_get_room_unknown_room(self) -> None: | |||
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) | |||