Pārlūkot izejas kodu

Add a type hint for `get_device_handler()` and fix incorrect types. (#14055)

This was the last untyped handler from the HomeServer object. Since
it was being treated as Any (and thus unchecked) it was being used
incorrectly in a few places.
tags/v1.73.0rc1
Patrick Cloke pirms 1 gada
committed by GitHub
vecāks
revīzija
6d47b7e325
Šim parakstam datu bāzē netika atrasta zināma atslēga GPG atslēgas ID: 4AEE18F83AFDEB23
16 mainītis faili ar 185 papildinājumiem un 77 dzēšanām
  1. +1
    -0
      changelog.d/14055.misc
  2. +4
    -0
      synapse/handlers/deactivate_account.py
  3. +50
    -15
      synapse/handlers/device.py
  4. +32
    -29
      synapse/handlers/e2e_keys.py
  5. +4
    -0
      synapse/handlers/register.py
  6. +5
    -1
      synapse/handlers/set_password.py
  7. +9
    -0
      synapse/handlers/sso.py
  8. +9
    -1
      synapse/module_api/__init__.py
  9. +8
    -3
      synapse/replication/http/devices.py
  10. +15
    -11
      synapse/rest/admin/__init__.py
  11. +10
    -3
      synapse/rest/admin/devices.py
  12. +13
    -4
      synapse/rest/client/devices.py
  13. +7
    -2
      synapse/rest/client/logout.py
  14. +1
    -1
      synapse/server.py
  15. +13
    -6
      tests/handlers/test_device.py
  16. +4
    -1
      tests/rest/admin/test_device.py

+ 1
- 0
changelog.d/14055.misc Parādīt failu

@@ -0,0 +1 @@
Add missing type hints to `HomeServer`.

+ 4
- 0
synapse/handlers/deactivate_account.py Parādīt failu

@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester

@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
True if identity server supports removing threepids, otherwise False.
"""

# This can only be called on the main process.
assert isinstance(self._device_handler, DeviceHandler)

# Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin


+ 50
- 15
synapse/handlers/device.py Parādīt failu

@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000


class DeviceWorkerHandler:
device_list_updater: "DeviceListWorkerUpdater"

def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled

self.device_list_updater = DeviceListWorkerUpdater(hs)

@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
log_kv(device_map)
return devices

async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.

Args:
user_id: the user whose dehydrated device we are looking for
Returns:
a tuple whose first item is the device ID, and the second item is
the dehydrated device information
"""
return await self.store.get_dehydrated_device(user_id)

@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:


class DeviceHandler(DeviceWorkerHandler):
device_list_updater: "DeviceListUpdater"

def __init__(self, hs: "HomeServer"):
super().__init__(hs)

@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id])
return device_id

async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.

Args:
user_id: the user whose dehydrated device we are looking for
Returns:
a tuple whose first item is the device ID, and the second item is
the dehydrated device information
"""
return await self.store.get_dehydrated_device(user_id)

async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
@@ -882,7 +888,36 @@ def _update_device_from_client_ips(
)


class DeviceListUpdater:
class DeviceListWorkerUpdater:
"Handles incoming device list updates from federation and contacts the main process over replication"

def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationUserDevicesResyncRestServlet,
)

self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)

async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.

Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
return await self._user_device_resync_client(user_id=user_id)


class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"

def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):


+ 32
- 29
synapse/handlers/e2e_keys.py Parādīt failu

@@ -27,9 +27,9 @@ from twisted.internet import defer

from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
@@ -56,27 +56,23 @@ class E2eKeysHandler:
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

self._edu_updater = SigningKeyEduUpdater(hs, self)

federation_registry = hs.get_federation_registry()

self._is_master = hs.config.worker.worker_app is None
if not self._is_master:
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
is_master = hs.config.worker.worker_app is None
if is_master:
edu_updater = SigningKeyEduUpdater(hs)

# Only register this edu handler on master as it requires writing
# device updates to the db
federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)
# also handle the unstable version
# FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update,
edu_updater.incoming_signing_key_update,
)

# doesn't really work as part of the generic query API, because the
@@ -319,14 +315,13 @@ class E2eKeysHandler:
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
resync_results = await self.device_handler.device_list_updater.user_device_resync(
resync_results = (
await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
resync_results = await self._user_device_resync_client(
user_id=user_id
)
)
if resync_results is None:
raise ValueError("Device resync failed")

# Add the device keys to the results.
user_devices = resync_results["devices"]
@@ -605,6 +600,8 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

time_now = self.clock.time_msec()

@@ -732,6 +729,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys
keys: the signing keys
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

# if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys
@@ -823,6 +822,9 @@ class E2eKeysHandler:
Raises:
SynapseError: if the signatures dict is not valid.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

failures = {}

# signatures to be stored. Each item will be a SignatureListItem
@@ -1200,6 +1202,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)

try:
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
@@ -1396,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""

def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler

device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

@@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing
"""

device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater

async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
@@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)

for master_key, self_signing_key in pending_updates:
new_device_ids = (
await device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
device_ids = device_ids + new_device_ids

await device_handler.notify_device_update(user_id, device_ids)
await self._device_handler.notify_device_update(user_id, device_ids)

+ 4
- 0
synapse/handlers/register.py Parādīt failu

@@ -38,6 +38,7 @@ from synapse.api.errors import (
)
from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
@@ -841,6 +842,9 @@ class RegistrationHandler:
refresh_token = None
refresh_token_id = None

# This can only run on the main process.
assert isinstance(self.device_handler, DeviceHandler)

registered_device_id = await self.device_handler.check_device_registered(
user_id,
device_id,


+ 5
- 1
synapse/handlers/set_password.py Parādīt failu

@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.types import Requester

if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
# This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler

async def set_password(
self,


+ 9
- 0
synapse/handlers/sso.py Parādīt failu

@@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
@@ -1035,6 +1036,8 @@ class SsoHandler:
) -> None:
"""Revoke any devices and in-flight logins tied to a provider session.

Can only be called from the main process.

Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@@ -1042,6 +1045,12 @@ class SsoHandler:
expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error.
"""

# It is expected that this is the main process.
assert isinstance(
self._device_handler, DeviceHandler
), "revoking SSO sessions can only be called on the main process"

# Invalidate any running user-mapping sessions
to_delete = []
for session_id, session in self._username_mapping_sessions.items():


+ 9
- 1
synapse/module_api/__init__.py Parādīt failu

@@ -86,6 +86,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
@@ -207,6 +208,7 @@ class ModuleApi:
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory

try:
@@ -784,6 +786,8 @@ class ModuleApi:
) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user

Can only be called from the main process.

Added in Synapse v0.25.0.

Args:
@@ -796,6 +800,10 @@ class ModuleApi:
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
assert isinstance(
self._device_handler, DeviceHandler
), "invalidate_access_token can only be called on the main process"

# see if the access token corresponds to a device
user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token)
@@ -805,7 +813,7 @@ class ModuleApi:
if device_id:
# delete the device, which will also delete its access tokens
yield defer.ensureDeferred(
self._hs.get_device_handler().delete_devices(user_id, [device_id])
self._device_handler.delete_devices(user_id, [device_id])
)
else:
# no associated device. Just delete the access token.


+ 8
- 3
synapse/replication/http/devices.py Parādīt failu

@@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from twisted.web.server import Request

@@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.device_list_updater = hs.get_device_handler().device_list_updater
from synapse.handlers.device import DeviceHandler

handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_list_updater = handler.device_list_updater

self.store = hs.get_datastores().main
self.clock = hs.get_clock()

@@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):

async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

return 200, user_devices


+ 15
- 11
synapse/rest/admin/__init__.py Parādīt failu

@@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
"""
Register all the admin servlets.
"""
# Admin servlets aren't registered on workers.
if hs.config.worker.worker_app is not None:
return

register_servlets_for_client_rest_resource(hs, http_server)
BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server)
@@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
@@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserByExternalId(hs).register(http_server)
UserByThreePid(hs).register(http_server)

# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:
SendServerNoticeServlet(hs).register(http_server)
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)


def register_servlets_for_client_rest_resource(
@@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource(
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
# The following resources can only be run on the main process.
if hs.config.worker.worker_app is None:
DeactivateAccountRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)


+ 10
- 3
synapse/rest/admin/devices.py Parādīt failu

@@ -16,6 +16,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine

@@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):

def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine

@@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):

def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine



+ 13
- 4
synapse/rest/client/devices.py Parādīt failu

@@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr

from synapse.api import errors
from synapse.api.errors import NotFoundError
from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.auth_handler = hs.get_auth_handler()

class PostBody(RequestBodyModel):
@@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled

@@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler

class PostBody(RequestBodyModel):
device_id: StrictStr


+ 7
- 2
synapse/rest/client/logout.py Parādīt failu

@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
@@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self._device_handler = handler

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
@@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self._device_handler = handler

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)


+ 1
- 1
synapse/server.py Parādīt failu

@@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta):
)

@cache_in_self
def get_device_handler(self):
def get_device_handler(self) -> DeviceWorkerHandler:
if self.config.worker.worker_app:
return DeviceWorkerHandler(self)
else:


+ 13
- 6
tests/handlers/test_device.py Parādīt failu

@@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.util import Clock

@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.store = hs.get_datastores().main
return hs

@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco")

dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
assert dev is not None
self.assertEqual(dev["display_name"], "display name")

def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco")

dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
assert dev is not None
self.assertEqual(dev["display_name"], "display name")

def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)

dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
assert dev is not None
self.assertEqual(dev["display_name"], "display")

def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
)
)

retrieved_device_id, device_data = self.get_success(
self.handler.get_dehydrated_device(user_id=user_id)
)
result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
assert result is not None
retrieved_device_id, device_data = result

self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})


+ 4
- 1
tests/rest/admin/test_device.py Parādīt failu

@@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.handlers.device import DeviceHandler
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_device_handler()
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler

self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")


Notiek ielāde…
Atcelt
Saglabāt