This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places.tags/v1.73.0rc1
@@ -0,0 +1 @@ | |||
Add missing type hints to `HomeServer`. |
@@ -16,6 +16,7 @@ import logging | |||
from typing import TYPE_CHECKING, Optional | |||
from synapse.api.errors import SynapseError | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.types import Codes, Requester, UserID, create_requester | |||
@@ -76,6 +77,9 @@ class DeactivateAccountHandler: | |||
True if identity server supports removing threepids, otherwise False. | |||
""" | |||
# This can only be called on the main process. | |||
assert isinstance(self._device_handler, DeviceHandler) | |||
# Check if this user can be deactivated | |||
if not await self._third_party_rules.check_can_deactivate_user( | |||
user_id, by_admin | |||
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 | |||
class DeviceWorkerHandler: | |||
device_list_updater: "DeviceListWorkerUpdater" | |||
def __init__(self, hs: "HomeServer"): | |||
self.clock = hs.get_clock() | |||
self.hs = hs | |||
@@ -76,6 +78,8 @@ class DeviceWorkerHandler: | |||
self.server_name = hs.hostname | |||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled | |||
self.device_list_updater = DeviceListWorkerUpdater(hs) | |||
@trace | |||
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: | |||
""" | |||
@@ -99,6 +103,19 @@ class DeviceWorkerHandler: | |||
log_kv(device_map) | |||
return devices | |||
async def get_dehydrated_device( | |||
self, user_id: str | |||
) -> Optional[Tuple[str, JsonDict]]: | |||
"""Retrieve the information for a dehydrated device. | |||
Args: | |||
user_id: the user whose dehydrated device we are looking for | |||
Returns: | |||
a tuple whose first item is the device ID, and the second item is | |||
the dehydrated device information | |||
""" | |||
return await self.store.get_dehydrated_device(user_id) | |||
@trace | |||
async def get_device(self, user_id: str, device_id: str) -> JsonDict: | |||
"""Retrieve the given device | |||
@@ -127,7 +144,7 @@ class DeviceWorkerHandler: | |||
@cancellable | |||
async def get_device_changes_in_shared_rooms( | |||
self, user_id: str, room_ids: Collection[str], from_token: StreamToken | |||
) -> Collection[str]: | |||
) -> Set[str]: | |||
"""Get the set of users whose devices have changed who share a room with | |||
the given user. | |||
""" | |||
@@ -320,6 +337,8 @@ class DeviceWorkerHandler: | |||
class DeviceHandler(DeviceWorkerHandler): | |||
device_list_updater: "DeviceListUpdater" | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): | |||
await self.delete_devices(user_id, [old_device_id]) | |||
return device_id | |||
async def get_dehydrated_device( | |||
self, user_id: str | |||
) -> Optional[Tuple[str, JsonDict]]: | |||
"""Retrieve the information for a dehydrated device. | |||
Args: | |||
user_id: the user whose dehydrated device we are looking for | |||
Returns: | |||
a tuple whose first item is the device ID, and the second item is | |||
the dehydrated device information | |||
""" | |||
return await self.store.get_dehydrated_device(user_id) | |||
async def rehydrate_device( | |||
self, user_id: str, access_token: str, device_id: str | |||
) -> dict: | |||
@@ -882,7 +888,36 @@ def _update_device_from_client_ips( | |||
) | |||
class DeviceListUpdater: | |||
class DeviceListWorkerUpdater: | |||
"Handles incoming device list updates from federation and contacts the main process over replication" | |||
def __init__(self, hs: "HomeServer"): | |||
from synapse.replication.http.devices import ( | |||
ReplicationUserDevicesResyncRestServlet, | |||
) | |||
self._user_device_resync_client = ( | |||
ReplicationUserDevicesResyncRestServlet.make_client(hs) | |||
) | |||
async def user_device_resync( | |||
self, user_id: str, mark_failed_as_stale: bool = True | |||
) -> Optional[JsonDict]: | |||
"""Fetches all devices for a user and updates the device cache with them. | |||
Args: | |||
user_id: The user's id whose device_list will be updated. | |||
mark_failed_as_stale: Whether to mark the user's device list as stale | |||
if the attempt to resync failed. | |||
Returns: | |||
A dict with device info as under the "devices" in the result of this | |||
request: | |||
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid | |||
""" | |||
return await self._user_device_resync_client(user_id=user_id) | |||
class DeviceListUpdater(DeviceListWorkerUpdater): | |||
"Handles incoming device list updates from federation and updates the DB" | |||
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): | |||
@@ -27,9 +27,9 @@ from twisted.internet import defer | |||
from synapse.api.constants import EduTypes | |||
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace | |||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet | |||
from synapse.types import ( | |||
JsonDict, | |||
UserID, | |||
@@ -56,27 +56,23 @@ class E2eKeysHandler: | |||
self.is_mine = hs.is_mine | |||
self.clock = hs.get_clock() | |||
self._edu_updater = SigningKeyEduUpdater(hs, self) | |||
federation_registry = hs.get_federation_registry() | |||
self._is_master = hs.config.worker.worker_app is None | |||
if not self._is_master: | |||
self._user_device_resync_client = ( | |||
ReplicationUserDevicesResyncRestServlet.make_client(hs) | |||
) | |||
else: | |||
is_master = hs.config.worker.worker_app is None | |||
if is_master: | |||
edu_updater = SigningKeyEduUpdater(hs) | |||
# Only register this edu handler on master as it requires writing | |||
# device updates to the db | |||
federation_registry.register_edu_handler( | |||
EduTypes.SIGNING_KEY_UPDATE, | |||
self._edu_updater.incoming_signing_key_update, | |||
edu_updater.incoming_signing_key_update, | |||
) | |||
# also handle the unstable version | |||
# FIXME: remove this when enough servers have upgraded | |||
federation_registry.register_edu_handler( | |||
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, | |||
self._edu_updater.incoming_signing_key_update, | |||
edu_updater.incoming_signing_key_update, | |||
) | |||
# doesn't really work as part of the generic query API, because the | |||
@@ -319,14 +315,13 @@ class E2eKeysHandler: | |||
# probably be tracking their device lists. However, we haven't | |||
# done an initial sync on the device list so we do it now. | |||
try: | |||
if self._is_master: | |||
resync_results = await self.device_handler.device_list_updater.user_device_resync( | |||
resync_results = ( | |||
await self.device_handler.device_list_updater.user_device_resync( | |||
user_id | |||
) | |||
else: | |||
resync_results = await self._user_device_resync_client( | |||
user_id=user_id | |||
) | |||
) | |||
if resync_results is None: | |||
raise ValueError("Device resync failed") | |||
# Add the device keys to the results. | |||
user_devices = resync_results["devices"] | |||
@@ -605,6 +600,8 @@ class E2eKeysHandler: | |||
async def upload_keys_for_user( | |||
self, user_id: str, device_id: str, keys: JsonDict | |||
) -> JsonDict: | |||
# This can only be called from the main process. | |||
assert isinstance(self.device_handler, DeviceHandler) | |||
time_now = self.clock.time_msec() | |||
@@ -732,6 +729,8 @@ class E2eKeysHandler: | |||
user_id: the user uploading the keys | |||
keys: the signing keys | |||
""" | |||
# This can only be called from the main process. | |||
assert isinstance(self.device_handler, DeviceHandler) | |||
# if a master key is uploaded, then check it. Otherwise, load the | |||
# stored master key, to check signatures on other keys | |||
@@ -823,6 +822,9 @@ class E2eKeysHandler: | |||
Raises: | |||
SynapseError: if the signatures dict is not valid. | |||
""" | |||
# This can only be called from the main process. | |||
assert isinstance(self.device_handler, DeviceHandler) | |||
failures = {} | |||
# signatures to be stored. Each item will be a SignatureListItem | |||
@@ -1200,6 +1202,9 @@ class E2eKeysHandler: | |||
A tuple of the retrieved key content, the key's ID and the matching VerifyKey. | |||
If the key cannot be retrieved, all values in the tuple will instead be None. | |||
""" | |||
# This can only be called from the main process. | |||
assert isinstance(self.device_handler, DeviceHandler) | |||
try: | |||
remote_result = await self.federation.query_user_devices( | |||
user.domain, user.to_string() | |||
@@ -1396,11 +1401,14 @@ class SignatureListItem: | |||
class SigningKeyEduUpdater: | |||
"""Handles incoming signing key updates from federation and updates the DB""" | |||
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastores().main | |||
self.federation = hs.get_federation_client() | |||
self.clock = hs.get_clock() | |||
self.e2e_keys_handler = e2e_keys_handler | |||
device_handler = hs.get_device_handler() | |||
assert isinstance(device_handler, DeviceHandler) | |||
self._device_handler = device_handler | |||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key") | |||
@@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: | |||
user_id: the user whose updates we are processing | |||
""" | |||
device_handler = self.e2e_keys_handler.device_handler | |||
device_list_updater = device_handler.device_list_updater | |||
async with self._remote_edu_linearizer.queue(user_id): | |||
pending_updates = self._pending_updates.pop(user_id, []) | |||
if not pending_updates: | |||
@@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: | |||
logger.info("pending updates: %r", pending_updates) | |||
for master_key, self_signing_key in pending_updates: | |||
new_device_ids = ( | |||
await device_list_updater.process_cross_signing_key_update( | |||
user_id, | |||
master_key, | |||
self_signing_key, | |||
) | |||
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( | |||
user_id, | |||
master_key, | |||
self_signing_key, | |||
) | |||
device_ids = device_ids + new_device_ids | |||
await device_handler.notify_device_update(user_id, device_ids) | |||
await self._device_handler.notify_device_update(user_id, device_ids) |
@@ -38,6 +38,7 @@ from synapse.api.errors import ( | |||
) | |||
from synapse.appservice import ApplicationService | |||
from synapse.config.server import is_threepid_reserved | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.http.servlet import assert_params_in_dict | |||
from synapse.replication.http.login import RegisterDeviceReplicationServlet | |||
from synapse.replication.http.register import ( | |||
@@ -841,6 +842,9 @@ class RegistrationHandler: | |||
refresh_token = None | |||
refresh_token_id = None | |||
# This can only run on the main process. | |||
assert isinstance(self.device_handler, DeviceHandler) | |||
registered_device_id = await self.device_handler.check_device_registered( | |||
user_id, | |||
device_id, | |||
@@ -15,6 +15,7 @@ import logging | |||
from typing import TYPE_CHECKING, Optional | |||
from synapse.api.errors import Codes, StoreError, SynapseError | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.types import Requester | |||
if TYPE_CHECKING: | |||
@@ -29,7 +30,10 @@ class SetPasswordHandler: | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastores().main | |||
self._auth_handler = hs.get_auth_handler() | |||
self._device_handler = hs.get_device_handler() | |||
# This can only be instantiated on the main process. | |||
device_handler = hs.get_device_handler() | |||
assert isinstance(device_handler, DeviceHandler) | |||
self._device_handler = device_handler | |||
async def set_password( | |||
self, | |||
@@ -37,6 +37,7 @@ from twisted.web.server import Request | |||
from synapse.api.constants import LoginType | |||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError | |||
from synapse.config.sso import SsoAttributeRequirement | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.handlers.register import init_counters_for_auth_provider | |||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants | |||
from synapse.http import get_request_user_agent | |||
@@ -1035,6 +1036,8 @@ class SsoHandler: | |||
) -> None: | |||
"""Revoke any devices and in-flight logins tied to a provider session. | |||
Can only be called from the main process. | |||
Args: | |||
auth_provider_id: A unique identifier for this SSO provider, e.g. | |||
"oidc" or "saml". | |||
@@ -1042,6 +1045,12 @@ class SsoHandler: | |||
expected_user_id: The user we're expecting to logout. If set, it will ignore | |||
sessions belonging to other users and log an error. | |||
""" | |||
# It is expected that this is the main process. | |||
assert isinstance( | |||
self._device_handler, DeviceHandler | |||
), "revoking SSO sessions can only be called on the main process" | |||
# Invalidate any running user-mapping sessions | |||
to_delete = [] | |||
for session_id, session in self._username_mapping_sessions.items(): | |||
@@ -86,6 +86,7 @@ from synapse.handlers.auth import ( | |||
ON_LOGGED_OUT_CALLBACK, | |||
AuthHandler, | |||
) | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.handlers.push_rules import RuleSpec, check_actions | |||
from synapse.http.client import SimpleHttpClient | |||
from synapse.http.server import ( | |||
@@ -207,6 +208,7 @@ class ModuleApi: | |||
self._registration_handler = hs.get_registration_handler() | |||
self._send_email_handler = hs.get_send_email_handler() | |||
self._push_rules_handler = hs.get_push_rules_handler() | |||
self._device_handler = hs.get_device_handler() | |||
self.custom_template_dir = hs.config.server.custom_template_directory | |||
try: | |||
@@ -784,6 +786,8 @@ class ModuleApi: | |||
) -> Generator["defer.Deferred[Any]", Any, None]: | |||
"""Invalidate an access token for a user | |||
Can only be called from the main process. | |||
Added in Synapse v0.25.0. | |||
Args: | |||
@@ -796,6 +800,10 @@ class ModuleApi: | |||
Raises: | |||
synapse.api.errors.AuthError: the access token is invalid | |||
""" | |||
assert isinstance( | |||
self._device_handler, DeviceHandler | |||
), "invalidate_access_token can only be called on the main process" | |||
# see if the access token corresponds to a device | |||
user_info = yield defer.ensureDeferred( | |||
self._auth.get_user_by_access_token(access_token) | |||
@@ -805,7 +813,7 @@ class ModuleApi: | |||
if device_id: | |||
# delete the device, which will also delete its access tokens | |||
yield defer.ensureDeferred( | |||
self._hs.get_device_handler().delete_devices(user_id, [device_id]) | |||
self._device_handler.delete_devices(user_id, [device_id]) | |||
) | |||
else: | |||
# no associated device. Just delete the access token. | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Tuple | |||
from typing import TYPE_CHECKING, Optional, Tuple | |||
from twisted.web.server import Request | |||
@@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.device_list_updater = hs.get_device_handler().device_list_updater | |||
from synapse.handlers.device import DeviceHandler | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_list_updater = handler.device_list_updater | |||
self.store = hs.get_datastores().main | |||
self.clock = hs.get_clock() | |||
@@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): | |||
async def _handle_request( # type: ignore[override] | |||
self, request: Request, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, Optional[JsonDict]]: | |||
user_devices = await self.device_list_updater.user_device_resync(user_id) | |||
return 200, user_devices | |||
@@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | |||
""" | |||
Register all the admin servlets. | |||
""" | |||
# Admin servlets aren't registered on workers. | |||
if hs.config.worker.worker_app is not None: | |||
return | |||
register_servlets_for_client_rest_resource(hs, http_server) | |||
BlockRoomRestServlet(hs).register(http_server) | |||
ListRoomRestServlet(hs).register(http_server) | |||
@@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | |||
UserTokenRestServlet(hs).register(http_server) | |||
UserRestServletV2(hs).register(http_server) | |||
UsersRestServletV2(hs).register(http_server) | |||
DeviceRestServlet(hs).register(http_server) | |||
DevicesRestServlet(hs).register(http_server) | |||
DeleteDevicesRestServlet(hs).register(http_server) | |||
UserMediaStatisticsRestServlet(hs).register(http_server) | |||
EventReportDetailRestServlet(hs).register(http_server) | |||
EventReportsRestServlet(hs).register(http_server) | |||
@@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | |||
UserByExternalId(hs).register(http_server) | |||
UserByThreePid(hs).register(http_server) | |||
# Some servlets only get registered for the main process. | |||
if hs.config.worker.worker_app is None: | |||
SendServerNoticeServlet(hs).register(http_server) | |||
BackgroundUpdateEnabledRestServlet(hs).register(http_server) | |||
BackgroundUpdateRestServlet(hs).register(http_server) | |||
BackgroundUpdateStartJobRestServlet(hs).register(http_server) | |||
DeviceRestServlet(hs).register(http_server) | |||
DevicesRestServlet(hs).register(http_server) | |||
DeleteDevicesRestServlet(hs).register(http_server) | |||
SendServerNoticeServlet(hs).register(http_server) | |||
BackgroundUpdateEnabledRestServlet(hs).register(http_server) | |||
BackgroundUpdateRestServlet(hs).register(http_server) | |||
BackgroundUpdateStartJobRestServlet(hs).register(http_server) | |||
def register_servlets_for_client_rest_resource( | |||
@@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( | |||
"""Register only the servlets which need to be exposed on /_matrix/client/xxx""" | |||
WhoisRestServlet(hs).register(http_server) | |||
PurgeHistoryStatusRestServlet(hs).register(http_server) | |||
DeactivateAccountRestServlet(hs).register(http_server) | |||
PurgeHistoryRestServlet(hs).register(http_server) | |||
ResetPasswordRestServlet(hs).register(http_server) | |||
# The following resources can only be run on the main process. | |||
if hs.config.worker.worker_app is None: | |||
DeactivateAccountRestServlet(hs).register(http_server) | |||
ResetPasswordRestServlet(hs).register(http_server) | |||
SearchUsersRestServlet(hs).register(http_server) | |||
UserRegisterServlet(hs).register(http_server) | |||
AccountValidityRenewServlet(hs).register(http_server) | |||
@@ -16,6 +16,7 @@ from http import HTTPStatus | |||
from typing import TYPE_CHECKING, Tuple | |||
from synapse.api.errors import NotFoundError, SynapseError | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.http.servlet import ( | |||
RestServlet, | |||
assert_params_in_dict, | |||
@@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__() | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
self.store = hs.get_datastores().main | |||
self.is_mine = hs.is_mine | |||
@@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer"): | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
self.store = hs.get_datastores().main | |||
self.is_mine = hs.is_mine | |||
@@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer"): | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
self.store = hs.get_datastores().main | |||
self.is_mine = hs.is_mine | |||
@@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr | |||
from synapse.api import errors | |||
from synapse.api.errors import NotFoundError | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.http.server import HttpServer | |||
from synapse.http.servlet import ( | |||
RestServlet, | |||
@@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): | |||
super().__init__() | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
self.auth_handler = hs.get_auth_handler() | |||
class PostBody(RequestBodyModel): | |||
@@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): | |||
super().__init__() | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
self.auth_handler = hs.get_auth_handler() | |||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled | |||
@@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): | |||
super().__init__() | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |||
requester = await self.auth.get_user_by_req(request) | |||
@@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): | |||
super().__init__() | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
self.device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.device_handler = handler | |||
class PostBody(RequestBodyModel): | |||
device_id: StrictStr | |||
@@ -15,6 +15,7 @@ | |||
import logging | |||
from typing import TYPE_CHECKING, Tuple | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.http.server import HttpServer | |||
from synapse.http.servlet import RestServlet | |||
from synapse.http.site import SynapseRequest | |||
@@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): | |||
super().__init__() | |||
self.auth = hs.get_auth() | |||
self._auth_handler = hs.get_auth_handler() | |||
self._device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self._device_handler = handler | |||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |||
requester = await self.auth.get_user_by_req(request, allow_expired=True) | |||
@@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): | |||
super().__init__() | |||
self.auth = hs.get_auth() | |||
self._auth_handler = hs.get_auth_handler() | |||
self._device_handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self._device_handler = handler | |||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |||
requester = await self.auth.get_user_by_req(request, allow_expired=True) | |||
@@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
) | |||
@cache_in_self | |||
def get_device_handler(self): | |||
def get_device_handler(self) -> DeviceWorkerHandler: | |||
if self.config.worker.worker_app: | |||
return DeviceWorkerHandler(self) | |||
else: | |||
@@ -19,7 +19,7 @@ from typing import Optional | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.errors import NotFoundError, SynapseError | |||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN | |||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb" | |||
class DeviceTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver("server", federation_http_client=None) | |||
self.handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.handler = handler | |||
self.store = hs.get_datastores().main | |||
return hs | |||
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(res, "fco") | |||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) | |||
assert dev is not None | |||
self.assertEqual(dev["display_name"], "display name") | |||
def test_device_is_preserved_if_exists(self) -> None: | |||
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(res2, "fco") | |||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) | |||
assert dev is not None | |||
self.assertEqual(dev["display_name"], "display name") | |||
def test_device_id_is_made_up_if_unspecified(self) -> None: | |||
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
) | |||
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) | |||
assert dev is not None | |||
self.assertEqual(dev["display_name"], "display") | |||
def test_get_devices_by_user(self) -> None: | |||
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
class DehydrationTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver("server", federation_http_client=None) | |||
self.handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.handler = handler | |||
self.registration = hs.get_registration_handler() | |||
self.auth = hs.get_auth() | |||
self.store = hs.get_datastores().main | |||
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
retrieved_device_id, device_data = self.get_success( | |||
self.handler.get_dehydrated_device(user_id=user_id) | |||
) | |||
result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) | |||
assert result is not None | |||
retrieved_device_id, device_data = result | |||
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) | |||
self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) | |||
@@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||
import synapse.rest.admin | |||
from synapse.api.errors import Codes | |||
from synapse.handlers.device import DeviceHandler | |||
from synapse.rest.client import login | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
@@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): | |||
] | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.handler = hs.get_device_handler() | |||
handler = hs.get_device_handler() | |||
assert isinstance(handler, DeviceHandler) | |||
self.handler = handler | |||
self.admin_user = self.register_user("admin", "pass", admin=True) | |||
self.admin_user_tok = self.login("admin", "pass") | |||