You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1245 lines
46 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2017-2018 New Vector Ltd
  3. # Copyright 2019 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import datetime
  17. import os
  18. from typing import Any, Dict, List, Tuple
  19. import pkg_resources
  20. from twisted.test.proto_helpers import MemoryReactor
  21. import synapse.rest.admin
  22. from synapse.api.constants import (
  23. APP_SERVICE_REGISTRATION_TYPE,
  24. ApprovalNoticeMedium,
  25. LoginType,
  26. )
  27. from synapse.api.errors import Codes
  28. from synapse.appservice import ApplicationService
  29. from synapse.rest.client import account, account_validity, login, logout, register, sync
  30. from synapse.server import HomeServer
  31. from synapse.storage._base import db_to_json
  32. from synapse.types import JsonDict
  33. from synapse.util import Clock
  34. from tests import unittest
  35. from tests.unittest import override_config
  36. class RegisterRestServletTestCase(unittest.HomeserverTestCase):
  37. servlets = [
  38. login.register_servlets,
  39. register.register_servlets,
  40. synapse.rest.admin.register_servlets,
  41. ]
  42. url = b"/_matrix/client/r0/register"
  43. def default_config(self) -> Dict[str, Any]:
  44. config = super().default_config()
  45. config["allow_guest_access"] = True
  46. return config
  47. def test_POST_appservice_registration_valid(self) -> None:
  48. user_id = "@as_user_kermit:test"
  49. as_token = "i_am_an_app_service"
  50. appservice = ApplicationService(
  51. as_token,
  52. id="1234",
  53. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  54. sender="@as:test",
  55. )
  56. self.hs.get_datastores().main.services_cache.append(appservice)
  57. request_data = {
  58. "username": "as_user_kermit",
  59. "type": APP_SERVICE_REGISTRATION_TYPE,
  60. }
  61. channel = self.make_request(
  62. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  63. )
  64. self.assertEqual(channel.code, 200, msg=channel.result)
  65. det_data = {"user_id": user_id, "home_server": self.hs.hostname}
  66. self.assertDictContainsSubset(det_data, channel.json_body)
  67. def test_POST_appservice_registration_no_type(self) -> None:
  68. as_token = "i_am_an_app_service"
  69. appservice = ApplicationService(
  70. as_token,
  71. id="1234",
  72. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  73. sender="@as:test",
  74. )
  75. self.hs.get_datastores().main.services_cache.append(appservice)
  76. request_data = {"username": "as_user_kermit"}
  77. channel = self.make_request(
  78. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  79. )
  80. self.assertEqual(channel.code, 400, msg=channel.result)
  81. def test_POST_appservice_registration_invalid(self) -> None:
  82. self.appservice = None # no application service exists
  83. request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
  84. channel = self.make_request(
  85. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  86. )
  87. self.assertEqual(channel.code, 401, msg=channel.result)
  88. def test_POST_bad_password(self) -> None:
  89. request_data = {"username": "kermit", "password": 666}
  90. channel = self.make_request(b"POST", self.url, request_data)
  91. self.assertEqual(channel.code, 400, msg=channel.result)
  92. self.assertEqual(channel.json_body["error"], "Invalid password")
  93. def test_POST_bad_username(self) -> None:
  94. request_data = {"username": 777, "password": "monkey"}
  95. channel = self.make_request(b"POST", self.url, request_data)
  96. self.assertEqual(channel.code, 400, msg=channel.result)
  97. self.assertEqual(channel.json_body["error"], "Invalid username")
  98. def test_POST_user_valid(self) -> None:
  99. user_id = "@kermit:test"
  100. device_id = "frogfone"
  101. request_data = {
  102. "username": "kermit",
  103. "password": "monkey",
  104. "device_id": device_id,
  105. "auth": {"type": LoginType.DUMMY},
  106. }
  107. channel = self.make_request(b"POST", self.url, request_data)
  108. det_data = {
  109. "user_id": user_id,
  110. "home_server": self.hs.hostname,
  111. "device_id": device_id,
  112. }
  113. self.assertEqual(channel.code, 200, msg=channel.result)
  114. self.assertDictContainsSubset(det_data, channel.json_body)
  115. @override_config({"enable_registration": False})
  116. def test_POST_disabled_registration(self) -> None:
  117. request_data = {"username": "kermit", "password": "monkey"}
  118. self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
  119. channel = self.make_request(b"POST", self.url, request_data)
  120. self.assertEqual(channel.code, 403, msg=channel.result)
  121. self.assertEqual(channel.json_body["error"], "Registration has been disabled")
  122. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  123. def test_POST_guest_registration(self) -> None:
  124. self.hs.config.key.macaroon_secret_key = b"test"
  125. self.hs.config.registration.allow_guest_access = True
  126. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  127. det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
  128. self.assertEqual(channel.code, 200, msg=channel.result)
  129. self.assertDictContainsSubset(det_data, channel.json_body)
  130. def test_POST_disabled_guest_registration(self) -> None:
  131. self.hs.config.registration.allow_guest_access = False
  132. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  133. self.assertEqual(channel.code, 403, msg=channel.result)
  134. self.assertEqual(channel.json_body["error"], "Guest access is disabled")
  135. @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
  136. def test_POST_ratelimiting_guest(self) -> None:
  137. for i in range(0, 6):
  138. url = self.url + b"?kind=guest"
  139. channel = self.make_request(b"POST", url, b"{}")
  140. if i == 5:
  141. self.assertEqual(channel.code, 429, msg=channel.result)
  142. retry_after_ms = int(channel.json_body["retry_after_ms"])
  143. else:
  144. self.assertEqual(channel.code, 200, msg=channel.result)
  145. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  146. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  147. self.assertEqual(channel.code, 200, msg=channel.result)
  148. @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
  149. def test_POST_ratelimiting(self) -> None:
  150. for i in range(0, 6):
  151. request_data = {
  152. "username": "kermit" + str(i),
  153. "password": "monkey",
  154. "device_id": "frogfone",
  155. "auth": {"type": LoginType.DUMMY},
  156. }
  157. channel = self.make_request(b"POST", self.url, request_data)
  158. if i == 5:
  159. self.assertEqual(channel.code, 429, msg=channel.result)
  160. retry_after_ms = int(channel.json_body["retry_after_ms"])
  161. else:
  162. self.assertEqual(channel.code, 200, msg=channel.result)
  163. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  164. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  165. self.assertEqual(channel.code, 200, msg=channel.result)
  166. @override_config({"registration_requires_token": True})
  167. def test_POST_registration_requires_token(self) -> None:
  168. username = "kermit"
  169. device_id = "frogfone"
  170. token = "abcd"
  171. store = self.hs.get_datastores().main
  172. self.get_success(
  173. store.db_pool.simple_insert(
  174. "registration_tokens",
  175. {
  176. "token": token,
  177. "uses_allowed": None,
  178. "pending": 0,
  179. "completed": 0,
  180. "expiry_time": None,
  181. },
  182. )
  183. )
  184. params: JsonDict = {
  185. "username": username,
  186. "password": "monkey",
  187. "device_id": device_id,
  188. }
  189. # Request without auth to get flows and session
  190. channel = self.make_request(b"POST", self.url, params)
  191. self.assertEqual(channel.code, 401, msg=channel.result)
  192. flows = channel.json_body["flows"]
  193. # Synapse adds a dummy stage to differentiate flows where otherwise one
  194. # flow would be a subset of another flow.
  195. self.assertCountEqual(
  196. [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
  197. (f["stages"] for f in flows),
  198. )
  199. session = channel.json_body["session"]
  200. # Do the registration token stage and check it has completed
  201. params["auth"] = {
  202. "type": LoginType.REGISTRATION_TOKEN,
  203. "token": token,
  204. "session": session,
  205. }
  206. channel = self.make_request(b"POST", self.url, params)
  207. self.assertEqual(channel.code, 401, msg=channel.result)
  208. completed = channel.json_body["completed"]
  209. self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
  210. # Do the m.login.dummy stage and check registration was successful
  211. params["auth"] = {
  212. "type": LoginType.DUMMY,
  213. "session": session,
  214. }
  215. channel = self.make_request(b"POST", self.url, params)
  216. det_data = {
  217. "user_id": f"@{username}:{self.hs.hostname}",
  218. "home_server": self.hs.hostname,
  219. "device_id": device_id,
  220. }
  221. self.assertEqual(channel.code, 200, msg=channel.result)
  222. self.assertDictContainsSubset(det_data, channel.json_body)
  223. # Check the `completed` counter has been incremented and pending is 0
  224. res = self.get_success(
  225. store.db_pool.simple_select_one(
  226. "registration_tokens",
  227. keyvalues={"token": token},
  228. retcols=["pending", "completed"],
  229. )
  230. )
  231. self.assertEqual(res["completed"], 1)
  232. self.assertEqual(res["pending"], 0)
  233. @override_config({"registration_requires_token": True})
  234. def test_POST_registration_token_invalid(self) -> None:
  235. params: JsonDict = {
  236. "username": "kermit",
  237. "password": "monkey",
  238. }
  239. # Request without auth to get session
  240. channel = self.make_request(b"POST", self.url, params)
  241. session = channel.json_body["session"]
  242. # Test with token param missing (invalid)
  243. params["auth"] = {
  244. "type": LoginType.REGISTRATION_TOKEN,
  245. "session": session,
  246. }
  247. channel = self.make_request(b"POST", self.url, params)
  248. self.assertEqual(channel.code, 401, msg=channel.result)
  249. self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
  250. self.assertEqual(channel.json_body["completed"], [])
  251. # Test with non-string (invalid)
  252. params["auth"]["token"] = 1234
  253. channel = self.make_request(b"POST", self.url, params)
  254. self.assertEqual(channel.code, 401, msg=channel.result)
  255. self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
  256. self.assertEqual(channel.json_body["completed"], [])
  257. # Test with unknown token (invalid)
  258. params["auth"]["token"] = "1234"
  259. channel = self.make_request(b"POST", self.url, params)
  260. self.assertEqual(channel.code, 401, msg=channel.result)
  261. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  262. self.assertEqual(channel.json_body["completed"], [])
  263. @override_config({"registration_requires_token": True})
  264. def test_POST_registration_token_limit_uses(self) -> None:
  265. token = "abcd"
  266. store = self.hs.get_datastores().main
  267. # Create token that can be used once
  268. self.get_success(
  269. store.db_pool.simple_insert(
  270. "registration_tokens",
  271. {
  272. "token": token,
  273. "uses_allowed": 1,
  274. "pending": 0,
  275. "completed": 0,
  276. "expiry_time": None,
  277. },
  278. )
  279. )
  280. params1: JsonDict = {"username": "bert", "password": "monkey"}
  281. params2: JsonDict = {"username": "ernie", "password": "monkey"}
  282. # Do 2 requests without auth to get two session IDs
  283. channel1 = self.make_request(b"POST", self.url, params1)
  284. session1 = channel1.json_body["session"]
  285. channel2 = self.make_request(b"POST", self.url, params2)
  286. session2 = channel2.json_body["session"]
  287. # Use token with session1 and check `pending` is 1
  288. params1["auth"] = {
  289. "type": LoginType.REGISTRATION_TOKEN,
  290. "token": token,
  291. "session": session1,
  292. }
  293. self.make_request(b"POST", self.url, params1)
  294. # Repeat request to make sure pending isn't increased again
  295. self.make_request(b"POST", self.url, params1)
  296. pending = self.get_success(
  297. store.db_pool.simple_select_one_onecol(
  298. "registration_tokens",
  299. keyvalues={"token": token},
  300. retcol="pending",
  301. )
  302. )
  303. self.assertEqual(pending, 1)
  304. # Check auth fails when using token with session2
  305. params2["auth"] = {
  306. "type": LoginType.REGISTRATION_TOKEN,
  307. "token": token,
  308. "session": session2,
  309. }
  310. channel = self.make_request(b"POST", self.url, params2)
  311. self.assertEqual(channel.code, 401, msg=channel.result)
  312. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  313. self.assertEqual(channel.json_body["completed"], [])
  314. # Complete registration with session1
  315. params1["auth"]["type"] = LoginType.DUMMY
  316. self.make_request(b"POST", self.url, params1)
  317. # Check pending=0 and completed=1
  318. res = self.get_success(
  319. store.db_pool.simple_select_one(
  320. "registration_tokens",
  321. keyvalues={"token": token},
  322. retcols=["pending", "completed"],
  323. )
  324. )
  325. self.assertEqual(res["pending"], 0)
  326. self.assertEqual(res["completed"], 1)
  327. # Check auth still fails when using token with session2
  328. channel = self.make_request(b"POST", self.url, params2)
  329. self.assertEqual(channel.code, 401, msg=channel.result)
  330. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  331. self.assertEqual(channel.json_body["completed"], [])
  332. @override_config({"registration_requires_token": True})
  333. def test_POST_registration_token_expiry(self) -> None:
  334. token = "abcd"
  335. now = self.hs.get_clock().time_msec()
  336. store = self.hs.get_datastores().main
  337. # Create token that expired yesterday
  338. self.get_success(
  339. store.db_pool.simple_insert(
  340. "registration_tokens",
  341. {
  342. "token": token,
  343. "uses_allowed": None,
  344. "pending": 0,
  345. "completed": 0,
  346. "expiry_time": now - 24 * 60 * 60 * 1000,
  347. },
  348. )
  349. )
  350. params: JsonDict = {"username": "kermit", "password": "monkey"}
  351. # Request without auth to get session
  352. channel = self.make_request(b"POST", self.url, params)
  353. session = channel.json_body["session"]
  354. # Check authentication fails with expired token
  355. params["auth"] = {
  356. "type": LoginType.REGISTRATION_TOKEN,
  357. "token": token,
  358. "session": session,
  359. }
  360. channel = self.make_request(b"POST", self.url, params)
  361. self.assertEqual(channel.code, 401, msg=channel.result)
  362. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  363. self.assertEqual(channel.json_body["completed"], [])
  364. # Update token so it expires tomorrow
  365. self.get_success(
  366. store.db_pool.simple_update_one(
  367. "registration_tokens",
  368. keyvalues={"token": token},
  369. updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
  370. )
  371. )
  372. # Check authentication succeeds
  373. channel = self.make_request(b"POST", self.url, params)
  374. completed = channel.json_body["completed"]
  375. self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
  376. @override_config({"registration_requires_token": True})
  377. def test_POST_registration_token_session_expiry(self) -> None:
  378. """Test `pending` is decremented when an uncompleted session expires."""
  379. token = "abcd"
  380. store = self.hs.get_datastores().main
  381. self.get_success(
  382. store.db_pool.simple_insert(
  383. "registration_tokens",
  384. {
  385. "token": token,
  386. "uses_allowed": None,
  387. "pending": 0,
  388. "completed": 0,
  389. "expiry_time": None,
  390. },
  391. )
  392. )
  393. # Do 2 requests without auth to get two session IDs
  394. params1: JsonDict = {"username": "bert", "password": "monkey"}
  395. params2: JsonDict = {"username": "ernie", "password": "monkey"}
  396. channel1 = self.make_request(b"POST", self.url, params1)
  397. session1 = channel1.json_body["session"]
  398. channel2 = self.make_request(b"POST", self.url, params2)
  399. session2 = channel2.json_body["session"]
  400. # Use token with both sessions
  401. params1["auth"] = {
  402. "type": LoginType.REGISTRATION_TOKEN,
  403. "token": token,
  404. "session": session1,
  405. }
  406. self.make_request(b"POST", self.url, params1)
  407. params2["auth"] = {
  408. "type": LoginType.REGISTRATION_TOKEN,
  409. "token": token,
  410. "session": session2,
  411. }
  412. self.make_request(b"POST", self.url, params2)
  413. # Complete registration with session1
  414. params1["auth"]["type"] = LoginType.DUMMY
  415. self.make_request(b"POST", self.url, params1)
  416. # Check `result` of registration token stage for session1 is `True`
  417. result1 = self.get_success(
  418. store.db_pool.simple_select_one_onecol(
  419. "ui_auth_sessions_credentials",
  420. keyvalues={
  421. "session_id": session1,
  422. "stage_type": LoginType.REGISTRATION_TOKEN,
  423. },
  424. retcol="result",
  425. )
  426. )
  427. self.assertTrue(db_to_json(result1))
  428. # Check `result` for session2 is the token used
  429. result2 = self.get_success(
  430. store.db_pool.simple_select_one_onecol(
  431. "ui_auth_sessions_credentials",
  432. keyvalues={
  433. "session_id": session2,
  434. "stage_type": LoginType.REGISTRATION_TOKEN,
  435. },
  436. retcol="result",
  437. )
  438. )
  439. self.assertEqual(db_to_json(result2), token)
  440. # Delete both sessions (mimics expiry)
  441. self.get_success(
  442. store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
  443. )
  444. # Check pending is now 0
  445. pending = self.get_success(
  446. store.db_pool.simple_select_one_onecol(
  447. "registration_tokens",
  448. keyvalues={"token": token},
  449. retcol="pending",
  450. )
  451. )
  452. self.assertEqual(pending, 0)
  453. @override_config({"registration_requires_token": True})
  454. def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
  455. """Test session expiry doesn't break when the token is deleted.
  456. 1. Start but don't complete UIA with a registration token
  457. 2. Delete the token from the database
  458. 3. Expire the session
  459. """
  460. token = "abcd"
  461. store = self.hs.get_datastores().main
  462. self.get_success(
  463. store.db_pool.simple_insert(
  464. "registration_tokens",
  465. {
  466. "token": token,
  467. "uses_allowed": None,
  468. "pending": 0,
  469. "completed": 0,
  470. "expiry_time": None,
  471. },
  472. )
  473. )
  474. # Do request without auth to get a session ID
  475. params: JsonDict = {"username": "kermit", "password": "monkey"}
  476. channel = self.make_request(b"POST", self.url, params)
  477. session = channel.json_body["session"]
  478. # Use token
  479. params["auth"] = {
  480. "type": LoginType.REGISTRATION_TOKEN,
  481. "token": token,
  482. "session": session,
  483. }
  484. self.make_request(b"POST", self.url, params)
  485. # Delete token
  486. self.get_success(
  487. store.db_pool.simple_delete_one(
  488. "registration_tokens",
  489. keyvalues={"token": token},
  490. )
  491. )
  492. # Delete session (mimics expiry)
  493. self.get_success(
  494. store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
  495. )
  496. def test_advertised_flows(self) -> None:
  497. channel = self.make_request(b"POST", self.url, b"{}")
  498. self.assertEqual(channel.code, 401, msg=channel.result)
  499. flows = channel.json_body["flows"]
  500. # with the stock config, we only expect the dummy flow
  501. self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
  502. @unittest.override_config(
  503. {
  504. "public_baseurl": "https://test_server",
  505. "enable_registration_captcha": True,
  506. "user_consent": {
  507. "version": "1",
  508. "template_dir": "/",
  509. "require_at_registration": True,
  510. },
  511. "account_threepid_delegates": {
  512. "msisdn": "https://id_server",
  513. },
  514. "email": {"notif_from": "Synapse <synapse@example.com>"},
  515. }
  516. )
  517. def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
  518. channel = self.make_request(b"POST", self.url, b"{}")
  519. self.assertEqual(channel.code, 401, msg=channel.result)
  520. flows = channel.json_body["flows"]
  521. self.assertCountEqual(
  522. [
  523. ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
  524. ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
  525. ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
  526. [
  527. "m.login.recaptcha",
  528. "m.login.terms",
  529. "m.login.msisdn",
  530. "m.login.email.identity",
  531. ],
  532. ],
  533. (f["stages"] for f in flows),
  534. )
  535. @unittest.override_config(
  536. {
  537. "public_baseurl": "https://test_server",
  538. "registrations_require_3pid": ["email"],
  539. "disable_msisdn_registration": True,
  540. "email": {
  541. "smtp_host": "mail_server",
  542. "smtp_port": 2525,
  543. "notif_from": "sender@host",
  544. },
  545. }
  546. )
  547. def test_advertised_flows_no_msisdn_email_required(self) -> None:
  548. channel = self.make_request(b"POST", self.url, b"{}")
  549. self.assertEqual(channel.code, 401, msg=channel.result)
  550. flows = channel.json_body["flows"]
  551. # with the stock config, we expect all four combinations of 3pid
  552. self.assertCountEqual(
  553. [["m.login.email.identity"]], (f["stages"] for f in flows)
  554. )
  555. @unittest.override_config(
  556. {
  557. "request_token_inhibit_3pid_errors": True,
  558. "public_baseurl": "https://test_server",
  559. "email": {
  560. "smtp_host": "mail_server",
  561. "smtp_port": 2525,
  562. "notif_from": "sender@host",
  563. },
  564. }
  565. )
  566. def test_request_token_existing_email_inhibit_error(self) -> None:
  567. """Test that requesting a token via this endpoint doesn't leak existing
  568. associations if configured that way.
  569. """
  570. user_id = self.register_user("kermit", "monkey")
  571. self.login("kermit", "monkey")
  572. email = "test@example.com"
  573. # Add a threepid
  574. self.get_success(
  575. self.hs.get_datastores().main.user_add_threepid(
  576. user_id=user_id,
  577. medium="email",
  578. address=email,
  579. validated_at=0,
  580. added_at=0,
  581. )
  582. )
  583. channel = self.make_request(
  584. "POST",
  585. b"register/email/requestToken",
  586. {"client_secret": "foobar", "email": email, "send_attempt": 1},
  587. )
  588. self.assertEqual(200, channel.code, channel.result)
  589. self.assertIsNotNone(channel.json_body.get("sid"))
  590. @unittest.override_config(
  591. {
  592. "public_baseurl": "https://test_server",
  593. "email": {
  594. "smtp_host": "mail_server",
  595. "smtp_port": 2525,
  596. "notif_from": "sender@host",
  597. },
  598. }
  599. )
  600. def test_reject_invalid_email(self) -> None:
  601. """Check that bad emails are rejected"""
  602. # Test for email with multiple @
  603. channel = self.make_request(
  604. "POST",
  605. b"register/email/requestToken",
  606. {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
  607. )
  608. self.assertEqual(400, channel.code, channel.result)
  609. # Check error to ensure that we're not erroring due to a bug in the test.
  610. self.assertEqual(
  611. channel.json_body,
  612. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  613. )
  614. # Test for email with no @
  615. channel = self.make_request(
  616. "POST",
  617. b"register/email/requestToken",
  618. {"client_secret": "foobar", "email": "email", "send_attempt": 1},
  619. )
  620. self.assertEqual(400, channel.code, channel.result)
  621. self.assertEqual(
  622. channel.json_body,
  623. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  624. )
  625. # Test for super long email
  626. email = "a@" + "a" * 1000
  627. channel = self.make_request(
  628. "POST",
  629. b"register/email/requestToken",
  630. {"client_secret": "foobar", "email": email, "send_attempt": 1},
  631. )
  632. self.assertEqual(400, channel.code, channel.result)
  633. self.assertEqual(
  634. channel.json_body,
  635. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  636. )
  637. @override_config(
  638. {
  639. "inhibit_user_in_use_error": True,
  640. }
  641. )
  642. def test_inhibit_user_in_use_error(self) -> None:
  643. """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
  644. correctly.
  645. """
  646. username = "arthur"
  647. # Manually register the user, so we know the test isn't passing because of a lack
  648. # of clashing.
  649. reg_handler = self.hs.get_registration_handler()
  650. self.get_success(reg_handler.register_user(username))
  651. # Check that /available correctly ignores the username provided despite the
  652. # username being already registered.
  653. channel = self.make_request("GET", "register/available?username=" + username)
  654. self.assertEqual(200, channel.code, channel.result)
  655. # Test that when starting a UIA registration flow the request doesn't fail because
  656. # of a conflicting username
  657. channel = self.make_request(
  658. "POST",
  659. "register",
  660. {"username": username, "type": "m.login.password", "password": "foo"},
  661. )
  662. self.assertEqual(channel.code, 401)
  663. self.assertIn("session", channel.json_body)
  664. # Test that finishing the registration fails because of a conflicting username.
  665. session = channel.json_body["session"]
  666. channel = self.make_request(
  667. "POST",
  668. "register",
  669. {"auth": {"session": session, "type": LoginType.DUMMY}},
  670. )
  671. self.assertEqual(channel.code, 400, channel.json_body)
  672. self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
  673. @override_config(
  674. {
  675. "experimental_features": {
  676. "msc3866": {
  677. "enabled": True,
  678. "require_approval_for_new_accounts": True,
  679. }
  680. }
  681. }
  682. )
  683. def test_require_approval(self) -> None:
  684. channel = self.make_request(
  685. "POST",
  686. "register",
  687. {
  688. "username": "kermit",
  689. "password": "monkey",
  690. "auth": {"type": LoginType.DUMMY},
  691. },
  692. )
  693. self.assertEqual(403, channel.code, channel.result)
  694. self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
  695. self.assertEqual(
  696. ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
  697. )
  698. class AccountValidityTestCase(unittest.HomeserverTestCase):
  699. servlets = [
  700. register.register_servlets,
  701. synapse.rest.admin.register_servlets_for_client_rest_resource,
  702. login.register_servlets,
  703. sync.register_servlets,
  704. logout.register_servlets,
  705. account_validity.register_servlets,
  706. ]
  707. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  708. config = self.default_config()
  709. # Test for account expiring after a week.
  710. config["enable_registration"] = True
  711. config["account_validity"] = {
  712. "enabled": True,
  713. "period": 604800000, # Time in ms for 1 week
  714. }
  715. self.hs = self.setup_test_homeserver(config=config)
  716. return self.hs
  717. def test_validity_period(self) -> None:
  718. self.register_user("kermit", "monkey")
  719. tok = self.login("kermit", "monkey")
  720. # The specific endpoint doesn't matter, all we need is an authenticated
  721. # endpoint.
  722. channel = self.make_request(b"GET", "/sync", access_token=tok)
  723. self.assertEqual(channel.code, 200, msg=channel.result)
  724. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  725. channel = self.make_request(b"GET", "/sync", access_token=tok)
  726. self.assertEqual(channel.code, 403, msg=channel.result)
  727. self.assertEqual(
  728. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  729. )
  730. def test_manual_renewal(self) -> None:
  731. user_id = self.register_user("kermit", "monkey")
  732. tok = self.login("kermit", "monkey")
  733. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  734. # If we register the admin user at the beginning of the test, it will
  735. # expire at the same time as the normal user and the renewal request
  736. # will be denied.
  737. self.register_user("admin", "adminpassword", admin=True)
  738. admin_tok = self.login("admin", "adminpassword")
  739. url = "/_synapse/admin/v1/account_validity/validity"
  740. request_data = {"user_id": user_id}
  741. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  742. self.assertEqual(channel.code, 200, msg=channel.result)
  743. # The specific endpoint doesn't matter, all we need is an authenticated
  744. # endpoint.
  745. channel = self.make_request(b"GET", "/sync", access_token=tok)
  746. self.assertEqual(channel.code, 200, msg=channel.result)
  747. def test_manual_expire(self) -> None:
  748. user_id = self.register_user("kermit", "monkey")
  749. tok = self.login("kermit", "monkey")
  750. self.register_user("admin", "adminpassword", admin=True)
  751. admin_tok = self.login("admin", "adminpassword")
  752. url = "/_synapse/admin/v1/account_validity/validity"
  753. request_data = {
  754. "user_id": user_id,
  755. "expiration_ts": 0,
  756. "enable_renewal_emails": False,
  757. }
  758. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  759. self.assertEqual(channel.code, 200, msg=channel.result)
  760. # The specific endpoint doesn't matter, all we need is an authenticated
  761. # endpoint.
  762. channel = self.make_request(b"GET", "/sync", access_token=tok)
  763. self.assertEqual(channel.code, 403, msg=channel.result)
  764. self.assertEqual(
  765. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  766. )
  767. def test_logging_out_expired_user(self) -> None:
  768. user_id = self.register_user("kermit", "monkey")
  769. tok = self.login("kermit", "monkey")
  770. self.register_user("admin", "adminpassword", admin=True)
  771. admin_tok = self.login("admin", "adminpassword")
  772. url = "/_synapse/admin/v1/account_validity/validity"
  773. request_data = {
  774. "user_id": user_id,
  775. "expiration_ts": 0,
  776. "enable_renewal_emails": False,
  777. }
  778. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  779. self.assertEqual(channel.code, 200, msg=channel.result)
  780. # Try to log the user out
  781. channel = self.make_request(b"POST", "/logout", access_token=tok)
  782. self.assertEqual(channel.code, 200, msg=channel.result)
  783. # Log the user in again (allowed for expired accounts)
  784. tok = self.login("kermit", "monkey")
  785. # Try to log out all of the user's sessions
  786. channel = self.make_request(b"POST", "/logout/all", access_token=tok)
  787. self.assertEqual(channel.code, 200, msg=channel.result)
  788. class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
  789. servlets = [
  790. register.register_servlets,
  791. synapse.rest.admin.register_servlets_for_client_rest_resource,
  792. login.register_servlets,
  793. sync.register_servlets,
  794. account_validity.register_servlets,
  795. account.register_servlets,
  796. ]
  797. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  798. config = self.default_config()
  799. # Test for account expiring after a week and renewal emails being sent 2
  800. # days before expiry.
  801. config["enable_registration"] = True
  802. config["account_validity"] = {
  803. "enabled": True,
  804. "period": 604800000, # Time in ms for 1 week
  805. "renew_at": 172800000, # Time in ms for 2 days
  806. "renew_by_email_enabled": True,
  807. "renew_email_subject": "Renew your account",
  808. "account_renewed_html_path": "account_renewed.html",
  809. "invalid_token_html_path": "invalid_token.html",
  810. }
  811. # Email config.
  812. config["email"] = {
  813. "enable_notifs": True,
  814. "template_dir": os.path.abspath(
  815. pkg_resources.resource_filename("synapse", "res/templates")
  816. ),
  817. "expiry_template_html": "notice_expiry.html",
  818. "expiry_template_text": "notice_expiry.txt",
  819. "notif_template_html": "notif_mail.html",
  820. "notif_template_text": "notif_mail.txt",
  821. "smtp_host": "127.0.0.1",
  822. "smtp_port": 20,
  823. "require_transport_security": False,
  824. "smtp_user": None,
  825. "smtp_pass": None,
  826. "notif_from": "test@example.com",
  827. }
  828. self.hs = self.setup_test_homeserver(config=config)
  829. async def sendmail(*args: Any, **kwargs: Any) -> None:
  830. self.email_attempts.append((args, kwargs))
  831. self.email_attempts: List[Tuple[Any, Any]] = []
  832. self.hs.get_send_email_handler()._sendmail = sendmail
  833. self.store = self.hs.get_datastores().main
  834. return self.hs
  835. def test_renewal_email(self) -> None:
  836. self.email_attempts = []
  837. (user_id, tok) = self.create_user()
  838. # Move 5 days forward. This should trigger a renewal email to be sent.
  839. self.reactor.advance(datetime.timedelta(days=5).total_seconds())
  840. self.assertEqual(len(self.email_attempts), 1)
  841. # Retrieving the URL from the email is too much pain for now, so we
  842. # retrieve the token from the DB.
  843. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
  844. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  845. channel = self.make_request(b"GET", url)
  846. self.assertEqual(channel.code, 200, msg=channel.result)
  847. # Check that we're getting HTML back.
  848. content_type = channel.headers.getRawHeaders(b"Content-Type")
  849. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  850. # Check that the HTML we're getting is the one we expect on a successful renewal.
  851. expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  852. expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
  853. expiration_ts=expiration_ts
  854. )
  855. self.assertEqual(
  856. channel.result["body"], expected_html.encode("utf8"), channel.result
  857. )
  858. # Move 1 day forward. Try to renew with the same token again.
  859. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  860. channel = self.make_request(b"GET", url)
  861. self.assertEqual(channel.code, 200, msg=channel.result)
  862. # Check that we're getting HTML back.
  863. content_type = channel.headers.getRawHeaders(b"Content-Type")
  864. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  865. # Check that the HTML we're getting is the one we expect when reusing a
  866. # token. The account expiration date should not have changed.
  867. expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
  868. expiration_ts=expiration_ts
  869. )
  870. self.assertEqual(
  871. channel.result["body"], expected_html.encode("utf8"), channel.result
  872. )
  873. # Move 3 days forward. If the renewal failed, every authed request with
  874. # our access token should be denied from now, otherwise they should
  875. # succeed.
  876. self.reactor.advance(datetime.timedelta(days=3).total_seconds())
  877. channel = self.make_request(b"GET", "/sync", access_token=tok)
  878. self.assertEqual(channel.code, 200, msg=channel.result)
  879. def test_renewal_invalid_token(self) -> None:
  880. # Hit the renewal endpoint with an invalid token and check that it behaves as
  881. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
  882. url = "/_matrix/client/unstable/account_validity/renew?token=123"
  883. channel = self.make_request(b"GET", url)
  884. self.assertEqual(channel.code, 404, msg=channel.result)
  885. # Check that we're getting HTML back.
  886. content_type = channel.headers.getRawHeaders(b"Content-Type")
  887. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  888. # Check that the HTML we're getting is the one we expect when using an
  889. # invalid/unknown token.
  890. expected_html = (
  891. self.hs.config.account_validity.account_validity_invalid_token_template.render()
  892. )
  893. self.assertEqual(
  894. channel.result["body"], expected_html.encode("utf8"), channel.result
  895. )
  896. def test_manual_email_send(self) -> None:
  897. self.email_attempts = []
  898. (user_id, tok) = self.create_user()
  899. channel = self.make_request(
  900. b"POST",
  901. "/_matrix/client/unstable/account_validity/send_mail",
  902. access_token=tok,
  903. )
  904. self.assertEqual(channel.code, 200, msg=channel.result)
  905. self.assertEqual(len(self.email_attempts), 1)
  906. def test_deactivated_user(self) -> None:
  907. self.email_attempts = []
  908. (user_id, tok) = self.create_user()
  909. request_data = {
  910. "auth": {
  911. "type": "m.login.password",
  912. "user": user_id,
  913. "password": "monkey",
  914. },
  915. "erase": False,
  916. }
  917. channel = self.make_request(
  918. "POST", "account/deactivate", request_data, access_token=tok
  919. )
  920. self.assertEqual(channel.code, 200)
  921. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  922. self.assertEqual(len(self.email_attempts), 0)
  923. def create_user(self) -> Tuple[str, str]:
  924. user_id = self.register_user("kermit", "monkey")
  925. tok = self.login("kermit", "monkey")
  926. # We need to manually add an email address otherwise the handler will do
  927. # nothing.
  928. now = self.hs.get_clock().time_msec()
  929. self.get_success(
  930. self.store.user_add_threepid(
  931. user_id=user_id,
  932. medium="email",
  933. address="kermit@example.com",
  934. validated_at=now,
  935. added_at=now,
  936. )
  937. )
  938. return user_id, tok
  939. def test_manual_email_send_expired_account(self) -> None:
  940. user_id = self.register_user("kermit", "monkey")
  941. tok = self.login("kermit", "monkey")
  942. # We need to manually add an email address otherwise the handler will do
  943. # nothing.
  944. now = self.hs.get_clock().time_msec()
  945. self.get_success(
  946. self.store.user_add_threepid(
  947. user_id=user_id,
  948. medium="email",
  949. address="kermit@example.com",
  950. validated_at=now,
  951. added_at=now,
  952. )
  953. )
  954. # Make the account expire.
  955. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  956. # Ignore all emails sent by the automatic background task and only focus on the
  957. # ones sent manually.
  958. self.email_attempts = []
  959. # Test that we're still able to manually trigger a mail to be sent.
  960. channel = self.make_request(
  961. b"POST",
  962. "/_matrix/client/unstable/account_validity/send_mail",
  963. access_token=tok,
  964. )
  965. self.assertEqual(channel.code, 200, msg=channel.result)
  966. self.assertEqual(len(self.email_attempts), 1)
  967. class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
  968. servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
  969. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  970. self.validity_period = 10
  971. self.max_delta = self.validity_period * 10.0 / 100.0
  972. config = self.default_config()
  973. config["enable_registration"] = True
  974. config["account_validity"] = {"enabled": False}
  975. self.hs = self.setup_test_homeserver(config=config)
  976. # We need to set these directly, instead of in the homeserver config dict above.
  977. # This is due to account validity-related config options not being read by
  978. # Synapse when account_validity.enabled is False.
  979. self.hs.get_datastores().main._account_validity_period = self.validity_period
  980. self.hs.get_datastores().main._account_validity_startup_job_max_delta = (
  981. self.max_delta
  982. )
  983. self.store = self.hs.get_datastores().main
  984. return self.hs
  985. def test_background_job(self) -> None:
  986. """
  987. Tests the same thing as test_background_job, except that it sets the
  988. startup_job_max_delta parameter and checks that the expiration date is within the
  989. allowed range.
  990. """
  991. user_id = self.register_user("kermit_delta", "user")
  992. self.hs.config.account_validity.account_validity_startup_job_max_delta = (
  993. self.max_delta
  994. )
  995. now_ms = self.hs.get_clock().time_msec()
  996. self.get_success(self.store._set_expiration_date_when_missing())
  997. res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  998. assert res is not None
  999. self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
  1000. self.assertLessEqual(res, now_ms + self.validity_period)
  1001. class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
  1002. servlets = [register.register_servlets]
  1003. url = "/_matrix/client/v1/register/m.login.registration_token/validity"
  1004. def default_config(self) -> Dict[str, Any]:
  1005. config = super().default_config()
  1006. config["registration_requires_token"] = True
  1007. return config
  1008. def test_GET_token_valid(self) -> None:
  1009. token = "abcd"
  1010. store = self.hs.get_datastores().main
  1011. self.get_success(
  1012. store.db_pool.simple_insert(
  1013. "registration_tokens",
  1014. {
  1015. "token": token,
  1016. "uses_allowed": None,
  1017. "pending": 0,
  1018. "completed": 0,
  1019. "expiry_time": None,
  1020. },
  1021. )
  1022. )
  1023. channel = self.make_request(
  1024. b"GET",
  1025. f"{self.url}?token={token}",
  1026. )
  1027. self.assertEqual(channel.code, 200, msg=channel.result)
  1028. self.assertEqual(channel.json_body["valid"], True)
  1029. def test_GET_token_invalid(self) -> None:
  1030. token = "1234"
  1031. channel = self.make_request(
  1032. b"GET",
  1033. f"{self.url}?token={token}",
  1034. )
  1035. self.assertEqual(channel.code, 200, msg=channel.result)
  1036. self.assertEqual(channel.json_body["valid"], False)
  1037. @override_config(
  1038. {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
  1039. )
  1040. def test_GET_ratelimiting(self) -> None:
  1041. token = "1234"
  1042. for i in range(0, 6):
  1043. channel = self.make_request(
  1044. b"GET",
  1045. f"{self.url}?token={token}",
  1046. )
  1047. if i == 5:
  1048. self.assertEqual(channel.code, 429, msg=channel.result)
  1049. retry_after_ms = int(channel.json_body["retry_after_ms"])
  1050. else:
  1051. self.assertEqual(channel.code, 200, msg=channel.result)
  1052. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  1053. channel = self.make_request(
  1054. b"GET",
  1055. f"{self.url}?token={token}",
  1056. )
  1057. self.assertEqual(channel.code, 200, msg=channel.result)