Mostly to improve type safety.tags/v1.96.0rc1
@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -19,6 +19,8 @@ import logging | |||
import urllib.parse | |||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple | |||
import attr | |||
from synapse.api.errors import ( | |||
CodeMessageException, | |||
Codes, | |||
@@ -357,9 +359,9 @@ class IdentityHandler: | |||
# Check to see if a session already exists and that it is not yet | |||
# marked as validated | |||
if session and session.get("validated_at") is None: | |||
session_id = session["session_id"] | |||
last_send_attempt = session["last_send_attempt"] | |||
if session and session.validated_at is None: | |||
session_id = session.session_id | |||
last_send_attempt = session.last_send_attempt | |||
# Check that the send_attempt is higher than previous attempts | |||
if send_attempt <= last_send_attempt: | |||
@@ -480,7 +482,6 @@ class IdentityHandler: | |||
# We don't actually know which medium this 3PID is. Thus we first assume it's email, | |||
# and if validation fails we try msisdn | |||
validation_session = None | |||
# Try to validate as email | |||
if self.hs.config.email.can_verify_email: | |||
@@ -488,19 +489,18 @@ class IdentityHandler: | |||
validation_session = await self.store.get_threepid_validation_session( | |||
"email", client_secret, sid=sid, validated=True | |||
) | |||
if validation_session: | |||
return validation_session | |||
if validation_session: | |||
return attr.asdict(validation_session) | |||
# Try to validate as msisdn | |||
if self.hs.config.registration.account_threepid_delegate_msisdn: | |||
# Ask our delegated msisdn identity server | |||
validation_session = await self.threepid_from_creds( | |||
return await self.threepid_from_creds( | |||
self.hs.config.registration.account_threepid_delegate_msisdn, | |||
threepid_creds, | |||
) | |||
return validation_session | |||
return None | |||
async def proxy_msisdn_submit_token( | |||
self, id_server: str, client_secret: str, sid: str, token: str | |||
@@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker: | |||
if row: | |||
threepid = { | |||
"medium": row["medium"], | |||
"address": row["address"], | |||
"validated_at": row["validated_at"], | |||
"medium": row.medium, | |||
"address": row.address, | |||
"validated_at": row.validated_at, | |||
} | |||
# Valid threepid returned, delete from the db | |||
@@ -949,10 +949,7 @@ class MediaRepository: | |||
deleted = 0 | |||
for media in old_media: | |||
origin = media["media_origin"] | |||
media_id = media["media_id"] | |||
file_id = media["filesystem_id"] | |||
for origin, media_id, file_id in old_media: | |||
key = (origin, media_id) | |||
logger.info("Deleting: %r", key) | |||
@@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet): | |||
destinations, total = await self._store.get_destinations_paginate( | |||
start, limit, destination, order_by, direction | |||
) | |||
response = {"destinations": destinations, "total": total} | |||
response = { | |||
"destinations": [ | |||
{ | |||
"destination": r[0], | |||
"retry_last_ts": r[1], | |||
"retry_interval": r[2], | |||
"failure_ts": r[3], | |||
"last_successful_stream_ordering": r[4], | |||
} | |||
for r in destinations | |||
], | |||
"total": total, | |||
} | |||
if (start + limit) < total: | |||
response["next_token"] = str(start + len(destinations)) | |||
@@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): | |||
room_id, _ = await self.resolve_room_id(room_identifier) | |||
extremities = await self.store.get_forward_extremities_for_room(room_id) | |||
return HTTPStatus.OK, {"count": len(extremities), "results": extremities} | |||
result = [ | |||
{ | |||
"event_id": ex[0], | |||
"state_group": ex[1], | |||
"depth": ex[2], | |||
"received_ts": ex[3], | |||
} | |||
for ex in extremities | |||
] | |||
return HTTPStatus.OK, {"count": len(extremities), "results": result} | |||
class RoomEventContextServlet(RestServlet): | |||
@@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet): | |||
users_media, total = await self.store.get_users_media_usage_paginate( | |||
start, limit, from_ts, until_ts, order_by, direction, search_term | |||
) | |||
ret = {"users": users_media, "total": total} | |||
ret = { | |||
"users": [ | |||
{ | |||
"user_id": r[0], | |||
"displayname": r[1], | |||
"media_count": r[2], | |||
"media_length": r[3], | |||
} | |||
for r in users_media | |||
], | |||
"total": total, | |||
} | |||
if (start + limit) < total: | |||
ret["next_token"] = start + len(users_media) | |||
@@ -35,7 +35,6 @@ from typing import ( | |||
Tuple, | |||
Type, | |||
TypeVar, | |||
Union, | |||
cast, | |||
overload, | |||
) | |||
@@ -1047,43 +1046,20 @@ class DatabasePool: | |||
results = [dict(zip(col_headers, row)) for row in cursor] | |||
return results | |||
@overload | |||
async def execute( | |||
self, desc: str, decoder: Literal[None], query: str, *args: Any | |||
) -> List[Tuple[Any, ...]]: | |||
... | |||
@overload | |||
async def execute( | |||
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any | |||
) -> R: | |||
... | |||
async def execute( | |||
self, | |||
desc: str, | |||
decoder: Optional[Callable[[Cursor], R]], | |||
query: str, | |||
*args: Any, | |||
) -> Union[List[Tuple[Any, ...]], R]: | |||
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: | |||
"""Runs a single query for a result set. | |||
Args: | |||
desc: description of the transaction, for logging and metrics | |||
decoder - The function which can resolve the cursor results to | |||
something meaningful. | |||
query - The query string to execute | |||
*args - Query args. | |||
Returns: | |||
The result of decoder(results) | |||
""" | |||
def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]: | |||
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]: | |||
txn.execute(query, args) | |||
if decoder: | |||
return decoder(txn) | |||
else: | |||
return txn.fetchall() | |||
return txn.fetchall() | |||
return await self.runInteraction(desc, interaction) | |||
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
""" | |||
rows = await self.db_pool.execute( | |||
"_censor_redactions_fetch", None, sql, before_ts, 100 | |||
"_censor_redactions_fetch", sql, before_ts, 100 | |||
) | |||
updates = [] | |||
@@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
rows = await self.db_pool.execute( | |||
"get_all_devices_changed", | |||
None, | |||
sql, | |||
from_key, | |||
to_key, | |||
@@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
WHERE from_user_id = ? AND stream_id > ? | |||
""" | |||
rows = await self.db_pool.execute( | |||
"get_users_whose_signatures_changed", None, sql, user_id, from_key | |||
"get_users_whose_signatures_changed", sql, user_id, from_key | |||
) | |||
return {user for row in rows for user in db_to_json(row[0])} | |||
else: | |||
@@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
""" | |||
rows = await self.db_pool.execute( | |||
"get_e2e_device_keys_for_federation_query_check", | |||
None, | |||
sql, | |||
now_stream_id, | |||
user_id, | |||
@@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the | |||
# indexes on it. | |||
# We need to pass execute a dummy function to handle the txn's result otherwise | |||
# it tries to call fetchall() on it and fails because there's no result to fetch. | |||
await self.db_pool.execute( | |||
await self.db_pool.runInteraction( | |||
"background_analyze_new_stream_ordering_column", | |||
lambda txn: None, | |||
"ANALYZE events(stream_ordering2)", | |||
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"), | |||
) | |||
await self.db_pool.runInteraction( | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Dict, List | |||
from typing import List, Optional, Tuple, cast | |||
from synapse.api.errors import SynapseError | |||
from synapse.storage.database import LoggingTransaction | |||
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore( | |||
async def get_forward_extremities_for_room( | |||
self, room_id: str | |||
) -> List[Dict[str, Any]]: | |||
"""Get list of forward extremities for a room.""" | |||
) -> List[Tuple[str, int, int, Optional[int]]]: | |||
""" | |||
Get list of forward extremities for a room. | |||
Returns: | |||
A list of tuples of event_id, state_group, depth, and received_ts. | |||
""" | |||
def get_forward_extremities_for_room_txn( | |||
txn: LoggingTransaction, | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[str, int, int, Optional[int]]]: | |||
sql = """ | |||
SELECT event_id, state_group, depth, received_ts | |||
FROM event_forward_extremities | |||
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore( | |||
""" | |||
txn.execute(sql, (room_id,)) | |||
return self.db_pool.cursor_to_dict(txn) | |||
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall()) | |||
return await self.db_pool.runInteraction( | |||
"get_forward_extremities_for_room", | |||
@@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
async def get_remote_media_ids( | |||
self, before_ts: int, include_quarantined_media: bool | |||
) -> List[Dict[str, str]]: | |||
) -> List[Tuple[str, str, str]]: | |||
""" | |||
Retrieve a list of server name, media ID tuples from the remote media cache. | |||
@@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
A list of tuples containing: | |||
* The server name of homeserver where the media originates from, | |||
* The ID of the media. | |||
* The filesystem ID. | |||
""" | |||
sql = """ | |||
SELECT media_origin, media_id, filesystem_id | |||
FROM remote_media_cache | |||
WHERE last_access_ts < ? | |||
""" | |||
sql = ( | |||
"SELECT media_origin, media_id, filesystem_id" | |||
" FROM remote_media_cache" | |||
" WHERE last_access_ts < ?" | |||
) | |||
if include_quarantined_media is False: | |||
# Only include media that has not been quarantined | |||
@@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
AND quarantined_by IS NULL | |||
""" | |||
return await self.db_pool.execute( | |||
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts | |||
return cast( | |||
List[Tuple[str, str, str]], | |||
await self.db_pool.execute("get_remote_media_ids", sql, before_ts), | |||
) | |||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None: | |||
@@ -151,6 +151,22 @@ class ThreepidResult: | |||
added_at: int | |||
@attr.s(frozen=True, slots=True, auto_attribs=True) | |||
class ThreepidValidationSession: | |||
address: str | |||
"""address of the 3pid""" | |||
medium: str | |||
"""medium of the 3pid""" | |||
client_secret: str | |||
"""a secret provided by the client for this validation session""" | |||
session_id: str | |||
"""ID of the validation session""" | |||
last_send_attempt: int | |||
"""a number serving to dedupe send attempts for this session""" | |||
validated_at: Optional[int] | |||
"""timestamp of when this session was validated if so""" | |||
class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
def __init__( | |||
self, | |||
@@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
address: Optional[str] = None, | |||
sid: Optional[str] = None, | |||
validated: Optional[bool] = True, | |||
) -> Optional[Dict[str, Any]]: | |||
) -> Optional[ThreepidValidationSession]: | |||
"""Gets a session_id and last_send_attempt (if available) for a | |||
combination of validation metadata | |||
@@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
perform no filtering | |||
Returns: | |||
A dict containing the following: | |||
* address - address of the 3pid | |||
* medium - medium of the 3pid | |||
* client_secret - a secret provided by the client for this validation session | |||
* session_id - ID of the validation session | |||
* send_attempt - a number serving to dedupe send attempts for this session | |||
* validated_at - timestamp of when this session was validated if so | |||
Otherwise None if a validation session is not found | |||
A ThreepidValidationSession or None if a validation session is not found | |||
""" | |||
if not client_secret: | |||
raise SynapseError( | |||
@@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
def get_threepid_validation_session_txn( | |||
txn: LoggingTransaction, | |||
) -> Optional[Dict[str, Any]]: | |||
) -> Optional[ThreepidValidationSession]: | |||
sql = """ | |||
SELECT address, session_id, medium, client_secret, | |||
last_send_attempt, validated_at | |||
@@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
sql += " LIMIT 1" | |||
txn.execute(sql, list(keyvalues.values())) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
if not rows: | |||
row = txn.fetchone() | |||
if not row: | |||
return None | |||
return rows[0] | |||
return ThreepidValidationSession( | |||
address=row[0], | |||
session_id=row[1], | |||
medium=row[2], | |||
client_secret=row[3], | |||
last_send_attempt=row[4], | |||
validated_at=row[5], | |||
) | |||
return await self.db_pool.runInteraction( | |||
"get_threepid_validation_session", get_threepid_validation_session_txn | |||
@@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
like_clause = "%:" + host | |||
rows = await self.db_pool.execute( | |||
"is_host_joined", None, sql, membership, room_id, like_clause | |||
"is_host_joined", sql, membership, room_id, like_clause | |||
) | |||
if not rows: | |||
@@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
AND forgotten = 0; | |||
""" | |||
rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) | |||
rows = await self.db_pool.execute("is_forgotten_room", sql, room_id) | |||
# `count(*)` returns always an integer | |||
# If any rows still exist it means someone has not forgotten this room yet | |||
@@ -26,6 +26,7 @@ from typing import ( | |||
Set, | |||
Tuple, | |||
Union, | |||
cast, | |||
) | |||
import attr | |||
@@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore): | |||
# entire table from the database. | |||
sql += " ORDER BY rank DESC LIMIT 500" | |||
results = await self.db_pool.execute( | |||
"search_msgs", self.db_pool.cursor_to_dict, sql, *args | |||
# List of tuples of (rank, room_id, event_id). | |||
results = cast( | |||
List[Tuple[Union[int, float], str, str]], | |||
await self.db_pool.execute("search_msgs", sql, *args), | |||
) | |||
results = list(filter(lambda row: row["room_id"] in room_ids, results)) | |||
results = list(filter(lambda row: row[1] in room_ids, results)) | |||
# We set redact_behaviour to block here to prevent redacted events being returned in | |||
# search results (which is a data leak) | |||
events = await self.get_events_as_list( # type: ignore[attr-defined] | |||
[r["event_id"] for r in results], | |||
[r[2] for r in results], | |||
redact_behaviour=EventRedactBehaviour.block, | |||
) | |||
@@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore): | |||
count_sql += " GROUP BY room_id" | |||
count_results = await self.db_pool.execute( | |||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args | |||
# List of tuples of (room_id, count). | |||
count_results = cast( | |||
List[Tuple[str, int]], | |||
await self.db_pool.execute("search_rooms_count", count_sql, *count_args), | |||
) | |||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) | |||
count = sum(row[1] for row in count_results if row[0] in room_ids) | |||
return { | |||
"results": [ | |||
{"event": event_map[r["event_id"]], "rank": r["rank"]} | |||
{"event": event_map[r[2]], "rank": r[0]} | |||
for r in results | |||
if r["event_id"] in event_map | |||
if r[2] in event_map | |||
], | |||
"highlights": highlights, | |||
"count": count, | |||
@@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore): | |||
search_query = search_term | |||
sql = """ | |||
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, | |||
origin_server_ts, stream_ordering, room_id, event_id | |||
room_id, event_id, origin_server_ts, stream_ordering | |||
FROM event_search | |||
WHERE vector @@ websearch_to_tsquery('english', ?) AND | |||
""" | |||
@@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore): | |||
# mypy expects to append only a `str`, not an `int` | |||
args.append(limit) | |||
results = await self.db_pool.execute( | |||
"search_rooms", self.db_pool.cursor_to_dict, sql, *args | |||
# List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering). | |||
results = cast( | |||
List[Tuple[Union[int, float], str, str, int, int]], | |||
await self.db_pool.execute("search_rooms", sql, *args), | |||
) | |||
results = list(filter(lambda row: row["room_id"] in room_ids, results)) | |||
results = list(filter(lambda row: row[1] in room_ids, results)) | |||
# We set redact_behaviour to block here to prevent redacted events being returned in | |||
# search results (which is a data leak) | |||
events = await self.get_events_as_list( # type: ignore[attr-defined] | |||
[r["event_id"] for r in results], | |||
[r[2] for r in results], | |||
redact_behaviour=EventRedactBehaviour.block, | |||
) | |||
@@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore): | |||
count_sql += " GROUP BY room_id" | |||
count_results = await self.db_pool.execute( | |||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args | |||
# List of tuples of (room_id, count). | |||
count_results = cast( | |||
List[Tuple[str, int]], | |||
await self.db_pool.execute("search_rooms_count", count_sql, *count_args), | |||
) | |||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) | |||
count = sum(row[1] for row in count_results if row[0] in room_ids) | |||
return { | |||
"results": [ | |||
{ | |||
"event": event_map[r["event_id"]], | |||
"rank": r["rank"], | |||
"pagination_token": "%s,%s" | |||
% (r["origin_server_ts"], r["stream_ordering"]), | |||
"event": event_map[r[2]], | |||
"rank": r[0], | |||
"pagination_token": "%s,%s" % (r[3], r[4]), | |||
} | |||
for r in results | |||
if r["event_id"] in event_map | |||
if r[2] in event_map | |||
], | |||
"highlights": highlights, | |||
"count": count, | |||
@@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore): | |||
order_by: Optional[str] = UserSortOrder.USER_ID.value, | |||
direction: Direction = Direction.FORWARDS, | |||
search_term: Optional[str] = None, | |||
) -> Tuple[List[JsonDict], int]: | |||
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]: | |||
"""Function to retrieve a paginated list of users and their uploaded local media | |||
(size and number). This will return a json list of users and the | |||
total number of users matching the filter criteria. | |||
@@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore): | |||
order_by: the sort order of the returned list | |||
direction: sort ascending or descending | |||
search_term: a string to filter user names by | |||
Returns: | |||
A list of user dicts and an integer representing the total number of | |||
users that exist given this query | |||
A tuple of: | |||
A list of tuples of user information (the user ID, displayname, | |||
total number of media, total length of media) and | |||
An integer representing the total number of users that exist | |||
given this query | |||
""" | |||
def get_users_media_usage_paginate_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[JsonDict], int]: | |||
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]: | |||
filters = [] | |||
args: list = [] | |||
@@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore): | |||
args += [limit, start] | |||
txn.execute(sql, args) | |||
users = self.db_pool.cursor_to_dict(txn) | |||
users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall()) | |||
return users, count | |||
@@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
""" | |||
row = await self.db_pool.execute( | |||
"get_current_topological_token", None, sql, room_id, room_id, stream_key | |||
"get_current_topological_token", sql, room_id, room_id, stream_key | |||
) | |||
return row[0][0] if row else 0 | |||
@@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
rows = await self.db_pool.execute( | |||
"get_timeline_gaps", | |||
None, | |||
sql, | |||
room_id, | |||
from_token.stream if from_token else 0, | |||
@@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
destination: Optional[str] = None, | |||
order_by: str = DestinationSortOrder.DESTINATION.value, | |||
direction: Direction = Direction.FORWARDS, | |||
) -> Tuple[List[JsonDict], int]: | |||
) -> Tuple[ | |||
List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]], | |||
int, | |||
]: | |||
"""Function to retrieve a paginated list of destinations. | |||
This will return a json list of destinations and the | |||
total number of destinations matching the filter criteria. | |||
@@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
order_by: the sort order of the returned list | |||
direction: sort ascending or descending | |||
Returns: | |||
A tuple of a list of mappings from destination to information | |||
A tuple of a list of tuples of destination information: | |||
* destination | |||
* retry_last_ts | |||
* retry_interval | |||
* failure_ts | |||
* last_successful_stream_ordering | |||
and a count of total destinations. | |||
""" | |||
def get_destinations_paginate_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[JsonDict], int]: | |||
) -> Tuple[ | |||
List[ | |||
Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]] | |||
], | |||
int, | |||
]: | |||
order_by_column = DestinationSortOrder(order_by).value | |||
if direction == Direction.BACKWARDS: | |||
@@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
LIMIT ? OFFSET ? | |||
""" | |||
txn.execute(sql, args + [limit, start]) | |||
destinations = self.db_pool.cursor_to_dict(txn) | |||
destinations = cast( | |||
List[ | |||
Tuple[ | |||
str, Optional[int], Optional[int], Optional[int], Optional[int] | |||
] | |||
], | |||
txn.fetchall(), | |||
) | |||
return destinations, count | |||
return await self.db_pool.runInteraction( | |||
@@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): | |||
raise Exception("Unrecognized database engine") | |||
results = cast( | |||
List[UserProfile], | |||
await self.db_pool.execute( | |||
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args | |||
), | |||
List[Tuple[str, Optional[str], Optional[str]]], | |||
await self.db_pool.execute("search_user_dir", sql, *args), | |||
) | |||
limited = len(results) > limit | |||
return {"limited": limited, "results": results[0:limit]} | |||
return { | |||
"limited": limited, | |||
"results": [ | |||
{"user_id": r[0], "display_name": r[1], "avatar_url": r[2]} | |||
for r in results[0:limit] | |||
], | |||
} | |||
def _filter_text_for_index(text: str) -> str: | |||
@@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): | |||
if max_group is None: | |||
rows = await self.db_pool.execute( | |||
"_background_deduplicate_state", | |||
None, | |||
"SELECT coalesce(max(id), 0) FROM state_groups", | |||
) | |||
max_group = rows[0][0] | |||
@@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||
event_id, stream_ordering = self.get_success( | |||
self.hs.get_datastores().main.db_pool.execute( | |||
"test:get_destination_rooms", | |||
None, | |||
""" | |||
SELECT event_id, stream_ordering | |||
FROM destination_rooms dr | |||
@@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): | |||
); | |||
""" | |||
self.get_success( | |||
self.store.db_pool.execute( | |||
"test_not_null_constraint", lambda _: None, table_sql | |||
self.store.db_pool.runInteraction( | |||
"test_not_null_constraint", lambda txn: txn.execute(table_sql) | |||
) | |||
) | |||
@@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): | |||
# using SQLite. | |||
index_sql = "CREATE INDEX test_index ON test_constraint(a)" | |||
self.get_success( | |||
self.store.db_pool.execute( | |||
"test_not_null_constraint", lambda _: None, index_sql | |||
self.store.db_pool.runInteraction( | |||
"test_not_null_constraint", lambda txn: txn.execute(index_sql) | |||
) | |||
) | |||
@@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): | |||
); | |||
""" | |||
self.get_success( | |||
self.store.db_pool.execute( | |||
"test_foreign_key_constraint", lambda _: None, base_sql | |||
self.store.db_pool.runInteraction( | |||
"test_foreign_key_constraint", lambda txn: txn.execute(base_sql) | |||
) | |||
) | |||
self.get_success( | |||
self.store.db_pool.execute( | |||
"test_foreign_key_constraint", lambda _: None, table_sql | |||
self.store.db_pool.runInteraction( | |||
"test_foreign_key_constraint", lambda txn: txn.execute(table_sql) | |||
) | |||
) | |||
@@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): | |||
res = self.get_success( | |||
self.store.db_pool.execute( | |||
"", None, "SELECT full_user_id from profiles ORDER BY full_user_id" | |||
"", "SELECT full_user_id from profiles ORDER BY full_user_id" | |||
) | |||
) | |||
self.assertEqual(len(res), len(expected_values)) | |||
@@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase): | |||
res = self.get_success( | |||
self.store.db_pool.execute( | |||
"", None, "SELECT full_user_id from user_filters ORDER BY full_user_id" | |||
"", "SELECT full_user_id from user_filters ORDER BY full_user_id" | |||
) | |||
) | |||
self.assertEqual(len(res), len(expected_values)) | |||