@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -283,7 +283,7 @@ class AdminHandler: | |||
start, limit, user_id | |||
) | |||
for media in media_ids: | |||
writer.write_media_id(media["media_id"], media) | |||
writer.write_media_id(media.media_id, attr.asdict(media)) | |||
logger.info( | |||
"[%s] Written %d media_ids of %s", | |||
@@ -33,6 +33,7 @@ from synapse.api.errors import ( | |||
RequestSendFailed, | |||
SynapseError, | |||
) | |||
from synapse.storage.databases.main.room import LargestRoomStats | |||
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID | |||
from synapse.util.caches.descriptors import _CacheContext, cached | |||
from synapse.util.caches.response_cache import ResponseCache | |||
@@ -170,26 +171,24 @@ class RoomListHandler: | |||
ignore_non_federatable=from_federation, | |||
) | |||
def build_room_entry(room: JsonDict) -> JsonDict: | |||
def build_room_entry(room: LargestRoomStats) -> JsonDict: | |||
entry = { | |||
"room_id": room["room_id"], | |||
"name": room["name"], | |||
"topic": room["topic"], | |||
"canonical_alias": room["canonical_alias"], | |||
"num_joined_members": room["joined_members"], | |||
"avatar_url": room["avatar"], | |||
"world_readable": room["history_visibility"] | |||
"room_id": room.room_id, | |||
"name": room.name, | |||
"topic": room.topic, | |||
"canonical_alias": room.canonical_alias, | |||
"num_joined_members": room.joined_members, | |||
"avatar_url": room.avatar, | |||
"world_readable": room.history_visibility | |||
== HistoryVisibility.WORLD_READABLE, | |||
"guest_can_join": room["guest_access"] == "can_join", | |||
"join_rule": room["join_rules"], | |||
"room_type": room["room_type"], | |||
"guest_can_join": room.guest_access == "can_join", | |||
"join_rule": room.join_rules, | |||
"room_type": room.room_type, | |||
} | |||
# Filter out Nones – rather omit the field altogether | |||
return {k: v for k, v in entry.items() if v is not None} | |||
results = [build_room_entry(r) for r in results] | |||
response: JsonDict = {} | |||
num_results = len(results) | |||
if limit is not None: | |||
@@ -212,33 +211,33 @@ class RoomListHandler: | |||
# If there was a token given then we assume that there | |||
# must be previous results. | |||
response["prev_batch"] = RoomListNextBatch( | |||
last_joined_members=initial_entry["num_joined_members"], | |||
last_room_id=initial_entry["room_id"], | |||
last_joined_members=initial_entry.joined_members, | |||
last_room_id=initial_entry.room_id, | |||
direction_is_forward=False, | |||
).to_token() | |||
if more_to_come: | |||
response["next_batch"] = RoomListNextBatch( | |||
last_joined_members=final_entry["num_joined_members"], | |||
last_room_id=final_entry["room_id"], | |||
last_joined_members=final_entry.joined_members, | |||
last_room_id=final_entry.room_id, | |||
direction_is_forward=True, | |||
).to_token() | |||
else: | |||
if has_batch_token: | |||
response["next_batch"] = RoomListNextBatch( | |||
last_joined_members=final_entry["num_joined_members"], | |||
last_room_id=final_entry["room_id"], | |||
last_joined_members=final_entry.joined_members, | |||
last_room_id=final_entry.room_id, | |||
direction_is_forward=True, | |||
).to_token() | |||
if more_to_come: | |||
response["prev_batch"] = RoomListNextBatch( | |||
last_joined_members=initial_entry["num_joined_members"], | |||
last_room_id=initial_entry["room_id"], | |||
last_joined_members=initial_entry.joined_members, | |||
last_room_id=initial_entry.room_id, | |||
direction_is_forward=False, | |||
).to_token() | |||
response["chunk"] = results | |||
response["chunk"] = [build_room_entry(r) for r in results] | |||
response["total_room_count_estimate"] = await self.store.count_public_rooms( | |||
network_tuple, | |||
@@ -703,24 +703,24 @@ class RoomSummaryHandler: | |||
# there should always be an entry | |||
assert stats is not None, "unable to retrieve stats for %s" % (room_id,) | |||
entry = { | |||
"room_id": stats["room_id"], | |||
"name": stats["name"], | |||
"topic": stats["topic"], | |||
"canonical_alias": stats["canonical_alias"], | |||
"num_joined_members": stats["joined_members"], | |||
"avatar_url": stats["avatar"], | |||
"join_rule": stats["join_rules"], | |||
entry: JsonDict = { | |||
"room_id": stats.room_id, | |||
"name": stats.name, | |||
"topic": stats.topic, | |||
"canonical_alias": stats.canonical_alias, | |||
"num_joined_members": stats.joined_members, | |||
"avatar_url": stats.avatar, | |||
"join_rule": stats.join_rules, | |||
"world_readable": ( | |||
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE | |||
stats.history_visibility == HistoryVisibility.WORLD_READABLE | |||
), | |||
"guest_can_join": stats["guest_access"] == "can_join", | |||
"room_type": stats["room_type"], | |||
"guest_can_join": stats.guest_access == "can_join", | |||
"room_type": stats.room_type, | |||
} | |||
if self._msc3266_enabled: | |||
entry["im.nheko.summary.version"] = stats["version"] | |||
entry["im.nheko.summary.encryption"] = stats["encryption"] | |||
entry["im.nheko.summary.version"] = stats.version | |||
entry["im.nheko.summary.encryption"] = stats.encryption | |||
# Federation requests need to provide additional information so the | |||
# requested server is able to filter the response appropriately. | |||
@@ -17,6 +17,8 @@ import logging | |||
from http import HTTPStatus | |||
from typing import TYPE_CHECKING, Optional, Tuple | |||
import attr | |||
from synapse.api.constants import Direction | |||
from synapse.api.errors import Codes, NotFoundError, SynapseError | |||
from synapse.http.server import HttpServer | |||
@@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet): | |||
start, limit, user_id, order_by, direction | |||
) | |||
ret = {"media": media, "total": total} | |||
ret = {"media": [attr.asdict(m) for m in media], "total": total} | |||
if (start + limit) < total: | |||
ret["next_token"] = start + len(media) | |||
@@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet): | |||
) | |||
deleted_media, total = await self.media_repository.delete_local_media_ids( | |||
[row["media_id"] for row in media] | |||
[m.media_id for m in media] | |||
) | |||
return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} | |||
@@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet): | |||
await assert_requester_is_admin(self.auth, request) | |||
valid = parse_boolean(request, "valid") | |||
token_list = await self.store.get_registration_tokens(valid) | |||
return HTTPStatus.OK, {"registration_tokens": token_list} | |||
return HTTPStatus.OK, { | |||
"registration_tokens": [ | |||
{ | |||
"token": t[0], | |||
"uses_allowed": t[1], | |||
"pending": t[2], | |||
"completed": t[3], | |||
"expiry_time": t[4], | |||
} | |||
for t in token_list | |||
] | |||
} | |||
class NewRegistrationTokenRestServlet(RestServlet): | |||
@@ -16,6 +16,8 @@ from http import HTTPStatus | |||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast | |||
from urllib import parse as urlparse | |||
import attr | |||
from synapse.api.constants import Direction, EventTypes, JoinRules, Membership | |||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError | |||
from synapse.api.filtering import Filter | |||
@@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet): | |||
raise NotFoundError("Room not found") | |||
members = await self.store.get_users_in_room(room_id) | |||
ret["joined_local_devices"] = await self.store.count_devices_by_users(members) | |||
ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) | |||
result = attr.asdict(ret) | |||
result["joined_local_devices"] = await self.store.count_devices_by_users( | |||
members | |||
) | |||
result["forgotten"] = await self.store.is_locally_forgotten_room(room_id) | |||
return HTTPStatus.OK, ret | |||
return HTTPStatus.OK, result | |||
async def on_DELETE( | |||
self, request: SynapseRequest, room_id: str | |||
@@ -18,6 +18,8 @@ import secrets | |||
from http import HTTPStatus | |||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple | |||
import attr | |||
from synapse.api.constants import Direction, UserTypes | |||
from synapse.api.errors import Codes, NotFoundError, SynapseError | |||
from synapse.http.servlet import ( | |||
@@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet): | |||
) | |||
# If support for MSC3866 is not enabled, don't show the approval flag. | |||
filter = None | |||
if not self._msc3866_enabled: | |||
for user in users: | |||
del user["approved"] | |||
ret = {"users": users, "total": total} | |||
def _filter(a: attr.Attribute) -> bool: | |||
return a.name != "approved" | |||
ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total} | |||
if (start + limit) < total: | |||
ret["next_token"] = str(start + len(users)) | |||
@@ -28,6 +28,7 @@ from typing import ( | |||
Sequence, | |||
Tuple, | |||
Type, | |||
cast, | |||
) | |||
import attr | |||
@@ -488,14 +489,14 @@ class BackgroundUpdater: | |||
True if we have finished running all the background updates, otherwise False | |||
""" | |||
def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: | |||
def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]: | |||
txn.execute( | |||
""" | |||
SELECT update_name, depends_on FROM background_updates | |||
ORDER BY ordering, update_name | |||
""" | |||
) | |||
return self.db_pool.cursor_to_dict(txn) | |||
return cast(List[Tuple[str, Optional[str]]], txn.fetchall()) | |||
if not self._current_background_update: | |||
all_pending_updates = await self.db_pool.runInteraction( | |||
@@ -507,14 +508,13 @@ class BackgroundUpdater: | |||
return True | |||
# find the first update which isn't dependent on another one in the queue. | |||
pending = {update["update_name"] for update in all_pending_updates} | |||
for upd in all_pending_updates: | |||
depends_on = upd["depends_on"] | |||
pending = {update_name for update_name, depends_on in all_pending_updates} | |||
for update_name, depends_on in all_pending_updates: | |||
if not depends_on or depends_on not in pending: | |||
break | |||
logger.info( | |||
"Not starting on bg update %s until %s is done", | |||
upd["update_name"], | |||
update_name, | |||
depends_on, | |||
) | |||
else: | |||
@@ -524,7 +524,7 @@ class BackgroundUpdater: | |||
"another: dependency cycle?" | |||
) | |||
self._current_background_update = upd["update_name"] | |||
self._current_background_update = update_name | |||
# We have a background update to run, otherwise we would have returned | |||
# early. | |||
@@ -18,7 +18,6 @@ import logging | |||
import time | |||
import types | |||
from collections import defaultdict | |||
from sys import intern | |||
from time import monotonic as monotonic_time | |||
from typing import ( | |||
TYPE_CHECKING, | |||
@@ -1042,20 +1041,6 @@ class DatabasePool: | |||
self._db_pool.runWithConnection(inner_func, *args, **kwargs) | |||
) | |||
@staticmethod | |||
def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: | |||
"""Converts a SQL cursor into an list of dicts. | |||
Args: | |||
cursor: The DBAPI cursor which has executed a query. | |||
Returns: | |||
A list of dicts where the key is the column header. | |||
""" | |||
assert cursor.description is not None, "cursor.description was None" | |||
col_headers = [intern(str(column[0])) for column in cursor.description] | |||
results = [dict(zip(col_headers, row)) for row in cursor] | |||
return results | |||
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: | |||
"""Runs a single query for a result set. | |||
@@ -17,6 +17,8 @@ | |||
import logging | |||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast | |||
import attr | |||
from synapse.api.constants import Direction | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.storage._base import make_in_list_sql_clause | |||
@@ -28,7 +30,7 @@ from synapse.storage.database import ( | |||
from synapse.storage.databases.main.stats import UserSortOrder | |||
from synapse.storage.engines import BaseDatabaseEngine | |||
from synapse.storage.types import Cursor | |||
from synapse.types import JsonDict, get_domain_from_id | |||
from synapse.types import get_domain_from_id | |||
from .account_data import AccountDataStore | |||
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore | |||
@@ -82,6 +84,25 @@ if TYPE_CHECKING: | |||
logger = logging.getLogger(__name__) | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class UserPaginateResponse: | |||
"""This is very similar to UserInfo, but not quite the same.""" | |||
name: str | |||
user_type: Optional[str] | |||
is_guest: bool | |||
admin: bool | |||
deactivated: bool | |||
shadow_banned: bool | |||
displayname: Optional[str] | |||
avatar_url: Optional[str] | |||
creation_ts: Optional[int] | |||
approved: bool | |||
erased: bool | |||
last_seen_ts: int | |||
locked: bool | |||
class DataStore( | |||
EventsBackgroundUpdatesStore, | |||
ExperimentalFeaturesStore, | |||
@@ -156,7 +177,7 @@ class DataStore( | |||
approved: bool = True, | |||
not_user_types: Optional[List[str]] = None, | |||
locked: bool = False, | |||
) -> Tuple[List[JsonDict], int]: | |||
) -> Tuple[List[UserPaginateResponse], int]: | |||
"""Function to retrieve a paginated list of users from | |||
users list. This will return a json list of users and the | |||
total number of users matching the filter criteria. | |||
@@ -182,7 +203,7 @@ class DataStore( | |||
def get_users_paginate_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[JsonDict], int]: | |||
) -> Tuple[List[UserPaginateResponse], int]: | |||
filters = [] | |||
args: list = [] | |||
@@ -282,13 +303,24 @@ class DataStore( | |||
""" | |||
args += [limit, start] | |||
txn.execute(sql, args) | |||
users = self.db_pool.cursor_to_dict(txn) | |||
# some of those boolean values are returned as integers when we're on SQLite | |||
columns_to_boolify = ["erased"] | |||
for user in users: | |||
for column in columns_to_boolify: | |||
user[column] = bool(user[column]) | |||
users = [ | |||
UserPaginateResponse( | |||
name=row[0], | |||
user_type=row[1], | |||
is_guest=bool(row[2]), | |||
admin=bool(row[3]), | |||
deactivated=bool(row[4]), | |||
shadow_banned=bool(row[5]), | |||
displayname=row[6], | |||
avatar_url=row[7], | |||
creation_ts=row[8], | |||
approved=bool(row[9]), | |||
erased=bool(row[10]), | |||
last_seen_ts=row[11], | |||
locked=bool(row[12]), | |||
) | |||
for row in txn | |||
] | |||
return users, count | |||
@@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
# | |||
# For each duplicate, we delete all the existing rows and put one back. | |||
KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] | |||
last_row = progress.get( | |||
"last_row", | |||
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, | |||
@@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
def _txn(txn: LoggingTransaction) -> int: | |||
clause, args = make_tuple_comparison_clause( | |||
[(x, last_row[x]) for x in KEY_COLS] | |||
[ | |||
("stream_id", last_row["stream_id"]), | |||
("destination", last_row["destination"]), | |||
("user_id", last_row["user_id"]), | |||
("device_id", last_row["device_id"]), | |||
] | |||
) | |||
sql = """ | |||
sql = f""" | |||
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts | |||
FROM device_lists_outbound_pokes | |||
WHERE %s | |||
GROUP BY %s | |||
WHERE {clause} | |||
GROUP BY stream_id, destination, user_id, device_id | |||
HAVING count(*) > 1 | |||
ORDER BY %s | |||
ORDER BY stream_id, destination, user_id, device_id | |||
LIMIT ? | |||
""" % ( | |||
clause, # WHERE | |||
",".join(KEY_COLS), # GROUP BY | |||
",".join(KEY_COLS), # ORDER BY | |||
) | |||
""" | |||
txn.execute(sql, args + [batch_size]) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
rows = txn.fetchall() | |||
row = None | |||
for row in rows: | |||
stream_id, destination, user_id, device_id = None, None, None, None | |||
for stream_id, destination, user_id, device_id, _ in rows: | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
"device_lists_outbound_pokes", | |||
{x: row[x] for x in KEY_COLS}, | |||
{ | |||
"stream_id": stream_id, | |||
"destination": destination, | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
}, | |||
) | |||
row["sent"] = False | |||
self.db_pool.simple_insert_txn( | |||
txn, | |||
"device_lists_outbound_pokes", | |||
row, | |||
{ | |||
"stream_id": stream_id, | |||
"destination": destination, | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"sent": False, | |||
}, | |||
) | |||
if row: | |||
if rows: | |||
self.db_pool.updates._background_update_progress_txn( | |||
txn, | |||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, | |||
{"last_row": row}, | |||
{ | |||
"last_row": { | |||
"stream_id": stream_id, | |||
"destination": destination, | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
} | |||
}, | |||
) | |||
return len(rows) | |||
@@ -26,6 +26,8 @@ from typing import ( | |||
cast, | |||
) | |||
import attr | |||
from synapse.api.constants import Direction | |||
from synapse.logging.opentracing import trace | |||
from synapse.media._base import ThumbnailInfo | |||
@@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( | |||
) | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class LocalMedia: | |||
media_id: str | |||
media_type: str | |||
media_length: int | |||
upload_name: str | |||
created_ts: int | |||
last_access_ts: int | |||
quarantined_by: Optional[str] | |||
safe_from_quarantine: bool | |||
class MediaSortOrder(Enum): | |||
""" | |||
Enum to define the sorting method used when returning media with | |||
@@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
user_id: str, | |||
order_by: str = MediaSortOrder.CREATED_TS.value, | |||
direction: Direction = Direction.FORWARDS, | |||
) -> Tuple[List[Dict[str, Any]], int]: | |||
) -> Tuple[List[LocalMedia], int]: | |||
"""Get a paginated list of metadata for a local piece of media | |||
which an user_id has uploaded | |||
@@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
def get_local_media_by_user_paginate_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[Dict[str, Any]], int]: | |||
) -> Tuple[List[LocalMedia], int]: | |||
# Set ordering | |||
order_by_column = MediaSortOrder(order_by).value | |||
@@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
sql = """ | |||
SELECT | |||
"media_id", | |||
"media_type", | |||
"media_length", | |||
"upload_name", | |||
"created_ts", | |||
"last_access_ts", | |||
"quarantined_by", | |||
"safe_from_quarantine" | |||
media_id, | |||
media_type, | |||
media_length, | |||
upload_name, | |||
created_ts, | |||
last_access_ts, | |||
quarantined_by, | |||
safe_from_quarantine | |||
FROM local_media_repository | |||
WHERE user_id = ? | |||
ORDER BY {order_by_column} {order}, media_id ASC | |||
@@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
args += [limit, start] | |||
txn.execute(sql, args) | |||
media = self.db_pool.cursor_to_dict(txn) | |||
media = [ | |||
LocalMedia( | |||
media_id=row[0], | |||
media_type=row[1], | |||
media_length=row[2], | |||
upload_name=row[3], | |||
created_ts=row[4], | |||
last_access_ts=row[5], | |||
quarantined_by=row[6], | |||
safe_from_quarantine=bool(row[7]), | |||
) | |||
for row in txn | |||
] | |||
return media, count | |||
return await self.db_pool.runInteraction( | |||
@@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
async def get_registration_tokens( | |||
self, valid: Optional[bool] = None | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: | |||
"""List all registration tokens. Used by the admin API. | |||
Args: | |||
@@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
Default is None: return all tokens regardless of validity. | |||
Returns: | |||
A list of dicts, each containing details of a token. | |||
A list of tuples containing: | |||
* The token | |||
* The number of users allowed (or None) | |||
* Whether it is pending | |||
* Whether it has been completed | |||
* An expiry time (or None if no expiry) | |||
""" | |||
def select_registration_tokens_txn( | |||
txn: LoggingTransaction, now: int, valid: Optional[bool] | |||
) -> List[Dict[str, Any]]: | |||
) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: | |||
if valid is None: | |||
# Return all tokens regardless of validity | |||
txn.execute("SELECT * FROM registration_tokens") | |||
txn.execute( | |||
""" | |||
SELECT token, uses_allowed, pending, completed, expiry_time | |||
FROM registration_tokens | |||
""" | |||
) | |||
elif valid: | |||
# Select valid tokens only | |||
sql = ( | |||
"SELECT * FROM registration_tokens WHERE " | |||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) " | |||
"AND (expiry_time > ? OR expiry_time IS NULL)" | |||
) | |||
sql = """ | |||
SELECT token, uses_allowed, pending, completed, expiry_time | |||
FROM registration_tokens | |||
WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL) | |||
AND (expiry_time > ? OR expiry_time IS NULL) | |||
""" | |||
txn.execute(sql, [now]) | |||
else: | |||
# Select invalid tokens only | |||
sql = ( | |||
"SELECT * FROM registration_tokens WHERE " | |||
"uses_allowed <= pending + completed OR expiry_time <= ?" | |||
) | |||
sql = """ | |||
SELECT token, uses_allowed, pending, completed, expiry_time | |||
FROM registration_tokens | |||
WHERE uses_allowed <= pending + completed OR expiry_time <= ? | |||
""" | |||
txn.execute(sql, [now]) | |||
return self.db_pool.cursor_to_dict(txn) | |||
return cast( | |||
List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() | |||
) | |||
return await self.db_pool.runInteraction( | |||
"select_registration_tokens", | |||
@@ -78,6 +78,31 @@ class RatelimitOverride: | |||
burst_count: int | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class LargestRoomStats: | |||
room_id: str | |||
name: Optional[str] | |||
canonical_alias: Optional[str] | |||
joined_members: int | |||
join_rules: Optional[str] | |||
guest_access: Optional[str] | |||
history_visibility: Optional[str] | |||
state_events: int | |||
avatar: Optional[str] | |||
topic: Optional[str] | |||
room_type: Optional[str] | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class RoomStats(LargestRoomStats): | |||
joined_local_members: int | |||
version: Optional[str] | |||
creator: Optional[str] | |||
encryption: Optional[str] | |||
federatable: bool | |||
public: bool | |||
class RoomSortOrder(Enum): | |||
""" | |||
Enum to define the sorting method used when returning rooms with get_rooms_paginate | |||
@@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
allow_none=True, | |||
) | |||
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: | |||
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: | |||
"""Retrieve room with statistics. | |||
Args: | |||
@@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
def get_room_with_stats_txn( | |||
txn: LoggingTransaction, room_id: str | |||
) -> Optional[Dict[str, Any]]: | |||
) -> Optional[RoomStats]: | |||
sql = """ | |||
SELECT room_id, state.name, state.canonical_alias, curr.joined_members, | |||
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, | |||
@@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
WHERE room_id = ? | |||
""" | |||
txn.execute(sql, [room_id]) | |||
# Catch error if sql returns empty result to return "None" instead of an error | |||
try: | |||
res = self.db_pool.cursor_to_dict(txn)[0] | |||
except IndexError: | |||
row = txn.fetchone() | |||
if not row: | |||
return None | |||
res["federatable"] = bool(res["federatable"]) | |||
res["public"] = bool(res["public"]) | |||
return res | |||
return RoomStats( | |||
room_id=row[0], | |||
name=row[1], | |||
canonical_alias=row[2], | |||
joined_members=row[3], | |||
joined_local_members=row[4], | |||
version=row[5], | |||
creator=row[6], | |||
encryption=row[7], | |||
federatable=bool(row[8]), | |||
public=bool(row[9]), | |||
join_rules=row[10], | |||
guest_access=row[11], | |||
history_visibility=row[12], | |||
state_events=row[13], | |||
avatar=row[14], | |||
topic=row[15], | |||
room_type=row[16], | |||
) | |||
return await self.db_pool.runInteraction( | |||
"get_room_with_stats", get_room_with_stats_txn, room_id | |||
@@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
bounds: Optional[Tuple[int, str]], | |||
forwards: bool, | |||
ignore_non_federatable: bool = False, | |||
) -> List[Dict[str, Any]]: | |||
) -> List[LargestRoomStats]: | |||
"""Gets the largest public rooms (where largest is in terms of joined | |||
members, as tracked in the statistics table). | |||
@@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): | |||
def _get_largest_public_rooms_txn( | |||
txn: LoggingTransaction, | |||
) -> List[Dict[str, Any]]: | |||
) -> List[LargestRoomStats]: | |||
txn.execute(sql, query_args) | |||
results = self.db_pool.cursor_to_dict(txn) | |||
results = [ | |||
LargestRoomStats( | |||
room_id=r[0], | |||
name=r[1], | |||
canonical_alias=r[3], | |||
joined_members=r[4], | |||
join_rules=r[8], | |||
guest_access=r[7], | |||
history_visibility=r[6], | |||
state_events=0, | |||
avatar=r[5], | |||
topic=r[2], | |||
room_type=r[9], | |||
) | |||
for r in txn | |||
] | |||
if not forwards: | |||
results.reverse() | |||
return results | |||
ret_val = await self.db_pool.runInteraction( | |||
return await self.db_pool.runInteraction( | |||
"get_largest_public_rooms", _get_largest_public_rooms_txn | |||
) | |||
return ret_val | |||
@cached(max_entries=10000) | |||
async def is_room_blocked(self, room_id: str) -> Optional[bool]: | |||
@@ -342,10 +342,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
# Ensure the room is properly not federated. | |||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | |||
assert room is not None | |||
self.assertFalse(room["federatable"]) | |||
self.assertFalse(room["public"]) | |||
self.assertEqual(room["join_rules"], "public") | |||
self.assertIsNone(room["guest_access"]) | |||
self.assertFalse(room.federatable) | |||
self.assertFalse(room.public) | |||
self.assertEqual(room.join_rules, "public") | |||
self.assertIsNone(room.guest_access) | |||
# The user should be in the room. | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
@@ -372,7 +372,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
# Ensure the room is properly a public room. | |||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | |||
assert room is not None | |||
self.assertEqual(room["join_rules"], "public") | |||
self.assertEqual(room.join_rules, "public") | |||
# Both users should be in the room. | |||
rooms = self.get_success(self.store.get_rooms_for_user(inviter)) | |||
@@ -411,9 +411,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
# Ensure the room is properly a private room. | |||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | |||
assert room is not None | |||
self.assertFalse(room["public"]) | |||
self.assertEqual(room["join_rules"], "invite") | |||
self.assertEqual(room["guest_access"], "can_join") | |||
self.assertFalse(room.public) | |||
self.assertEqual(room.join_rules, "invite") | |||
self.assertEqual(room.guest_access, "can_join") | |||
# Both users should be in the room. | |||
rooms = self.get_success(self.store.get_rooms_for_user(inviter)) | |||
@@ -455,9 +455,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
# Ensure the room is properly a private room. | |||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | |||
assert room is not None | |||
self.assertFalse(room["public"]) | |||
self.assertEqual(room["join_rules"], "invite") | |||
self.assertEqual(room["guest_access"], "can_join") | |||
self.assertFalse(room.public) | |||
self.assertEqual(room.join_rules, "invite") | |||
self.assertEqual(room.guest_access, "can_join") | |||
# Both users should be in the room. | |||
rooms = self.get_success(self.store.get_rooms_for_user(inviter)) | |||
@@ -39,11 +39,11 @@ class DataStoreTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual(1, total) | |||
self.assertEqual(self.displayname, users.pop()["displayname"]) | |||
self.assertEqual(self.displayname, users.pop().displayname) | |||
users, total = self.get_success( | |||
self.store.get_users_paginate(0, 10, name="BC", guests=False) | |||
) | |||
self.assertEqual(1, total) | |||
self.assertEqual(self.displayname, users.pop()["displayname"]) | |||
self.assertEqual(self.displayname, users.pop().displayname) |
@@ -59,14 +59,9 @@ class RoomStoreTestCase(HomeserverTestCase): | |||
def test_get_room_with_stats(self) -> None: | |||
res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) | |||
assert res is not None | |||
self.assertLessEqual( | |||
{ | |||
"room_id": self.room.to_string(), | |||
"creator": self.u_creator.to_string(), | |||
"public": True, | |||
}.items(), | |||
res.items(), | |||
) | |||
self.assertEqual(res.room_id, self.room.to_string()) | |||
self.assertEqual(res.creator, self.u_creator.to_string()) | |||
self.assertTrue(res.public) | |||
def test_get_room_with_stats_unknown_room(self) -> None: | |||
self.assertIsNone( | |||