You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1312 lines
43 KiB

  1. # Copyright 2022 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. from email.parser import Parser
  17. from http import HTTPStatus
  18. from typing import Any, Dict, List, Optional, Union
  19. from unittest.mock import Mock
  20. import pkg_resources
  21. from twisted.internet.interfaces import IReactorTCP
  22. from twisted.test.proto_helpers import MemoryReactor
  23. import synapse.rest.admin
  24. from synapse.api.constants import LoginType, Membership
  25. from synapse.api.errors import Codes, HttpResponseException
  26. from synapse.appservice import ApplicationService
  27. from synapse.rest import admin
  28. from synapse.rest.client import account, login, register, room
  29. from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
  30. from synapse.server import HomeServer
  31. from synapse.types import JsonDict, UserID
  32. from synapse.util import Clock
  33. from tests import unittest
  34. from tests.server import FakeSite, make_request
  35. from tests.unittest import override_config
  36. class PasswordResetTestCase(unittest.HomeserverTestCase):
  37. servlets = [
  38. account.register_servlets,
  39. synapse.rest.admin.register_servlets_for_client_rest_resource,
  40. register.register_servlets,
  41. login.register_servlets,
  42. ]
  43. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  44. config = self.default_config()
  45. # Email config.
  46. config["email"] = {
  47. "enable_notifs": False,
  48. "template_dir": os.path.abspath(
  49. pkg_resources.resource_filename("synapse", "res/templates")
  50. ),
  51. "smtp_host": "127.0.0.1",
  52. "smtp_port": 20,
  53. "require_transport_security": False,
  54. "smtp_user": None,
  55. "smtp_pass": None,
  56. "notif_from": "test@example.com",
  57. }
  58. config["public_baseurl"] = "https://example.com"
  59. hs = self.setup_test_homeserver(config=config)
  60. async def sendmail(
  61. reactor: IReactorTCP,
  62. smtphost: str,
  63. smtpport: int,
  64. from_addr: str,
  65. to_addr: str,
  66. msg_bytes: bytes,
  67. *args: Any,
  68. **kwargs: Any,
  69. ) -> None:
  70. self.email_attempts.append(msg_bytes)
  71. self.email_attempts: List[bytes] = []
  72. hs.get_send_email_handler()._sendmail = sendmail
  73. return hs
  74. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  75. self.store = hs.get_datastores().main
  76. self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
  77. def attempt_wrong_password_login(self, username: str, password: str) -> None:
  78. """Attempts to login as the user with the given password, asserting
  79. that the attempt *fails*.
  80. """
  81. body = {"type": "m.login.password", "user": username, "password": password}
  82. channel = self.make_request("POST", "/_matrix/client/r0/login", body)
  83. self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
  84. def test_basic_password_reset(self) -> None:
  85. """Test basic password reset flow"""
  86. old_password = "monkey"
  87. new_password = "kangeroo"
  88. user_id = self.register_user("kermit", old_password)
  89. self.login("kermit", old_password)
  90. email = "test@example.com"
  91. # Add a threepid
  92. self.get_success(
  93. self.store.user_add_threepid(
  94. user_id=user_id,
  95. medium="email",
  96. address=email,
  97. validated_at=0,
  98. added_at=0,
  99. )
  100. )
  101. client_secret = "foobar"
  102. session_id = self._request_token(email, client_secret)
  103. self.assertEqual(len(self.email_attempts), 1)
  104. link = self._get_link_from_email()
  105. self._validate_token(link)
  106. self._reset_password(new_password, session_id, client_secret)
  107. # Assert we can log in with the new password
  108. self.login("kermit", new_password)
  109. # Assert we can't log in with the old password
  110. self.attempt_wrong_password_login("kermit", old_password)
  111. @override_config({"rc_3pid_validation": {"burst_count": 3}})
  112. def test_ratelimit_by_email(self) -> None:
  113. """Test that we ratelimit /requestToken for the same email."""
  114. old_password = "monkey"
  115. new_password = "kangeroo"
  116. user_id = self.register_user("kermit", old_password)
  117. self.login("kermit", old_password)
  118. email = "test1@example.com"
  119. # Add a threepid
  120. self.get_success(
  121. self.store.user_add_threepid(
  122. user_id=user_id,
  123. medium="email",
  124. address=email,
  125. validated_at=0,
  126. added_at=0,
  127. )
  128. )
  129. def reset(ip: str) -> None:
  130. client_secret = "foobar"
  131. session_id = self._request_token(email, client_secret, ip)
  132. self.assertEqual(len(self.email_attempts), 1)
  133. link = self._get_link_from_email()
  134. self._validate_token(link)
  135. self._reset_password(new_password, session_id, client_secret)
  136. self.email_attempts.clear()
  137. # We expect to be able to make three requests before getting rate
  138. # limited.
  139. #
  140. # We change IPs to ensure that we're not being ratelimited due to the
  141. # same IP
  142. reset("127.0.0.1")
  143. reset("127.0.0.2")
  144. reset("127.0.0.3")
  145. with self.assertRaises(HttpResponseException) as cm:
  146. reset("127.0.0.4")
  147. self.assertEqual(cm.exception.code, 429)
  148. def test_basic_password_reset_canonicalise_email(self) -> None:
  149. """Test basic password reset flow
  150. Request password reset with different spelling
  151. """
  152. old_password = "monkey"
  153. new_password = "kangeroo"
  154. user_id = self.register_user("kermit", old_password)
  155. self.login("kermit", old_password)
  156. email_profile = "test@example.com"
  157. email_passwort_reset = "TEST@EXAMPLE.COM"
  158. # Add a threepid
  159. self.get_success(
  160. self.store.user_add_threepid(
  161. user_id=user_id,
  162. medium="email",
  163. address=email_profile,
  164. validated_at=0,
  165. added_at=0,
  166. )
  167. )
  168. client_secret = "foobar"
  169. session_id = self._request_token(email_passwort_reset, client_secret)
  170. self.assertEqual(len(self.email_attempts), 1)
  171. link = self._get_link_from_email()
  172. self._validate_token(link)
  173. self._reset_password(new_password, session_id, client_secret)
  174. # Assert we can log in with the new password
  175. self.login("kermit", new_password)
  176. # Assert we can't log in with the old password
  177. self.attempt_wrong_password_login("kermit", old_password)
  178. def test_cant_reset_password_without_clicking_link(self) -> None:
  179. """Test that we do actually need to click the link in the email"""
  180. old_password = "monkey"
  181. new_password = "kangeroo"
  182. user_id = self.register_user("kermit", old_password)
  183. self.login("kermit", old_password)
  184. email = "test@example.com"
  185. # Add a threepid
  186. self.get_success(
  187. self.store.user_add_threepid(
  188. user_id=user_id,
  189. medium="email",
  190. address=email,
  191. validated_at=0,
  192. added_at=0,
  193. )
  194. )
  195. client_secret = "foobar"
  196. session_id = self._request_token(email, client_secret)
  197. self.assertEqual(len(self.email_attempts), 1)
  198. # Attempt to reset password without clicking the link
  199. self._reset_password(new_password, session_id, client_secret, expected_code=401)
  200. # Assert we can log in with the old password
  201. self.login("kermit", old_password)
  202. # Assert we can't log in with the new password
  203. self.attempt_wrong_password_login("kermit", new_password)
  204. def test_no_valid_token(self) -> None:
  205. """Test that we do actually need to request a token and can't just
  206. make a session up.
  207. """
  208. old_password = "monkey"
  209. new_password = "kangeroo"
  210. user_id = self.register_user("kermit", old_password)
  211. self.login("kermit", old_password)
  212. email = "test@example.com"
  213. # Add a threepid
  214. self.get_success(
  215. self.store.user_add_threepid(
  216. user_id=user_id,
  217. medium="email",
  218. address=email,
  219. validated_at=0,
  220. added_at=0,
  221. )
  222. )
  223. client_secret = "foobar"
  224. session_id = "weasle"
  225. # Attempt to reset password without even requesting an email
  226. self._reset_password(new_password, session_id, client_secret, expected_code=401)
  227. # Assert we can log in with the old password
  228. self.login("kermit", old_password)
  229. # Assert we can't log in with the new password
  230. self.attempt_wrong_password_login("kermit", new_password)
  231. @unittest.override_config({"request_token_inhibit_3pid_errors": True})
  232. def test_password_reset_bad_email_inhibit_error(self) -> None:
  233. """Test that triggering a password reset with an email address that isn't bound
  234. to an account doesn't leak the lack of binding for that address if configured
  235. that way.
  236. """
  237. self.register_user("kermit", "monkey")
  238. self.login("kermit", "monkey")
  239. email = "test@example.com"
  240. client_secret = "foobar"
  241. session_id = self._request_token(email, client_secret)
  242. self.assertIsNotNone(session_id)
  243. def _request_token(
  244. self,
  245. email: str,
  246. client_secret: str,
  247. ip: str = "127.0.0.1",
  248. ) -> str:
  249. channel = self.make_request(
  250. "POST",
  251. b"account/password/email/requestToken",
  252. {"client_secret": client_secret, "email": email, "send_attempt": 1},
  253. client_ip=ip,
  254. )
  255. if channel.code != 200:
  256. raise HttpResponseException(
  257. channel.code,
  258. channel.result["reason"],
  259. channel.result["body"],
  260. )
  261. return channel.json_body["sid"]
  262. def _validate_token(self, link: str) -> None:
  263. # Remove the host
  264. path = link.replace("https://example.com", "")
  265. # Load the password reset confirmation page
  266. channel = make_request(
  267. self.reactor,
  268. FakeSite(self.submit_token_resource, self.reactor),
  269. "GET",
  270. path,
  271. shorthand=False,
  272. )
  273. self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
  274. # Now POST to the same endpoint, mimicking the same behaviour as clicking the
  275. # password reset confirm button
  276. # Confirm the password reset
  277. channel = make_request(
  278. self.reactor,
  279. FakeSite(self.submit_token_resource, self.reactor),
  280. "POST",
  281. path,
  282. content=b"",
  283. shorthand=False,
  284. content_is_form=True,
  285. )
  286. self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
  287. def _get_link_from_email(self) -> str:
  288. assert self.email_attempts, "No emails have been sent"
  289. raw_msg = self.email_attempts[-1].decode("UTF-8")
  290. mail = Parser().parsestr(raw_msg)
  291. text = None
  292. for part in mail.walk():
  293. if part.get_content_type() == "text/plain":
  294. text = part.get_payload(decode=True).decode("UTF-8")
  295. break
  296. if not text:
  297. self.fail("Could not find text portion of email to parse")
  298. assert text is not None
  299. match = re.search(r"https://example.com\S+", text)
  300. assert match, "Could not find link in email"
  301. return match.group(0)
  302. def _reset_password(
  303. self,
  304. new_password: str,
  305. session_id: str,
  306. client_secret: str,
  307. expected_code: int = HTTPStatus.OK,
  308. ) -> None:
  309. channel = self.make_request(
  310. "POST",
  311. b"account/password",
  312. {
  313. "new_password": new_password,
  314. "auth": {
  315. "type": LoginType.EMAIL_IDENTITY,
  316. "threepid_creds": {
  317. "client_secret": client_secret,
  318. "sid": session_id,
  319. },
  320. },
  321. },
  322. )
  323. self.assertEqual(expected_code, channel.code, channel.result)
  324. class DeactivateTestCase(unittest.HomeserverTestCase):
  325. servlets = [
  326. synapse.rest.admin.register_servlets_for_client_rest_resource,
  327. login.register_servlets,
  328. account.register_servlets,
  329. room.register_servlets,
  330. ]
  331. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  332. self.hs = self.setup_test_homeserver()
  333. return self.hs
  334. def test_deactivate_account(self) -> None:
  335. user_id = self.register_user("kermit", "test")
  336. tok = self.login("kermit", "test")
  337. self.deactivate(user_id, tok)
  338. store = self.hs.get_datastores().main
  339. # Check that the user has been marked as deactivated.
  340. self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
  341. # Check that this access token has been invalidated.
  342. channel = self.make_request("GET", "account/whoami", access_token=tok)
  343. self.assertEqual(channel.code, 401)
  344. def test_pending_invites(self) -> None:
  345. """Tests that deactivating a user rejects every pending invite for them."""
  346. store = self.hs.get_datastores().main
  347. inviter_id = self.register_user("inviter", "test")
  348. inviter_tok = self.login("inviter", "test")
  349. invitee_id = self.register_user("invitee", "test")
  350. invitee_tok = self.login("invitee", "test")
  351. # Make @inviter:test invite @invitee:test in a new room.
  352. room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
  353. self.helper.invite(
  354. room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
  355. )
  356. # Make sure the invite is here.
  357. pending_invites = self.get_success(
  358. store.get_invited_rooms_for_local_user(invitee_id)
  359. )
  360. self.assertEqual(len(pending_invites), 1, pending_invites)
  361. self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
  362. # Deactivate @invitee:test.
  363. self.deactivate(invitee_id, invitee_tok)
  364. # Check that the invite isn't there anymore.
  365. pending_invites = self.get_success(
  366. store.get_invited_rooms_for_local_user(invitee_id)
  367. )
  368. self.assertEqual(len(pending_invites), 0, pending_invites)
  369. # Check that the membership of @invitee:test in the room is now "leave".
  370. memberships = self.get_success(
  371. store.get_rooms_for_local_user_where_membership_is(
  372. invitee_id, [Membership.LEAVE]
  373. )
  374. )
  375. self.assertEqual(len(memberships), 1, memberships)
  376. self.assertEqual(memberships[0].room_id, room_id, memberships)
  377. def deactivate(self, user_id: str, tok: str) -> None:
  378. request_data = {
  379. "auth": {
  380. "type": "m.login.password",
  381. "user": user_id,
  382. "password": "test",
  383. },
  384. "erase": False,
  385. }
  386. channel = self.make_request(
  387. "POST", "account/deactivate", request_data, access_token=tok
  388. )
  389. self.assertEqual(channel.code, 200, channel.json_body)
  390. class WhoamiTestCase(unittest.HomeserverTestCase):
  391. servlets = [
  392. synapse.rest.admin.register_servlets_for_client_rest_resource,
  393. login.register_servlets,
  394. account.register_servlets,
  395. register.register_servlets,
  396. ]
  397. def default_config(self) -> Dict[str, Any]:
  398. config = super().default_config()
  399. config["allow_guest_access"] = True
  400. return config
  401. def test_GET_whoami(self) -> None:
  402. device_id = "wouldgohere"
  403. user_id = self.register_user("kermit", "test")
  404. tok = self.login("kermit", "test", device_id=device_id)
  405. whoami = self._whoami(tok)
  406. self.assertEqual(
  407. whoami,
  408. {
  409. "user_id": user_id,
  410. "device_id": device_id,
  411. "is_guest": False,
  412. },
  413. )
  414. def test_GET_whoami_guests(self) -> None:
  415. channel = self.make_request(
  416. b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
  417. )
  418. tok = channel.json_body["access_token"]
  419. user_id = channel.json_body["user_id"]
  420. device_id = channel.json_body["device_id"]
  421. whoami = self._whoami(tok)
  422. self.assertEqual(
  423. whoami,
  424. {
  425. "user_id": user_id,
  426. "device_id": device_id,
  427. "is_guest": True,
  428. },
  429. )
  430. def test_GET_whoami_appservices(self) -> None:
  431. user_id = "@as:test"
  432. as_token = "i_am_an_app_service"
  433. appservice = ApplicationService(
  434. as_token,
  435. id="1234",
  436. namespaces={"users": [{"regex": user_id, "exclusive": True}]},
  437. sender=user_id,
  438. )
  439. self.hs.get_datastores().main.services_cache.append(appservice)
  440. whoami = self._whoami(as_token)
  441. self.assertEqual(
  442. whoami,
  443. {
  444. "user_id": user_id,
  445. "is_guest": False,
  446. },
  447. )
  448. self.assertFalse(hasattr(whoami, "device_id"))
  449. def _whoami(self, tok: str) -> JsonDict:
  450. channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
  451. self.assertEqual(channel.code, 200)
  452. return channel.json_body
  453. class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
  454. servlets = [
  455. account.register_servlets,
  456. login.register_servlets,
  457. synapse.rest.admin.register_servlets_for_client_rest_resource,
  458. ]
  459. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  460. config = self.default_config()
  461. # Email config.
  462. config["email"] = {
  463. "enable_notifs": False,
  464. "template_dir": os.path.abspath(
  465. pkg_resources.resource_filename("synapse", "res/templates")
  466. ),
  467. "smtp_host": "127.0.0.1",
  468. "smtp_port": 20,
  469. "require_transport_security": False,
  470. "smtp_user": None,
  471. "smtp_pass": None,
  472. "notif_from": "test@example.com",
  473. }
  474. config["public_baseurl"] = "https://example.com"
  475. self.hs = self.setup_test_homeserver(config=config)
  476. async def sendmail(
  477. reactor: IReactorTCP,
  478. smtphost: str,
  479. smtpport: int,
  480. from_addr: str,
  481. to_addr: str,
  482. msg_bytes: bytes,
  483. *args: Any,
  484. **kwargs: Any,
  485. ) -> None:
  486. self.email_attempts.append(msg_bytes)
  487. self.email_attempts: List[bytes] = []
  488. self.hs.get_send_email_handler()._sendmail = sendmail
  489. return self.hs
  490. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  491. self.store = hs.get_datastores().main
  492. self.user_id = self.register_user("kermit", "test")
  493. self.user_id_tok = self.login("kermit", "test")
  494. self.email = "test@example.com"
  495. self.url_3pid = b"account/3pid"
  496. def test_add_valid_email(self) -> None:
  497. self._add_email(self.email, self.email)
  498. def test_add_valid_email_second_time(self) -> None:
  499. self._add_email(self.email, self.email)
  500. self._request_token_invalid_email(
  501. self.email,
  502. expected_errcode=Codes.THREEPID_IN_USE,
  503. expected_error="Email is already in use",
  504. )
  505. def test_add_valid_email_second_time_canonicalise(self) -> None:
  506. self._add_email(self.email, self.email)
  507. self._request_token_invalid_email(
  508. "TEST@EXAMPLE.COM",
  509. expected_errcode=Codes.THREEPID_IN_USE,
  510. expected_error="Email is already in use",
  511. )
  512. def test_add_email_no_at(self) -> None:
  513. self._request_token_invalid_email(
  514. "address-without-at.bar",
  515. expected_errcode=Codes.BAD_JSON,
  516. expected_error="Unable to parse email address",
  517. )
  518. def test_add_email_two_at(self) -> None:
  519. self._request_token_invalid_email(
  520. "foo@foo@test.bar",
  521. expected_errcode=Codes.BAD_JSON,
  522. expected_error="Unable to parse email address",
  523. )
  524. def test_add_email_bad_format(self) -> None:
  525. self._request_token_invalid_email(
  526. "user@bad.example.net@good.example.com",
  527. expected_errcode=Codes.BAD_JSON,
  528. expected_error="Unable to parse email address",
  529. )
  530. def test_add_email_domain_to_lower(self) -> None:
  531. self._add_email("foo@TEST.BAR", "foo@test.bar")
  532. def test_add_email_domain_with_umlaut(self) -> None:
  533. self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
  534. def test_add_email_address_casefold(self) -> None:
  535. self._add_email("Strauß@Example.com", "strauss@example.com")
  536. def test_address_trim(self) -> None:
  537. self._add_email(" foo@test.bar ", "foo@test.bar")
  538. @override_config({"rc_3pid_validation": {"burst_count": 3}})
  539. def test_ratelimit_by_ip(self) -> None:
  540. """Tests that adding emails is ratelimited by IP"""
  541. # We expect to be able to set three emails before getting ratelimited.
  542. self._add_email("foo1@test.bar", "foo1@test.bar")
  543. self._add_email("foo2@test.bar", "foo2@test.bar")
  544. self._add_email("foo3@test.bar", "foo3@test.bar")
  545. with self.assertRaises(HttpResponseException) as cm:
  546. self._add_email("foo4@test.bar", "foo4@test.bar")
  547. self.assertEqual(cm.exception.code, 429)
  548. def test_add_email_if_disabled(self) -> None:
  549. """Test adding email to profile when doing so is disallowed"""
  550. self.hs.config.registration.enable_3pid_changes = False
  551. client_secret = "foobar"
  552. channel = self.make_request(
  553. "POST",
  554. b"/_matrix/client/unstable/account/3pid/email/requestToken",
  555. {
  556. "client_secret": client_secret,
  557. "email": "test@example.com",
  558. "send_attempt": 1,
  559. },
  560. )
  561. self.assertEqual(
  562. HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
  563. )
  564. self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
  565. def test_delete_email(self) -> None:
  566. """Test deleting an email from profile"""
  567. # Add a threepid
  568. self.get_success(
  569. self.store.user_add_threepid(
  570. user_id=self.user_id,
  571. medium="email",
  572. address=self.email,
  573. validated_at=0,
  574. added_at=0,
  575. )
  576. )
  577. channel = self.make_request(
  578. "POST",
  579. b"account/3pid/delete",
  580. {"medium": "email", "address": self.email},
  581. access_token=self.user_id_tok,
  582. )
  583. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  584. # Get user
  585. channel = self.make_request(
  586. "GET",
  587. self.url_3pid,
  588. access_token=self.user_id_tok,
  589. )
  590. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  591. self.assertFalse(channel.json_body["threepids"])
  592. def test_delete_email_if_disabled(self) -> None:
  593. """Test deleting an email from profile when disallowed"""
  594. self.hs.config.registration.enable_3pid_changes = False
  595. # Add a threepid
  596. self.get_success(
  597. self.store.user_add_threepid(
  598. user_id=self.user_id,
  599. medium="email",
  600. address=self.email,
  601. validated_at=0,
  602. added_at=0,
  603. )
  604. )
  605. channel = self.make_request(
  606. "POST",
  607. b"account/3pid/delete",
  608. {"medium": "email", "address": self.email},
  609. access_token=self.user_id_tok,
  610. )
  611. self.assertEqual(
  612. HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
  613. )
  614. self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
  615. # Get user
  616. channel = self.make_request(
  617. "GET",
  618. self.url_3pid,
  619. access_token=self.user_id_tok,
  620. )
  621. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  622. self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
  623. self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
  624. def test_cant_add_email_without_clicking_link(self) -> None:
  625. """Test that we do actually need to click the link in the email"""
  626. client_secret = "foobar"
  627. session_id = self._request_token(self.email, client_secret)
  628. self.assertEqual(len(self.email_attempts), 1)
  629. # Attempt to add email without clicking the link
  630. channel = self.make_request(
  631. "POST",
  632. b"/_matrix/client/unstable/account/3pid/add",
  633. {
  634. "client_secret": client_secret,
  635. "sid": session_id,
  636. "auth": {
  637. "type": "m.login.password",
  638. "user": self.user_id,
  639. "password": "test",
  640. },
  641. },
  642. access_token=self.user_id_tok,
  643. )
  644. self.assertEqual(
  645. HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
  646. )
  647. self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
  648. # Get user
  649. channel = self.make_request(
  650. "GET",
  651. self.url_3pid,
  652. access_token=self.user_id_tok,
  653. )
  654. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  655. self.assertFalse(channel.json_body["threepids"])
  656. def test_no_valid_token(self) -> None:
  657. """Test that we do actually need to request a token and can't just
  658. make a session up.
  659. """
  660. client_secret = "foobar"
  661. session_id = "weasle"
  662. # Attempt to add email without even requesting an email
  663. channel = self.make_request(
  664. "POST",
  665. b"/_matrix/client/unstable/account/3pid/add",
  666. {
  667. "client_secret": client_secret,
  668. "sid": session_id,
  669. "auth": {
  670. "type": "m.login.password",
  671. "user": self.user_id,
  672. "password": "test",
  673. },
  674. },
  675. access_token=self.user_id_tok,
  676. )
  677. self.assertEqual(
  678. HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
  679. )
  680. self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
  681. # Get user
  682. channel = self.make_request(
  683. "GET",
  684. self.url_3pid,
  685. access_token=self.user_id_tok,
  686. )
  687. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  688. self.assertFalse(channel.json_body["threepids"])
  689. @override_config({"next_link_domain_whitelist": None})
  690. def test_next_link(self) -> None:
  691. """Tests a valid next_link parameter value with no whitelist (good case)"""
  692. self._request_token(
  693. "something@example.com",
  694. "some_secret",
  695. next_link="https://example.com/a/good/site",
  696. expect_code=HTTPStatus.OK,
  697. )
  698. @override_config({"next_link_domain_whitelist": None})
  699. def test_next_link_exotic_protocol(self) -> None:
  700. """Tests using a esoteric protocol as a next_link parameter value.
  701. Someone may be hosting a client on IPFS etc.
  702. """
  703. self._request_token(
  704. "something@example.com",
  705. "some_secret",
  706. next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
  707. expect_code=HTTPStatus.OK,
  708. )
  709. @override_config({"next_link_domain_whitelist": None})
  710. def test_next_link_file_uri(self) -> None:
  711. """Tests next_link parameters cannot be file URI"""
  712. # Attempt to use a next_link value that points to the local disk
  713. self._request_token(
  714. "something@example.com",
  715. "some_secret",
  716. next_link="file:///host/path",
  717. expect_code=HTTPStatus.BAD_REQUEST,
  718. )
  719. @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
  720. def test_next_link_domain_whitelist(self) -> None:
  721. """Tests next_link parameters must fit the whitelist if provided"""
  722. # Ensure not providing a next_link parameter still works
  723. self._request_token(
  724. "something@example.com",
  725. "some_secret",
  726. next_link=None,
  727. expect_code=HTTPStatus.OK,
  728. )
  729. self._request_token(
  730. "something@example.com",
  731. "some_secret",
  732. next_link="https://example.com/some/good/page",
  733. expect_code=HTTPStatus.OK,
  734. )
  735. self._request_token(
  736. "something@example.com",
  737. "some_secret",
  738. next_link="https://example.org/some/also/good/page",
  739. expect_code=HTTPStatus.OK,
  740. )
  741. self._request_token(
  742. "something@example.com",
  743. "some_secret",
  744. next_link="https://bad.example.org/some/bad/page",
  745. expect_code=HTTPStatus.BAD_REQUEST,
  746. )
  747. @override_config({"next_link_domain_whitelist": []})
  748. def test_empty_next_link_domain_whitelist(self) -> None:
  749. """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
  750. disallowed
  751. """
  752. self._request_token(
  753. "something@example.com",
  754. "some_secret",
  755. next_link="https://example.com/a/page",
  756. expect_code=HTTPStatus.BAD_REQUEST,
  757. )
  758. def _request_token(
  759. self,
  760. email: str,
  761. client_secret: str,
  762. next_link: Optional[str] = None,
  763. expect_code: int = HTTPStatus.OK,
  764. ) -> Optional[str]:
  765. """Request a validation token to add an email address to a user's account
  766. Args:
  767. email: The email address to validate
  768. client_secret: A secret string
  769. next_link: A link to redirect the user to after validation
  770. expect_code: Expected return code of the call
  771. Returns:
  772. The ID of the new threepid validation session, or None if the response
  773. did not contain a session ID.
  774. """
  775. body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
  776. if next_link:
  777. body["next_link"] = next_link
  778. channel = self.make_request(
  779. "POST",
  780. b"account/3pid/email/requestToken",
  781. body,
  782. )
  783. if channel.code != expect_code:
  784. raise HttpResponseException(
  785. channel.code,
  786. channel.result["reason"],
  787. channel.result["body"],
  788. )
  789. return channel.json_body.get("sid")
  790. def _request_token_invalid_email(
  791. self,
  792. email: str,
  793. expected_errcode: str,
  794. expected_error: str,
  795. client_secret: str = "foobar",
  796. ) -> None:
  797. channel = self.make_request(
  798. "POST",
  799. b"account/3pid/email/requestToken",
  800. {"client_secret": client_secret, "email": email, "send_attempt": 1},
  801. )
  802. self.assertEqual(
  803. HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
  804. )
  805. self.assertEqual(expected_errcode, channel.json_body["errcode"])
  806. self.assertIn(expected_error, channel.json_body["error"])
  807. def _validate_token(self, link: str) -> None:
  808. # Remove the host
  809. path = link.replace("https://example.com", "")
  810. channel = self.make_request("GET", path, shorthand=False)
  811. self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
  812. def _get_link_from_email(self) -> str:
  813. assert self.email_attempts, "No emails have been sent"
  814. raw_msg = self.email_attempts[-1].decode("UTF-8")
  815. mail = Parser().parsestr(raw_msg)
  816. text = None
  817. for part in mail.walk():
  818. if part.get_content_type() == "text/plain":
  819. text = part.get_payload(decode=True).decode("UTF-8")
  820. break
  821. if not text:
  822. self.fail("Could not find text portion of email to parse")
  823. assert text is not None
  824. match = re.search(r"https://example.com\S+", text)
  825. assert match, "Could not find link in email"
  826. return match.group(0)
  827. def _add_email(self, request_email: str, expected_email: str) -> None:
  828. """Test adding an email to profile"""
  829. previous_email_attempts = len(self.email_attempts)
  830. client_secret = "foobar"
  831. session_id = self._request_token(request_email, client_secret)
  832. self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1)
  833. link = self._get_link_from_email()
  834. self._validate_token(link)
  835. channel = self.make_request(
  836. "POST",
  837. b"/_matrix/client/unstable/account/3pid/add",
  838. {
  839. "client_secret": client_secret,
  840. "sid": session_id,
  841. "auth": {
  842. "type": "m.login.password",
  843. "user": self.user_id,
  844. "password": "test",
  845. },
  846. },
  847. access_token=self.user_id_tok,
  848. )
  849. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  850. # Get user
  851. channel = self.make_request(
  852. "GET",
  853. self.url_3pid,
  854. access_token=self.user_id_tok,
  855. )
  856. self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
  857. self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
  858. threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
  859. self.assertIn(expected_email, threepids)
  860. class AccountStatusTestCase(unittest.HomeserverTestCase):
  861. servlets = [
  862. account.register_servlets,
  863. admin.register_servlets,
  864. login.register_servlets,
  865. ]
  866. url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
  867. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  868. config = self.default_config()
  869. config["experimental_features"] = {"msc3720_enabled": True}
  870. return self.setup_test_homeserver(config=config)
  871. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  872. self.requester = self.register_user("requester", "password")
  873. self.requester_tok = self.login("requester", "password")
  874. self.server_name = hs.config.server.server_name
  875. def test_missing_mxid(self) -> None:
  876. """Tests that not providing any MXID raises an error."""
  877. self._test_status(
  878. users=None,
  879. expected_status_code=HTTPStatus.BAD_REQUEST,
  880. expected_errcode=Codes.MISSING_PARAM,
  881. )
  882. def test_invalid_mxid(self) -> None:
  883. """Tests that providing an invalid MXID raises an error."""
  884. self._test_status(
  885. users=["bad:test"],
  886. expected_status_code=HTTPStatus.BAD_REQUEST,
  887. expected_errcode=Codes.INVALID_PARAM,
  888. )
  889. def test_local_user_not_exists(self) -> None:
  890. """Tests that the account status endpoints correctly reports that a user doesn't
  891. exist.
  892. """
  893. user = "@unknown:" + self.hs.config.server.server_name
  894. self._test_status(
  895. users=[user],
  896. expected_statuses={
  897. user: {
  898. "exists": False,
  899. },
  900. },
  901. expected_failures=[],
  902. )
  903. def test_local_user_exists(self) -> None:
  904. """Tests that the account status endpoint correctly reports that a user doesn't
  905. exist.
  906. """
  907. user = self.register_user("someuser", "password")
  908. self._test_status(
  909. users=[user],
  910. expected_statuses={
  911. user: {
  912. "exists": True,
  913. "deactivated": False,
  914. },
  915. },
  916. expected_failures=[],
  917. )
  918. def test_local_user_deactivated(self) -> None:
  919. """Tests that the account status endpoint correctly reports a deactivated user."""
  920. user = self.register_user("someuser", "password")
  921. self.get_success(
  922. self.hs.get_datastores().main.set_user_deactivated_status(
  923. user, deactivated=True
  924. )
  925. )
  926. self._test_status(
  927. users=[user],
  928. expected_statuses={
  929. user: {
  930. "exists": True,
  931. "deactivated": True,
  932. },
  933. },
  934. expected_failures=[],
  935. )
  936. def test_mixed_local_and_remote_users(self) -> None:
  937. """Tests that if some users are remote the account status endpoint correctly
  938. merges the remote responses with the local result.
  939. """
  940. # We use 3 users: one doesn't exist but belongs on the local homeserver, one is
  941. # deactivated and belongs on one remote homeserver, and one belongs to another
  942. # remote homeserver that didn't return any result (the federation code should
  943. # mark that user as a failure).
  944. users = [
  945. "@unknown:" + self.hs.config.server.server_name,
  946. "@deactivated:remote",
  947. "@failed:otherremote",
  948. "@bad:badremote",
  949. ]
  950. async def post_json(
  951. destination: str,
  952. path: str,
  953. data: Optional[JsonDict] = None,
  954. *a: Any,
  955. **kwa: Any,
  956. ) -> Union[JsonDict, list]:
  957. if destination == "remote":
  958. return {
  959. "account_statuses": {
  960. users[1]: {
  961. "exists": True,
  962. "deactivated": True,
  963. },
  964. }
  965. }
  966. elif destination == "badremote":
  967. # badremote tries to overwrite the status of a user that doesn't belong
  968. # to it (i.e. users[1]) with false data, which Synapse is expected to
  969. # ignore.
  970. return {
  971. "account_statuses": {
  972. users[3]: {
  973. "exists": False,
  974. },
  975. users[1]: {
  976. "exists": False,
  977. },
  978. }
  979. }
  980. # if destination == "otherremote"
  981. else:
  982. return {}
  983. # Register a mock that will return the expected result depending on the remote.
  984. self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
  985. # Check that we've got the correct response from the client-side endpoint.
  986. self._test_status(
  987. users=users,
  988. expected_statuses={
  989. users[0]: {
  990. "exists": False,
  991. },
  992. users[1]: {
  993. "exists": True,
  994. "deactivated": True,
  995. },
  996. users[3]: {
  997. "exists": False,
  998. },
  999. },
  1000. expected_failures=[users[2]],
  1001. )
  1002. @unittest.override_config(
  1003. {
  1004. "use_account_validity_in_account_status": True,
  1005. }
  1006. )
  1007. def test_no_account_validity(self) -> None:
  1008. """Tests that if we decide to include account validity in the response but no
  1009. account validity 'is_user_expired' callback is provided, we default to marking all
  1010. users as not expired.
  1011. """
  1012. user = self.register_user("someuser", "password")
  1013. self._test_status(
  1014. users=[user],
  1015. expected_statuses={
  1016. user: {
  1017. "exists": True,
  1018. "deactivated": False,
  1019. "org.matrix.expired": False,
  1020. },
  1021. },
  1022. expected_failures=[],
  1023. )
  1024. @unittest.override_config(
  1025. {
  1026. "use_account_validity_in_account_status": True,
  1027. }
  1028. )
  1029. def test_account_validity_expired(self) -> None:
  1030. """Test that if we decide to include account validity in the response and the user
  1031. is expired, we return the correct info.
  1032. """
  1033. user = self.register_user("someuser", "password")
  1034. async def is_expired(user_id: str) -> bool:
  1035. # We can't blindly say everyone is expired, otherwise the request to get the
  1036. # account status will fail.
  1037. return UserID.from_string(user_id).localpart == "someuser"
  1038. self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
  1039. is_expired
  1040. )
  1041. self._test_status(
  1042. users=[user],
  1043. expected_statuses={
  1044. user: {
  1045. "exists": True,
  1046. "deactivated": False,
  1047. "org.matrix.expired": True,
  1048. },
  1049. },
  1050. expected_failures=[],
  1051. )
  1052. def _test_status(
  1053. self,
  1054. users: Optional[List[str]],
  1055. expected_status_code: int = HTTPStatus.OK,
  1056. expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
  1057. expected_failures: Optional[List[str]] = None,
  1058. expected_errcode: Optional[str] = None,
  1059. ) -> None:
  1060. """Send a request to the account status endpoint and check that the response
  1061. matches with what's expected.
  1062. Args:
  1063. users: The account(s) to request the status of, if any. If set to None, no
  1064. `user_id` query parameter will be included in the request.
  1065. expected_status_code: The expected HTTP status code.
  1066. expected_statuses: The expected account statuses, if any.
  1067. expected_failures: The expected failures, if any.
  1068. expected_errcode: The expected Matrix error code, if any.
  1069. """
  1070. content = {}
  1071. if users is not None:
  1072. content["user_ids"] = users
  1073. channel = self.make_request(
  1074. method="POST",
  1075. path=self.url,
  1076. content=content,
  1077. access_token=self.requester_tok,
  1078. )
  1079. self.assertEqual(channel.code, expected_status_code)
  1080. if expected_statuses is not None:
  1081. self.assertEqual(channel.json_body["account_statuses"], expected_statuses)
  1082. if expected_failures is not None:
  1083. self.assertEqual(channel.json_body["failures"], expected_failures)
  1084. if expected_errcode is not None:
  1085. self.assertEqual(channel.json_body["errcode"], expected_errcode)