You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

601 lines
22 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import re
  16. from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
  17. from typing_extensions import TypedDict
  18. from synapse.api.errors import Codes, LoginError, SynapseError
  19. from synapse.api.ratelimiting import Ratelimiter
  20. from synapse.api.urls import CLIENT_API_PREFIX
  21. from synapse.appservice import ApplicationService
  22. from synapse.handlers.sso import SsoIdentityProvider
  23. from synapse.http import get_request_uri
  24. from synapse.http.server import HttpServer, finish_request
  25. from synapse.http.servlet import (
  26. RestServlet,
  27. assert_params_in_dict,
  28. parse_boolean,
  29. parse_bytes_from_args,
  30. parse_json_object_from_request,
  31. parse_string,
  32. )
  33. from synapse.http.site import SynapseRequest
  34. from synapse.rest.client._base import client_patterns
  35. from synapse.rest.well_known import WellKnownBuilder
  36. from synapse.types import JsonDict, UserID
  37. if TYPE_CHECKING:
  38. from synapse.server import HomeServer
  39. logger = logging.getLogger(__name__)
  40. class LoginResponse(TypedDict, total=False):
  41. user_id: str
  42. access_token: str
  43. home_server: str
  44. expires_in_ms: Optional[int]
  45. refresh_token: Optional[str]
  46. device_id: str
  47. well_known: Optional[Dict[str, Any]]
  48. class LoginRestServlet(RestServlet):
  49. PATTERNS = client_patterns("/login$", v1=True)
  50. CAS_TYPE = "m.login.cas"
  51. SSO_TYPE = "m.login.sso"
  52. TOKEN_TYPE = "m.login.token"
  53. JWT_TYPE = "org.matrix.login.jwt"
  54. JWT_TYPE_DEPRECATED = "m.login.jwt"
  55. APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
  56. REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
  57. def __init__(self, hs: "HomeServer"):
  58. super().__init__()
  59. self.hs = hs
  60. # JWT configuration variables.
  61. self.jwt_enabled = hs.config.jwt_enabled
  62. self.jwt_secret = hs.config.jwt_secret
  63. self.jwt_algorithm = hs.config.jwt_algorithm
  64. self.jwt_issuer = hs.config.jwt_issuer
  65. self.jwt_audiences = hs.config.jwt_audiences
  66. # SSO configuration.
  67. self.saml2_enabled = hs.config.saml2_enabled
  68. self.cas_enabled = hs.config.cas_enabled
  69. self.oidc_enabled = hs.config.oidc_enabled
  70. self._msc2858_enabled = hs.config.experimental.msc2858_enabled
  71. self._msc2918_enabled = hs.config.access_token_lifetime is not None
  72. self.auth = hs.get_auth()
  73. self.clock = hs.get_clock()
  74. self.auth_handler = self.hs.get_auth_handler()
  75. self.registration_handler = hs.get_registration_handler()
  76. self._sso_handler = hs.get_sso_handler()
  77. self._well_known_builder = WellKnownBuilder(hs)
  78. self._address_ratelimiter = Ratelimiter(
  79. store=hs.get_datastore(),
  80. clock=hs.get_clock(),
  81. rate_hz=self.hs.config.rc_login_address.per_second,
  82. burst_count=self.hs.config.rc_login_address.burst_count,
  83. )
  84. self._account_ratelimiter = Ratelimiter(
  85. store=hs.get_datastore(),
  86. clock=hs.get_clock(),
  87. rate_hz=self.hs.config.rc_login_account.per_second,
  88. burst_count=self.hs.config.rc_login_account.burst_count,
  89. )
  90. def on_GET(self, request: SynapseRequest):
  91. flows = []
  92. if self.jwt_enabled:
  93. flows.append({"type": LoginRestServlet.JWT_TYPE})
  94. flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
  95. if self.cas_enabled:
  96. # we advertise CAS for backwards compat, though MSC1721 renamed it
  97. # to SSO.
  98. flows.append({"type": LoginRestServlet.CAS_TYPE})
  99. if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
  100. sso_flow: JsonDict = {
  101. "type": LoginRestServlet.SSO_TYPE,
  102. "identity_providers": [
  103. _get_auth_flow_dict_for_idp(
  104. idp,
  105. )
  106. for idp in self._sso_handler.get_identity_providers().values()
  107. ],
  108. }
  109. if self._msc2858_enabled:
  110. # backwards-compatibility support for clients which don't
  111. # support the stable API yet
  112. sso_flow["org.matrix.msc2858.identity_providers"] = [
  113. _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
  114. for idp in self._sso_handler.get_identity_providers().values()
  115. ]
  116. flows.append(sso_flow)
  117. # While it's valid for us to advertise this login type generally,
  118. # synapse currently only gives out these tokens as part of the
  119. # SSO login flow.
  120. # Generally we don't want to advertise login flows that clients
  121. # don't know how to implement, since they (currently) will always
  122. # fall back to the fallback API if they don't understand one of the
  123. # login flow types returned.
  124. flows.append({"type": LoginRestServlet.TOKEN_TYPE})
  125. flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
  126. flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
  127. return 200, {"flows": flows}
  128. async def on_POST(self, request: SynapseRequest):
  129. login_submission = parse_json_object_from_request(request)
  130. if self._msc2918_enabled:
  131. # Check if this login should also issue a refresh token, as per
  132. # MSC2918
  133. should_issue_refresh_token = parse_boolean(
  134. request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
  135. )
  136. else:
  137. should_issue_refresh_token = False
  138. try:
  139. if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
  140. appservice = self.auth.get_appservice_by_req(request)
  141. if appservice.is_rate_limited():
  142. await self._address_ratelimiter.ratelimit(
  143. None, request.getClientIP()
  144. )
  145. result = await self._do_appservice_login(
  146. login_submission,
  147. appservice,
  148. should_issue_refresh_token=should_issue_refresh_token,
  149. )
  150. elif self.jwt_enabled and (
  151. login_submission["type"] == LoginRestServlet.JWT_TYPE
  152. or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
  153. ):
  154. await self._address_ratelimiter.ratelimit(None, request.getClientIP())
  155. result = await self._do_jwt_login(
  156. login_submission,
  157. should_issue_refresh_token=should_issue_refresh_token,
  158. )
  159. elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
  160. await self._address_ratelimiter.ratelimit(None, request.getClientIP())
  161. result = await self._do_token_login(
  162. login_submission,
  163. should_issue_refresh_token=should_issue_refresh_token,
  164. )
  165. else:
  166. await self._address_ratelimiter.ratelimit(None, request.getClientIP())
  167. result = await self._do_other_login(
  168. login_submission,
  169. should_issue_refresh_token=should_issue_refresh_token,
  170. )
  171. except KeyError:
  172. raise SynapseError(400, "Missing JSON keys.")
  173. well_known_data = self._well_known_builder.get_well_known()
  174. if well_known_data:
  175. result["well_known"] = well_known_data
  176. return 200, result
  177. async def _do_appservice_login(
  178. self,
  179. login_submission: JsonDict,
  180. appservice: ApplicationService,
  181. should_issue_refresh_token: bool = False,
  182. ):
  183. identifier = login_submission.get("identifier")
  184. logger.info("Got appservice login request with identifier: %r", identifier)
  185. if not isinstance(identifier, dict):
  186. raise SynapseError(
  187. 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
  188. )
  189. # this login flow only supports identifiers of type "m.id.user".
  190. if identifier.get("type") != "m.id.user":
  191. raise SynapseError(
  192. 400, "Unknown login identifier type", Codes.INVALID_PARAM
  193. )
  194. user = identifier.get("user")
  195. if not isinstance(user, str):
  196. raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
  197. if user.startswith("@"):
  198. qualified_user_id = user
  199. else:
  200. qualified_user_id = UserID(user, self.hs.hostname).to_string()
  201. if not appservice.is_interested_in_user(qualified_user_id):
  202. raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
  203. return await self._complete_login(
  204. qualified_user_id,
  205. login_submission,
  206. ratelimit=appservice.is_rate_limited(),
  207. should_issue_refresh_token=should_issue_refresh_token,
  208. )
  209. async def _do_other_login(
  210. self, login_submission: JsonDict, should_issue_refresh_token: bool = False
  211. ) -> LoginResponse:
  212. """Handle non-token/saml/jwt logins
  213. Args:
  214. login_submission:
  215. should_issue_refresh_token: True if this login should issue
  216. a refresh token alongside the access token.
  217. Returns:
  218. HTTP response
  219. """
  220. # Log the request we got, but only certain fields to minimise the chance of
  221. # logging someone's password (even if they accidentally put it in the wrong
  222. # field)
  223. logger.info(
  224. "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
  225. login_submission.get("identifier"),
  226. login_submission.get("medium"),
  227. login_submission.get("address"),
  228. login_submission.get("user"),
  229. )
  230. canonical_user_id, callback = await self.auth_handler.validate_login(
  231. login_submission, ratelimit=True
  232. )
  233. result = await self._complete_login(
  234. canonical_user_id,
  235. login_submission,
  236. callback,
  237. should_issue_refresh_token=should_issue_refresh_token,
  238. )
  239. return result
  240. async def _complete_login(
  241. self,
  242. user_id: str,
  243. login_submission: JsonDict,
  244. callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
  245. create_non_existent_users: bool = False,
  246. ratelimit: bool = True,
  247. auth_provider_id: Optional[str] = None,
  248. should_issue_refresh_token: bool = False,
  249. ) -> LoginResponse:
  250. """Called when we've successfully authed the user and now need to
  251. actually login them in (e.g. create devices). This gets called on
  252. all successful logins.
  253. Applies the ratelimiting for successful login attempts against an
  254. account.
  255. Args:
  256. user_id: ID of the user to register.
  257. login_submission: Dictionary of login information.
  258. callback: Callback function to run after login.
  259. create_non_existent_users: Whether to create the user if they don't
  260. exist. Defaults to False.
  261. ratelimit: Whether to ratelimit the login request.
  262. auth_provider_id: The SSO IdP the user used, if any (just used for the
  263. prometheus metrics).
  264. should_issue_refresh_token: True if this login should issue
  265. a refresh token alongside the access token.
  266. Returns:
  267. result: Dictionary of account information after successful login.
  268. """
  269. # Before we actually log them in we check if they've already logged in
  270. # too often. This happens here rather than before as we don't
  271. # necessarily know the user before now.
  272. if ratelimit:
  273. await self._account_ratelimiter.ratelimit(None, user_id.lower())
  274. if create_non_existent_users:
  275. canonical_uid = await self.auth_handler.check_user_exists(user_id)
  276. if not canonical_uid:
  277. canonical_uid = await self.registration_handler.register_user(
  278. localpart=UserID.from_string(user_id).localpart
  279. )
  280. user_id = canonical_uid
  281. device_id = login_submission.get("device_id")
  282. initial_display_name = login_submission.get("initial_device_display_name")
  283. (
  284. device_id,
  285. access_token,
  286. valid_until_ms,
  287. refresh_token,
  288. ) = await self.registration_handler.register_device(
  289. user_id,
  290. device_id,
  291. initial_display_name,
  292. auth_provider_id=auth_provider_id,
  293. should_issue_refresh_token=should_issue_refresh_token,
  294. )
  295. result = LoginResponse(
  296. user_id=user_id,
  297. access_token=access_token,
  298. home_server=self.hs.hostname,
  299. device_id=device_id,
  300. )
  301. if valid_until_ms is not None:
  302. expires_in_ms = valid_until_ms - self.clock.time_msec()
  303. result["expires_in_ms"] = expires_in_ms
  304. if refresh_token is not None:
  305. result["refresh_token"] = refresh_token
  306. if callback is not None:
  307. await callback(result)
  308. return result
  309. async def _do_token_login(
  310. self, login_submission: JsonDict, should_issue_refresh_token: bool = False
  311. ) -> LoginResponse:
  312. """
  313. Handle the final stage of SSO login.
  314. Args:
  315. login_submission: The JSON request body.
  316. should_issue_refresh_token: True if this login should issue
  317. a refresh token alongside the access token.
  318. Returns:
  319. The body of the JSON response.
  320. """
  321. token = login_submission["token"]
  322. auth_handler = self.auth_handler
  323. res = await auth_handler.validate_short_term_login_token(token)
  324. return await self._complete_login(
  325. res.user_id,
  326. login_submission,
  327. self.auth_handler._sso_login_callback,
  328. auth_provider_id=res.auth_provider_id,
  329. should_issue_refresh_token=should_issue_refresh_token,
  330. )
  331. async def _do_jwt_login(
  332. self, login_submission: JsonDict, should_issue_refresh_token: bool = False
  333. ) -> LoginResponse:
  334. token = login_submission.get("token", None)
  335. if token is None:
  336. raise LoginError(
  337. 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
  338. )
  339. import jwt
  340. try:
  341. payload = jwt.decode(
  342. token,
  343. self.jwt_secret,
  344. algorithms=[self.jwt_algorithm],
  345. issuer=self.jwt_issuer,
  346. audience=self.jwt_audiences,
  347. )
  348. except jwt.PyJWTError as e:
  349. # A JWT error occurred, return some info back to the client.
  350. raise LoginError(
  351. 403,
  352. "JWT validation failed: %s" % (str(e),),
  353. errcode=Codes.FORBIDDEN,
  354. )
  355. user = payload.get("sub", None)
  356. if user is None:
  357. raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
  358. user_id = UserID(user, self.hs.hostname).to_string()
  359. result = await self._complete_login(
  360. user_id,
  361. login_submission,
  362. create_non_existent_users=True,
  363. should_issue_refresh_token=should_issue_refresh_token,
  364. )
  365. return result
  366. def _get_auth_flow_dict_for_idp(
  367. idp: SsoIdentityProvider, use_unstable_brands: bool = False
  368. ) -> JsonDict:
  369. """Return an entry for the login flow dict
  370. Returns an entry suitable for inclusion in "identity_providers" in the
  371. response to GET /_matrix/client/r0/login
  372. Args:
  373. idp: the identity provider to describe
  374. use_unstable_brands: whether we should use brand identifiers suitable
  375. for the unstable API
  376. """
  377. e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
  378. if idp.idp_icon:
  379. e["icon"] = idp.idp_icon
  380. if idp.idp_brand:
  381. e["brand"] = idp.idp_brand
  382. # use the stable brand identifier if the unstable identifier isn't defined.
  383. if use_unstable_brands and idp.unstable_idp_brand:
  384. e["brand"] = idp.unstable_idp_brand
  385. return e
  386. class RefreshTokenServlet(RestServlet):
  387. PATTERNS = client_patterns(
  388. "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
  389. )
  390. def __init__(self, hs: "HomeServer"):
  391. self._auth_handler = hs.get_auth_handler()
  392. self._clock = hs.get_clock()
  393. self.access_token_lifetime = hs.config.access_token_lifetime
  394. async def on_POST(
  395. self,
  396. request: SynapseRequest,
  397. ):
  398. refresh_submission = parse_json_object_from_request(request)
  399. assert_params_in_dict(refresh_submission, ["refresh_token"])
  400. token = refresh_submission["refresh_token"]
  401. if not isinstance(token, str):
  402. raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
  403. valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
  404. access_token, refresh_token = await self._auth_handler.refresh_token(
  405. token, valid_until_ms
  406. )
  407. expires_in_ms = valid_until_ms - self._clock.time_msec()
  408. return (
  409. 200,
  410. {
  411. "access_token": access_token,
  412. "refresh_token": refresh_token,
  413. "expires_in_ms": expires_in_ms,
  414. },
  415. )
  416. class SsoRedirectServlet(RestServlet):
  417. PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
  418. re.compile(
  419. "^"
  420. + CLIENT_API_PREFIX
  421. + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
  422. )
  423. ]
  424. def __init__(self, hs: "HomeServer"):
  425. # make sure that the relevant handlers are instantiated, so that they
  426. # register themselves with the main SSOHandler.
  427. if hs.config.cas_enabled:
  428. hs.get_cas_handler()
  429. if hs.config.saml2_enabled:
  430. hs.get_saml_handler()
  431. if hs.config.oidc_enabled:
  432. hs.get_oidc_handler()
  433. self._sso_handler = hs.get_sso_handler()
  434. self._msc2858_enabled = hs.config.experimental.msc2858_enabled
  435. self._public_baseurl = hs.config.public_baseurl
  436. def register(self, http_server: HttpServer) -> None:
  437. super().register(http_server)
  438. if self._msc2858_enabled:
  439. # expose additional endpoint for MSC2858 support: backwards-compat support
  440. # for clients which don't yet support the stable endpoints.
  441. http_server.register_paths(
  442. "GET",
  443. client_patterns(
  444. "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
  445. releases=(),
  446. unstable=True,
  447. ),
  448. self.on_GET,
  449. self.__class__.__name__,
  450. )
  451. async def on_GET(
  452. self, request: SynapseRequest, idp_id: Optional[str] = None
  453. ) -> None:
  454. if not self._public_baseurl:
  455. raise SynapseError(400, "SSO requires a valid public_baseurl")
  456. # if this isn't the expected hostname, redirect to the right one, so that we
  457. # get our cookies back.
  458. requested_uri = get_request_uri(request)
  459. baseurl_bytes = self._public_baseurl.encode("utf-8")
  460. if not requested_uri.startswith(baseurl_bytes):
  461. # swap out the incorrect base URL for the right one.
  462. #
  463. # The idea here is to redirect from
  464. # https://foo.bar/whatever/_matrix/...
  465. # to
  466. # https://public.baseurl/_matrix/...
  467. #
  468. i = requested_uri.index(b"/_matrix")
  469. new_uri = baseurl_bytes[:-1] + requested_uri[i:]
  470. logger.info(
  471. "Requested URI %s is not canonical: redirecting to %s",
  472. requested_uri.decode("utf-8", errors="replace"),
  473. new_uri.decode("utf-8", errors="replace"),
  474. )
  475. request.redirect(new_uri)
  476. finish_request(request)
  477. return
  478. args: Dict[bytes, List[bytes]] = request.args # type: ignore
  479. client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
  480. sso_url = await self._sso_handler.handle_redirect_request(
  481. request,
  482. client_redirect_url,
  483. idp_id,
  484. )
  485. logger.info("Redirecting to %s", sso_url)
  486. request.redirect(sso_url)
  487. finish_request(request)
  488. class CasTicketServlet(RestServlet):
  489. PATTERNS = client_patterns("/login/cas/ticket", v1=True)
  490. def __init__(self, hs):
  491. super().__init__()
  492. self._cas_handler = hs.get_cas_handler()
  493. async def on_GET(self, request: SynapseRequest) -> None:
  494. client_redirect_url = parse_string(request, "redirectUrl")
  495. ticket = parse_string(request, "ticket", required=True)
  496. # Maybe get a session ID (if this ticket is from user interactive
  497. # authentication).
  498. session = parse_string(request, "session")
  499. # Either client_redirect_url or session must be provided.
  500. if not client_redirect_url and not session:
  501. message = "Missing string query parameter redirectUrl or session"
  502. raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
  503. await self._cas_handler.handle_ticket(
  504. request, ticket, client_redirect_url, session
  505. )
  506. def register_servlets(hs, http_server):
  507. LoginRestServlet(hs).register(http_server)
  508. if hs.config.access_token_lifetime is not None:
  509. RefreshTokenServlet(hs).register(http_server)
  510. SsoRedirectServlet(hs).register(http_server)
  511. if hs.config.cas_enabled:
  512. CasTicketServlet(hs).register(http_server)