You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

579 lines
20 KiB

  1. # Copyright 2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. # Copyright 2020 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Optional
  17. from unittest import mock
  18. from twisted.internet.defer import ensureDeferred
  19. from twisted.test.proto_helpers import MemoryReactor
  20. from synapse.api.constants import RoomEncryptionAlgorithms
  21. from synapse.api.errors import NotFoundError, SynapseError
  22. from synapse.appservice import ApplicationService
  23. from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
  24. from synapse.rest import admin
  25. from synapse.rest.client import devices, login, register
  26. from synapse.server import HomeServer
  27. from synapse.storage.databases.main.appservice import _make_exclusive_regex
  28. from synapse.types import JsonDict, create_requester
  29. from synapse.util import Clock
  30. from tests import unittest
  31. from tests.unittest import override_config
  32. user1 = "@boris:aaa"
  33. user2 = "@theresa:bbb"
  34. class DeviceTestCase(unittest.HomeserverTestCase):
  35. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  36. self.appservice_api = mock.AsyncMock()
  37. hs = self.setup_test_homeserver(
  38. "server",
  39. application_service_api=self.appservice_api,
  40. )
  41. handler = hs.get_device_handler()
  42. assert isinstance(handler, DeviceHandler)
  43. self.handler = handler
  44. self.store = hs.get_datastores().main
  45. return hs
  46. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  47. # These tests assume that it starts 1000 seconds in.
  48. self.reactor.advance(1000)
  49. def test_device_is_created_with_invalid_name(self) -> None:
  50. self.get_failure(
  51. self.handler.check_device_registered(
  52. user_id="@boris:foo",
  53. device_id="foo",
  54. initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
  55. ),
  56. SynapseError,
  57. )
  58. def test_device_is_created_if_doesnt_exist(self) -> None:
  59. res = self.get_success(
  60. self.handler.check_device_registered(
  61. user_id="@boris:foo",
  62. device_id="fco",
  63. initial_device_display_name="display name",
  64. )
  65. )
  66. self.assertEqual(res, "fco")
  67. dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
  68. assert dev is not None
  69. self.assertEqual(dev["display_name"], "display name")
  70. def test_device_is_preserved_if_exists(self) -> None:
  71. res1 = self.get_success(
  72. self.handler.check_device_registered(
  73. user_id="@boris:foo",
  74. device_id="fco",
  75. initial_device_display_name="display name",
  76. )
  77. )
  78. self.assertEqual(res1, "fco")
  79. res2 = self.get_success(
  80. self.handler.check_device_registered(
  81. user_id="@boris:foo",
  82. device_id="fco",
  83. initial_device_display_name="new display name",
  84. )
  85. )
  86. self.assertEqual(res2, "fco")
  87. dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
  88. assert dev is not None
  89. self.assertEqual(dev["display_name"], "display name")
  90. def test_device_id_is_made_up_if_unspecified(self) -> None:
  91. device_id = self.get_success(
  92. self.handler.check_device_registered(
  93. user_id="@theresa:foo",
  94. device_id=None,
  95. initial_device_display_name="display",
  96. )
  97. )
  98. dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
  99. assert dev is not None
  100. self.assertEqual(dev["display_name"], "display")
  101. def test_get_devices_by_user(self) -> None:
  102. self._record_users()
  103. res = self.get_success(self.handler.get_devices_by_user(user1))
  104. self.assertEqual(3, len(res))
  105. device_map = {d["device_id"]: d for d in res}
  106. self.assertLessEqual(
  107. {
  108. "user_id": user1,
  109. "device_id": "xyz",
  110. "display_name": "display 0",
  111. "last_seen_ip": None,
  112. "last_seen_ts": None,
  113. }.items(),
  114. device_map["xyz"].items(),
  115. )
  116. self.assertLessEqual(
  117. {
  118. "user_id": user1,
  119. "device_id": "fco",
  120. "display_name": "display 1",
  121. "last_seen_ip": "ip1",
  122. "last_seen_ts": 1000000,
  123. }.items(),
  124. device_map["fco"].items(),
  125. )
  126. self.assertLessEqual(
  127. {
  128. "user_id": user1,
  129. "device_id": "abc",
  130. "display_name": "display 2",
  131. "last_seen_ip": "ip3",
  132. "last_seen_ts": 3000000,
  133. }.items(),
  134. device_map["abc"].items(),
  135. )
  136. def test_get_device(self) -> None:
  137. self._record_users()
  138. res = self.get_success(self.handler.get_device(user1, "abc"))
  139. self.assertLessEqual(
  140. {
  141. "user_id": user1,
  142. "device_id": "abc",
  143. "display_name": "display 2",
  144. "last_seen_ip": "ip3",
  145. "last_seen_ts": 3000000,
  146. }.items(),
  147. res.items(),
  148. )
  149. def test_delete_device(self) -> None:
  150. self._record_users()
  151. # delete the device
  152. self.get_success(self.handler.delete_devices(user1, ["abc"]))
  153. # check the device was deleted
  154. self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
  155. # we'd like to check the access token was invalidated, but that's a
  156. # bit of a PITA.
  157. def test_delete_device_and_device_inbox(self) -> None:
  158. self._record_users()
  159. # add an device_inbox
  160. self.get_success(
  161. self.store.db_pool.simple_insert(
  162. "device_inbox",
  163. {
  164. "user_id": user1,
  165. "device_id": "abc",
  166. "stream_id": 1,
  167. "message_json": "{}",
  168. },
  169. )
  170. )
  171. # delete the device
  172. self.get_success(self.handler.delete_devices(user1, ["abc"]))
  173. # check that the device_inbox was deleted
  174. res = self.get_success(
  175. self.store.db_pool.simple_select_one(
  176. table="device_inbox",
  177. keyvalues={"user_id": user1, "device_id": "abc"},
  178. retcols=("user_id", "device_id"),
  179. allow_none=True,
  180. desc="get_device_id_from_device_inbox",
  181. )
  182. )
  183. self.assertIsNone(res)
  184. def test_update_device(self) -> None:
  185. self._record_users()
  186. update = {"display_name": "new display"}
  187. self.get_success(self.handler.update_device(user1, "abc", update))
  188. res = self.get_success(self.handler.get_device(user1, "abc"))
  189. self.assertEqual(res["display_name"], "new display")
  190. def test_update_device_too_long_display_name(self) -> None:
  191. """Update a device with a display name that is invalid (too long)."""
  192. self._record_users()
  193. # Request to update a device display name with a new value that is longer than allowed.
  194. update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
  195. self.get_failure(
  196. self.handler.update_device(user1, "abc", update),
  197. SynapseError,
  198. )
  199. # Ensure the display name was not updated.
  200. res = self.get_success(self.handler.get_device(user1, "abc"))
  201. self.assertEqual(res["display_name"], "display 2")
  202. def test_update_unknown_device(self) -> None:
  203. update = {"display_name": "new_display"}
  204. self.get_failure(
  205. self.handler.update_device("user_id", "unknown_device_id", update),
  206. NotFoundError,
  207. )
  208. def _record_users(self) -> None:
  209. # check this works for both devices which have a recorded client_ip,
  210. # and those which don't.
  211. self._record_user(user1, "xyz", "display 0")
  212. self._record_user(user1, "fco", "display 1", "token1", "ip1")
  213. self._record_user(user1, "abc", "display 2", "token2", "ip2")
  214. self._record_user(user1, "abc", "display 2", "token3", "ip3")
  215. self._record_user(user2, "def", "dispkay", "token4", "ip4")
  216. self.reactor.advance(10000)
  217. def _record_user(
  218. self,
  219. user_id: str,
  220. device_id: str,
  221. display_name: str,
  222. access_token: Optional[str] = None,
  223. ip: Optional[str] = None,
  224. ) -> None:
  225. device_id = self.get_success(
  226. self.handler.check_device_registered(
  227. user_id=user_id,
  228. device_id=device_id,
  229. initial_device_display_name=display_name,
  230. )
  231. )
  232. if access_token is not None and ip is not None:
  233. self.get_success(
  234. self.store.insert_client_ip(
  235. user_id, access_token, ip, "user_agent", device_id
  236. )
  237. )
  238. self.reactor.advance(1000)
  239. @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
  240. def test_on_federation_query_user_devices_appservice(self) -> None:
  241. """Test that querying of appservices for keys overrides responses from the database."""
  242. local_user = "@boris:" + self.hs.hostname
  243. device_1 = "abc"
  244. device_2 = "def"
  245. device_3 = "ghi"
  246. # There are 3 devices:
  247. #
  248. # 1. One which is uploaded to the homeserver.
  249. # 2. One which is uploaded to the homeserver, but a newer copy is returned
  250. # by the appservice.
  251. # 3. One which is only returned by the appservice.
  252. device_key_1: JsonDict = {
  253. "user_id": local_user,
  254. "device_id": device_1,
  255. "algorithms": [
  256. "m.olm.curve25519-aes-sha2",
  257. RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
  258. ],
  259. "keys": {
  260. "ed25519:abc": "base64+ed25519+key",
  261. "curve25519:abc": "base64+curve25519+key",
  262. },
  263. "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
  264. }
  265. device_key_2a: JsonDict = {
  266. "user_id": local_user,
  267. "device_id": device_2,
  268. "algorithms": [
  269. "m.olm.curve25519-aes-sha2",
  270. RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
  271. ],
  272. "keys": {
  273. "ed25519:def": "base64+ed25519+key",
  274. "curve25519:def": "base64+curve25519+key",
  275. },
  276. "signatures": {local_user: {"ed25519:def": "base64+signature"}},
  277. }
  278. device_key_2b: JsonDict = {
  279. "user_id": local_user,
  280. "device_id": device_2,
  281. "algorithms": [
  282. "m.olm.curve25519-aes-sha2",
  283. RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
  284. ],
  285. # The device ID is the same (above), but the keys are different.
  286. "keys": {
  287. "ed25519:xyz": "base64+ed25519+key",
  288. "curve25519:xyz": "base64+curve25519+key",
  289. },
  290. "signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
  291. }
  292. device_key_3: JsonDict = {
  293. "user_id": local_user,
  294. "device_id": device_3,
  295. "algorithms": [
  296. "m.olm.curve25519-aes-sha2",
  297. RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
  298. ],
  299. "keys": {
  300. "ed25519:jkl": "base64+ed25519+key",
  301. "curve25519:jkl": "base64+curve25519+key",
  302. },
  303. "signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
  304. }
  305. # Upload keys for devices 1 & 2a.
  306. e2e_keys_handler = self.hs.get_e2e_keys_handler()
  307. self.get_success(
  308. e2e_keys_handler.upload_keys_for_user(
  309. local_user, device_1, {"device_keys": device_key_1}
  310. )
  311. )
  312. self.get_success(
  313. e2e_keys_handler.upload_keys_for_user(
  314. local_user, device_2, {"device_keys": device_key_2a}
  315. )
  316. )
  317. # Inject an appservice interested in this user.
  318. appservice = ApplicationService(
  319. token="i_am_an_app_service",
  320. id="1234",
  321. namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
  322. # Note: this user does not have to match the regex above
  323. sender="@as_main:test",
  324. )
  325. self.hs.get_datastores().main.services_cache = [appservice]
  326. self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
  327. [appservice]
  328. )
  329. # Setup a response.
  330. self.appservice_api.query_keys.return_value = {
  331. "device_keys": {
  332. local_user: {device_2: device_key_2b, device_3: device_key_3}
  333. }
  334. }
  335. # Request all devices.
  336. res = self.get_success(
  337. self.handler.on_federation_query_user_devices(local_user)
  338. )
  339. self.assertIn("devices", res)
  340. res_devices = res["devices"]
  341. for device in res_devices:
  342. device["keys"].pop("unsigned", None)
  343. self.assertEqual(
  344. res_devices,
  345. [
  346. {"device_id": device_1, "keys": device_key_1},
  347. {"device_id": device_2, "keys": device_key_2b},
  348. {"device_id": device_3, "keys": device_key_3},
  349. ],
  350. )
  351. class DehydrationTestCase(unittest.HomeserverTestCase):
  352. servlets = [
  353. admin.register_servlets_for_client_rest_resource,
  354. login.register_servlets,
  355. register.register_servlets,
  356. devices.register_servlets,
  357. ]
  358. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  359. hs = self.setup_test_homeserver("server")
  360. handler = hs.get_device_handler()
  361. assert isinstance(handler, DeviceHandler)
  362. self.handler = handler
  363. self.message_handler = hs.get_device_message_handler()
  364. self.registration = hs.get_registration_handler()
  365. self.auth = hs.get_auth()
  366. self.store = hs.get_datastores().main
  367. return hs
  368. def test_dehydrate_and_rehydrate_device(self) -> None:
  369. user_id = "@boris:dehydration"
  370. self.get_success(self.store.register_user(user_id, "foobar"))
  371. # First check if we can store and fetch a dehydrated device
  372. stored_dehydrated_device_id = self.get_success(
  373. self.handler.store_dehydrated_device(
  374. user_id=user_id,
  375. device_id=None,
  376. device_data={"device_data": {"foo": "bar"}},
  377. initial_device_display_name="dehydrated device",
  378. )
  379. )
  380. result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
  381. assert result is not None
  382. retrieved_device_id, device_data = result
  383. self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
  384. self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
  385. # Create a new login for the user and dehydrated the device
  386. device_id, access_token, _expiration_time, _refresh_token = self.get_success(
  387. self.registration.register_device(
  388. user_id=user_id,
  389. device_id=None,
  390. initial_display_name="new device",
  391. )
  392. )
  393. # Trying to claim a nonexistent device should throw an error
  394. self.get_failure(
  395. self.handler.rehydrate_device(
  396. user_id=user_id,
  397. access_token=access_token,
  398. device_id="not the right device ID",
  399. ),
  400. NotFoundError,
  401. )
  402. # dehydrating the right devices should succeed and change our device ID
  403. # to the dehydrated device's ID
  404. res = self.get_success(
  405. self.handler.rehydrate_device(
  406. user_id=user_id,
  407. access_token=access_token,
  408. device_id=retrieved_device_id,
  409. )
  410. )
  411. self.assertEqual(res, {"success": True})
  412. # make sure that our device ID has changed
  413. user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
  414. self.assertEqual(user_info.device_id, retrieved_device_id)
  415. # make sure the device has the display name that was set from the login
  416. res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
  417. self.assertEqual(res["display_name"], "new device")
  418. # make sure that the device ID that we were initially assigned no longer exists
  419. self.get_failure(
  420. self.handler.get_device(user_id, device_id),
  421. NotFoundError,
  422. )
  423. # make sure that there's no device available for dehydrating now
  424. ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
  425. self.assertIsNone(ret)
  426. @unittest.override_config(
  427. {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
  428. )
  429. def test_dehydrate_v2_and_fetch_events(self) -> None:
  430. user_id = "@boris:server"
  431. self.get_success(self.store.register_user(user_id, "foobar"))
  432. # First check if we can store and fetch a dehydrated device
  433. stored_dehydrated_device_id = self.get_success(
  434. self.handler.store_dehydrated_device(
  435. user_id=user_id,
  436. device_id=None,
  437. device_data={"device_data": {"foo": "bar"}},
  438. initial_device_display_name="dehydrated device",
  439. )
  440. )
  441. device_info = self.get_success(
  442. self.handler.get_dehydrated_device(user_id=user_id)
  443. )
  444. assert device_info is not None
  445. retrieved_device_id, device_data = device_info
  446. self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
  447. self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
  448. # Create a new login for the user
  449. device_id, access_token, _expiration_time, _refresh_token = self.get_success(
  450. self.registration.register_device(
  451. user_id=user_id,
  452. device_id=None,
  453. initial_display_name="new device",
  454. )
  455. )
  456. requester = create_requester(user_id, device_id=device_id)
  457. # Fetching messages for a non-existing device should return an error
  458. self.get_failure(
  459. self.message_handler.get_events_for_dehydrated_device(
  460. requester=requester,
  461. device_id="not the right device ID",
  462. since_token=None,
  463. limit=10,
  464. ),
  465. SynapseError,
  466. )
  467. # Send a message to the dehydrated device
  468. ensureDeferred(
  469. self.message_handler.send_device_message(
  470. requester=requester,
  471. message_type="test.message",
  472. messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
  473. )
  474. )
  475. self.pump()
  476. # Fetch the message of the dehydrated device
  477. res = self.get_success(
  478. self.message_handler.get_events_for_dehydrated_device(
  479. requester=requester,
  480. device_id=stored_dehydrated_device_id,
  481. since_token=None,
  482. limit=10,
  483. )
  484. )
  485. self.assertTrue(len(res["next_batch"]) > 1)
  486. self.assertEqual(len(res["events"]), 1)
  487. self.assertEqual(res["events"][0]["content"]["body"], "foo")
  488. # Fetch the message of the dehydrated device again, which should return
  489. # the same message as it has not been deleted
  490. res = self.get_success(
  491. self.message_handler.get_events_for_dehydrated_device(
  492. requester=requester,
  493. device_id=stored_dehydrated_device_id,
  494. since_token=None,
  495. limit=10,
  496. )
  497. )
  498. self.assertTrue(len(res["next_batch"]) > 1)
  499. self.assertEqual(len(res["events"]), 1)
  500. self.assertEqual(res["events"][0]["content"]["body"], "foo")