Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.
 
 
 
 
 
 

1693 Zeilen
65 KiB

  1. # Copyright 2020 Quentin Gliech
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import binascii
  16. import inspect
  17. import json
  18. import logging
  19. from typing import (
  20. TYPE_CHECKING,
  21. Any,
  22. Dict,
  23. Generic,
  24. List,
  25. Optional,
  26. Type,
  27. TypeVar,
  28. Union,
  29. )
  30. from urllib.parse import urlencode, urlparse
  31. import attr
  32. import unpaddedbase64
  33. from authlib.common.security import generate_token
  34. from authlib.jose import JsonWebToken, JWTClaims
  35. from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
  36. from authlib.oauth2.auth import ClientAuth
  37. from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
  38. from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge
  39. from authlib.oidc.core import CodeIDToken, UserInfo
  40. from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
  41. from jinja2 import Environment, Template
  42. from pymacaroons.exceptions import (
  43. MacaroonDeserializationException,
  44. MacaroonInitException,
  45. MacaroonInvalidSignatureException,
  46. )
  47. from typing_extensions import TypedDict
  48. from twisted.web.client import readBody
  49. from twisted.web.http_headers import Headers
  50. from synapse.api.errors import SynapseError
  51. from synapse.config import ConfigError
  52. from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
  53. from synapse.handlers.sso import MappingException, UserAttributes
  54. from synapse.http.server import finish_request
  55. from synapse.http.servlet import parse_string
  56. from synapse.http.site import SynapseRequest
  57. from synapse.logging.context import make_deferred_yieldable
  58. from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
  59. from synapse.util import Clock, json_decoder
  60. from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
  61. from synapse.util.macaroons import MacaroonGenerator, OidcSessionData
  62. from synapse.util.templates import _localpart_from_email_filter
  63. if TYPE_CHECKING:
  64. from synapse.server import HomeServer
  65. logger = logging.getLogger(__name__)
  66. # we want the cookie to be returned to us even when the request is the POSTed
  67. # result of a form on another domain, as is used with `response_mode=form_post`.
  68. #
  69. # Modern browsers will not do so unless we set SameSite=None; however *older*
  70. # browsers (including all versions of Safari on iOS 12?) don't support
  71. # SameSite=None, and interpret it as SameSite=Strict:
  72. # https://bugs.webkit.org/show_bug.cgi?id=198181
  73. #
  74. # As a rather painful workaround, we set *two* cookies, one with SameSite=None
  75. # and one with no SameSite, in the hope that at least one of them will get
  76. # back to us.
  77. #
  78. # Secure is necessary for SameSite=None (and, empirically, also breaks things
  79. # on iOS 12.)
  80. #
  81. # Here we have the names of the cookies, and the options we use to set them.
  82. _SESSION_COOKIES = [
  83. (b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
  84. (b"oidc_session_no_samesite", b"HttpOnly"),
  85. ]
  86. #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
  87. #: OpenID.Core sec 3.1.3.3.
  88. class Token(TypedDict):
  89. access_token: str
  90. token_type: str
  91. id_token: Optional[str]
  92. refresh_token: Optional[str]
  93. expires_in: int
  94. scope: Optional[str]
  95. #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
  96. #: there is no real point of doing this in our case.
  97. JWK = Dict[str, str]
  98. C = TypeVar("C")
  99. #: A JWK Set, as per RFC7517 sec 5.
  100. class JWKS(TypedDict):
  101. keys: List[JWK]
  102. class OidcHandler:
  103. """Handles requests related to the OpenID Connect login flow."""
  104. def __init__(self, hs: "HomeServer"):
  105. self._sso_handler = hs.get_sso_handler()
  106. provider_confs = hs.config.oidc.oidc_providers
  107. # we should not have been instantiated if there is no configured provider.
  108. assert provider_confs
  109. self._macaroon_generator = hs.get_macaroon_generator()
  110. self._providers: Dict[str, "OidcProvider"] = {
  111. p.idp_id: OidcProvider(hs, self._macaroon_generator, p)
  112. for p in provider_confs
  113. }
  114. async def load_metadata(self) -> None:
  115. """Validate the config and load the metadata from the remote endpoint.
  116. Called at startup to ensure we have everything we need.
  117. """
  118. for idp_id, p in self._providers.items():
  119. try:
  120. await p.load_metadata()
  121. if not p._uses_userinfo:
  122. await p.load_jwks()
  123. except Exception as e:
  124. raise Exception(
  125. "Error while initialising OIDC provider %r" % (idp_id,)
  126. ) from e
  127. async def handle_oidc_callback(self, request: SynapseRequest) -> None:
  128. """Handle an incoming request to /_synapse/client/oidc/callback
  129. Since we might want to display OIDC-related errors in a user-friendly
  130. way, we don't raise SynapseError from here. Instead, we call
  131. ``self._sso_handler.render_error`` which displays an HTML page for the error.
  132. Most of the OpenID Connect logic happens here:
  133. - first, we check if there was any error returned by the provider and
  134. display it
  135. - then we fetch the session cookie, decode and verify it
  136. - the ``state`` query parameter should match with the one stored in the
  137. session cookie
  138. Once we know the session is legit, we then delegate to the OIDC Provider
  139. implementation, which will exchange the code with the provider and complete the
  140. login/authentication.
  141. Args:
  142. request: the incoming request from the browser.
  143. """
  144. # This will always be set by the time Twisted calls us.
  145. assert request.args is not None
  146. # The provider might redirect with an error.
  147. # In that case, just display it as-is.
  148. if b"error" in request.args:
  149. # error response from the auth server. see:
  150. # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
  151. # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
  152. error = request.args[b"error"][0].decode()
  153. description = request.args.get(b"error_description", [b""])[0].decode()
  154. # Most of the errors returned by the provider could be due by
  155. # either the provider misbehaving or Synapse being misconfigured.
  156. # The only exception of that is "access_denied", where the user
  157. # probably cancelled the login flow. In other cases, log those errors.
  158. logger.log(
  159. logging.INFO if error == "access_denied" else logging.ERROR,
  160. "Received OIDC callback with error: %s %s",
  161. error,
  162. description,
  163. )
  164. self._sso_handler.render_error(request, error, description)
  165. return
  166. # otherwise, it is presumably a successful response. see:
  167. # https://tools.ietf.org/html/rfc6749#section-4.1.2
  168. # Fetch the session cookie. See the comments on SESSION_COOKIES for why there
  169. # are two.
  170. for cookie_name, _ in _SESSION_COOKIES:
  171. session: Optional[bytes] = request.getCookie(cookie_name)
  172. if session is not None:
  173. break
  174. else:
  175. logger.info("Received OIDC callback, with no session cookie")
  176. self._sso_handler.render_error(
  177. request, "missing_session", "No session cookie found"
  178. )
  179. return
  180. # Remove the cookies. There is a good chance that if the callback failed
  181. # once, it will fail next time and the code will already be exchanged.
  182. # Removing the cookies early avoids spamming the provider with token requests.
  183. #
  184. # we have to build the header by hand rather than calling request.addCookie
  185. # because the latter does not support SameSite=None
  186. # (https://twistedmatrix.com/trac/ticket/10088)
  187. for cookie_name, options in _SESSION_COOKIES:
  188. request.cookies.append(
  189. b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s"
  190. % (cookie_name, options)
  191. )
  192. # Check for the state query parameter
  193. if b"state" not in request.args:
  194. logger.info("Received OIDC callback, with no state parameter")
  195. self._sso_handler.render_error(
  196. request, "invalid_request", "State parameter is missing"
  197. )
  198. return
  199. state = request.args[b"state"][0].decode()
  200. # Deserialize the session token and verify it.
  201. try:
  202. session_data = self._macaroon_generator.verify_oidc_session_token(
  203. session, state
  204. )
  205. except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
  206. logger.exception("Invalid session for OIDC callback")
  207. self._sso_handler.render_error(request, "invalid_session", str(e))
  208. return
  209. except MacaroonInvalidSignatureException as e:
  210. logger.warning("Could not verify session for OIDC callback: %s", e)
  211. self._sso_handler.render_error(request, "mismatching_session", str(e))
  212. return
  213. logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
  214. oidc_provider = self._providers.get(session_data.idp_id)
  215. if not oidc_provider:
  216. logger.error("OIDC session uses unknown IdP %r", oidc_provider)
  217. self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
  218. return
  219. if b"code" not in request.args:
  220. logger.info("Code parameter is missing")
  221. self._sso_handler.render_error(
  222. request, "invalid_request", "Code parameter is missing"
  223. )
  224. return
  225. code = request.args[b"code"][0].decode()
  226. await oidc_provider.handle_oidc_callback(request, session_data, code)
  227. async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
  228. """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
  229. This extracts the logout_token from the request and tries to figure out
  230. which OpenID Provider it is comming from. This works by matching the iss claim
  231. with the issuer and the aud claim with the client_id.
  232. Since at this point we don't know who signed the JWT, we can't just
  233. decode it using authlib since it will always verifies the signature. We
  234. have to decode it manually without validating the signature. The actual JWT
  235. verification is done in the `OidcProvider.handler_backchannel_logout` method,
  236. once we figured out which provider sent the request.
  237. Args:
  238. request: the incoming request from the browser.
  239. """
  240. logout_token = parse_string(request, "logout_token")
  241. if logout_token is None:
  242. raise SynapseError(400, "Missing logout_token in request")
  243. # A JWT looks like this:
  244. # header.payload.signature
  245. # where all parts are encoded with urlsafe base64.
  246. # The aud and iss claims we care about are in the payload part, which
  247. # is a JSON object.
  248. try:
  249. # By destructuring the list after splitting, we ensure that we have
  250. # exactly 3 segments
  251. _, payload, _ = logout_token.split(".")
  252. except ValueError:
  253. raise SynapseError(400, "Invalid logout_token in request")
  254. try:
  255. payload_bytes = unpaddedbase64.decode_base64(payload)
  256. claims = json_decoder.decode(payload_bytes.decode("utf-8"))
  257. except (json.JSONDecodeError, binascii.Error, UnicodeError):
  258. raise SynapseError(400, "Invalid logout_token payload in request")
  259. try:
  260. # Let's extract the iss and aud claims
  261. iss = claims["iss"]
  262. aud = claims["aud"]
  263. # The aud claim can be either a string or a list of string. Here we
  264. # normalize it as a list of strings.
  265. if isinstance(aud, str):
  266. aud = [aud]
  267. # Check that we have the right types for the aud and the iss claims
  268. if not isinstance(iss, str) or not isinstance(aud, list):
  269. raise TypeError()
  270. for a in aud:
  271. if not isinstance(a, str):
  272. raise TypeError()
  273. # At this point we properly checked both claims types
  274. issuer: str = iss
  275. audience: List[str] = aud
  276. except (TypeError, KeyError):
  277. raise SynapseError(400, "Invalid issuer/audience in logout_token")
  278. # Now that we know the audience and the issuer, we can figure out from
  279. # what provider it is coming from
  280. oidc_provider: Optional[OidcProvider] = None
  281. for provider in self._providers.values():
  282. if provider.issuer == issuer and provider.client_id in audience:
  283. oidc_provider = provider
  284. break
  285. if oidc_provider is None:
  286. raise SynapseError(400, "Could not find the OP that issued this event")
  287. # Ask the provider to handle the logout request.
  288. await oidc_provider.handle_backchannel_logout(request, logout_token)
  289. class OidcError(Exception):
  290. """Used to catch errors when calling the token_endpoint"""
  291. def __init__(self, error: str, error_description: Optional[str] = None):
  292. self.error = error
  293. self.error_description = error_description
  294. def __str__(self) -> str:
  295. if self.error_description:
  296. return f"{self.error}: {self.error_description}"
  297. return self.error
  298. class OidcProvider:
  299. """Wraps the config for a single OIDC IdentityProvider
  300. Provides methods for handling redirect requests and callbacks via that particular
  301. IdP.
  302. """
  303. def __init__(
  304. self,
  305. hs: "HomeServer",
  306. macaroon_generator: MacaroonGenerator,
  307. provider: OidcProviderConfig,
  308. ):
  309. self._store = hs.get_datastores().main
  310. self._clock = hs.get_clock()
  311. self._macaroon_generaton = macaroon_generator
  312. self._config = provider
  313. self._callback_url: str = hs.config.oidc.oidc_callback_url
  314. # Calculate the prefix for OIDC callback paths based on the public_baseurl.
  315. # We'll insert this into the Path= parameter of any session cookies we set.
  316. public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
  317. self._callback_path_prefix = (
  318. public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
  319. )
  320. self._oidc_attribute_requirements = provider.attribute_requirements
  321. self._scopes = provider.scopes
  322. self._user_profile_method = provider.user_profile_method
  323. client_secret: Optional[Union[str, JwtClientSecret]] = None
  324. if provider.client_secret:
  325. client_secret = provider.client_secret
  326. elif provider.client_secret_jwt_key:
  327. client_secret = JwtClientSecret(
  328. provider.client_secret_jwt_key,
  329. provider.client_id,
  330. provider.issuer,
  331. hs.get_clock(),
  332. )
  333. self._client_auth = ClientAuth(
  334. provider.client_id,
  335. client_secret,
  336. provider.client_auth_method,
  337. )
  338. self._client_auth_method = provider.client_auth_method
  339. # cache of metadata for the identity provider (endpoint uris, mostly). This is
  340. # loaded on-demand from the discovery endpoint (if discovery is enabled), with
  341. # possible overrides from the config. Access via `load_metadata`.
  342. self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
  343. # cache of JWKs used by the identity provider to sign tokens. Loaded on demand
  344. # from the IdP's jwks_uri, if required.
  345. self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
  346. self._user_mapping_provider = provider.user_mapping_provider_class(
  347. provider.user_mapping_provider_config
  348. )
  349. self._skip_verification = provider.skip_verification
  350. self._allow_existing_users = provider.allow_existing_users
  351. self._http_client = hs.get_proxied_http_client()
  352. self._server_name: str = hs.config.server.server_name
  353. # identifier for the external_ids table
  354. self.idp_id = provider.idp_id
  355. # user-facing name of this auth provider
  356. self.idp_name = provider.idp_name
  357. # MXC URI for icon for this auth provider
  358. self.idp_icon = provider.idp_icon
  359. # optional brand identifier for this auth provider
  360. self.idp_brand = provider.idp_brand
  361. self._sso_handler = hs.get_sso_handler()
  362. self._device_handler = hs.get_device_handler()
  363. self._sso_handler.register_identity_provider(self)
  364. def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
  365. """Verifies the provider metadata.
  366. This checks the validity of the currently loaded provider. Not
  367. everything is checked, only:
  368. - ``issuer``
  369. - ``authorization_endpoint``
  370. - ``token_endpoint``
  371. - ``response_types_supported`` (checks if "code" is in it)
  372. - ``jwks_uri``
  373. Raises:
  374. ValueError: if something in the provider is not valid
  375. """
  376. # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
  377. if self._skip_verification is True:
  378. return
  379. m.validate_issuer()
  380. m.validate_authorization_endpoint()
  381. m.validate_token_endpoint()
  382. if m.get("token_endpoint_auth_methods_supported") is not None:
  383. m.validate_token_endpoint_auth_methods_supported()
  384. if (
  385. self._client_auth_method
  386. not in m["token_endpoint_auth_methods_supported"]
  387. ):
  388. raise ValueError(
  389. '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
  390. auth_method=self._client_auth_method,
  391. supported=m["token_endpoint_auth_methods_supported"],
  392. )
  393. )
  394. # If PKCE support is advertised ensure the wanted method is available.
  395. if m.get("code_challenge_methods_supported") is not None:
  396. m.validate_code_challenge_methods_supported()
  397. if "S256" not in m["code_challenge_methods_supported"]:
  398. raise ValueError(
  399. '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format(
  400. supported=m["code_challenge_methods_supported"],
  401. )
  402. )
  403. if m.get("response_types_supported") is not None:
  404. m.validate_response_types_supported()
  405. if "code" not in m["response_types_supported"]:
  406. raise ValueError(
  407. '"code" not in "response_types_supported" (%r)'
  408. % (m["response_types_supported"],)
  409. )
  410. # Ensure there's a userinfo endpoint to fetch from if it is required.
  411. if self._uses_userinfo:
  412. if m.get("userinfo_endpoint") is None:
  413. raise ValueError(
  414. 'provider has no "userinfo_endpoint", even though it is required'
  415. )
  416. else:
  417. # If we're not using userinfo, we need a valid jwks to validate the ID token
  418. m.validate_jwks_uri()
  419. if self._config.backchannel_logout_enabled:
  420. if not m.get("backchannel_logout_supported", False):
  421. logger.warning(
  422. "OIDC Back-Channel Logout is enabled for issuer %r"
  423. "but it does not advertise support for it",
  424. self.issuer,
  425. )
  426. elif not m.get("backchannel_logout_session_supported", False):
  427. logger.warning(
  428. "OIDC Back-Channel Logout is enabled and supported "
  429. "by issuer %r but it might not send a session ID with "
  430. "logout tokens, which is required for the logouts to work",
  431. self.issuer,
  432. )
  433. if not self._config.backchannel_logout_ignore_sub:
  434. # If OIDC backchannel logouts are enabled, the provider mapping provider
  435. # should use the `sub` claim. We verify that by mapping a dumb user and
  436. # see if we get back the sub claim
  437. user = UserInfo({"sub": "thisisasubject"})
  438. try:
  439. subject = self._user_mapping_provider.get_remote_user_id(user)
  440. if subject != user["sub"]:
  441. raise ValueError("Unexpected subject")
  442. except Exception:
  443. logger.warning(
  444. f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
  445. "but it looks like the configured `user_mapping_provider` "
  446. "does not use the `sub` claim as subject. If it is the case, "
  447. "and you want Synapse to ignore the `sub` claim in OIDC "
  448. "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
  449. "to `true` in the issuer config."
  450. )
  451. @property
  452. def _uses_userinfo(self) -> bool:
  453. """Returns True if the ``userinfo_endpoint`` should be used.
  454. This is based on the requested scopes: if the scopes include
  455. ``openid``, the provider should give use an ID token containing the
  456. user information. If not, we should fetch them using the
  457. ``access_token`` with the ``userinfo_endpoint``.
  458. """
  459. return (
  460. "openid" not in self._scopes
  461. or self._user_profile_method == "userinfo_endpoint"
  462. )
  463. @property
  464. def issuer(self) -> str:
  465. """The issuer identifying this provider."""
  466. return self._config.issuer
  467. @property
  468. def client_id(self) -> str:
  469. """The client_id used when interacting with this provider."""
  470. return self._config.client_id
  471. async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
  472. """Return the provider metadata.
  473. If this is the first call, the metadata is built from the config and from the
  474. metadata discovery endpoint (if enabled), and then validated. If the metadata
  475. is successfully validated, it is then cached for future use.
  476. Args:
  477. force: If true, any cached metadata is discarded to force a reload.
  478. Raises:
  479. ValueError: if something in the provider is not valid
  480. Returns:
  481. The provider's metadata.
  482. """
  483. if force:
  484. # reset the cached call to ensure we get a new result
  485. self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
  486. return await self._provider_metadata.get()
  487. async def _load_metadata(self) -> OpenIDProviderMetadata:
  488. # start out with just the issuer (unlike the other settings, discovered issuer
  489. # takes precedence over configured issuer, because configured issuer is
  490. # required for discovery to take place.)
  491. #
  492. metadata = OpenIDProviderMetadata(issuer=self._config.issuer)
  493. # load any data from the discovery endpoint, if enabled
  494. if self._config.discover:
  495. url = get_well_known_url(self._config.issuer, external=True)
  496. metadata_response = await self._http_client.get_json(url)
  497. metadata.update(metadata_response)
  498. # override any discovered data with any settings in our config
  499. if self._config.authorization_endpoint:
  500. metadata["authorization_endpoint"] = self._config.authorization_endpoint
  501. if self._config.token_endpoint:
  502. metadata["token_endpoint"] = self._config.token_endpoint
  503. if self._config.userinfo_endpoint:
  504. metadata["userinfo_endpoint"] = self._config.userinfo_endpoint
  505. if self._config.jwks_uri:
  506. metadata["jwks_uri"] = self._config.jwks_uri
  507. if self._config.pkce_method == "always":
  508. metadata["code_challenge_methods_supported"] = ["S256"]
  509. elif self._config.pkce_method == "never":
  510. metadata.pop("code_challenge_methods_supported", None)
  511. self._validate_metadata(metadata)
  512. return metadata
  513. async def load_jwks(self, force: bool = False) -> JWKS:
  514. """Load the JSON Web Key Set used to sign ID tokens.
  515. If we're not using the ``userinfo_endpoint``, user infos are extracted
  516. from the ID token, which is a JWT signed by keys given by the provider.
  517. The keys are then cached.
  518. Args:
  519. force: Force reloading the keys.
  520. Returns:
  521. The key set
  522. Looks like this::
  523. {
  524. 'keys': [
  525. {
  526. 'kid': 'abcdef',
  527. 'kty': 'RSA',
  528. 'alg': 'RS256',
  529. 'use': 'sig',
  530. 'e': 'XXXX',
  531. 'n': 'XXXX',
  532. }
  533. ]
  534. }
  535. """
  536. if force:
  537. # reset the cached call to ensure we get a new result
  538. self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
  539. return await self._jwks.get()
  540. async def _load_jwks(self) -> JWKS:
  541. metadata = await self.load_metadata()
  542. # Load the JWKS using the `jwks_uri` metadata.
  543. uri = metadata.get("jwks_uri")
  544. if not uri:
  545. # this should be unreachable: load_metadata validates that
  546. # there is a jwks_uri in the metadata if _uses_userinfo is unset
  547. raise RuntimeError('Missing "jwks_uri" in metadata')
  548. jwk_set = await self._http_client.get_json(uri)
  549. return jwk_set
  550. async def _exchange_code(self, code: str, code_verifier: str) -> Token:
  551. """Exchange an authorization code for a token.
  552. This calls the ``token_endpoint`` with the authorization code we
  553. received in the callback to exchange it for a token. The call uses the
  554. ``ClientAuth`` to authenticate with the client with its ID and secret.
  555. See:
  556. https://tools.ietf.org/html/rfc6749#section-3.2
  557. https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
  558. Args:
  559. code: The authorization code we got from the callback.
  560. code_verifier: The PKCE code verifier to send, blank if unused.
  561. Returns:
  562. A dict containing various tokens.
  563. May look like this::
  564. {
  565. 'token_type': 'bearer',
  566. 'access_token': 'abcdef',
  567. 'expires_in': 3599,
  568. 'id_token': 'ghijkl',
  569. 'refresh_token': 'mnopqr',
  570. }
  571. Raises:
  572. OidcError: when the ``token_endpoint`` returned an error.
  573. """
  574. metadata = await self.load_metadata()
  575. token_endpoint = metadata.get("token_endpoint")
  576. raw_headers: Dict[str, str] = {
  577. "Content-Type": "application/x-www-form-urlencoded",
  578. "User-Agent": self._http_client.user_agent.decode("ascii"),
  579. "Accept": "application/json",
  580. }
  581. args = {
  582. "grant_type": "authorization_code",
  583. "code": code,
  584. "redirect_uri": self._callback_url,
  585. }
  586. if code_verifier:
  587. args["code_verifier"] = code_verifier
  588. body = urlencode(args, True)
  589. # Fill the body/headers with credentials
  590. uri, raw_headers, body = self._client_auth.prepare(
  591. method="POST", uri=token_endpoint, headers=raw_headers, body=body
  592. )
  593. headers = Headers({k: [v] for (k, v) in raw_headers.items()})
  594. # Do the actual request
  595. # We're not using the SimpleHttpClient util methods as we don't want to
  596. # check the HTTP status code and we do the body encoding ourself.
  597. response = await self._http_client.request(
  598. method="POST",
  599. uri=uri,
  600. data=body.encode("utf-8"),
  601. headers=headers,
  602. )
  603. # This is used in multiple error messages below
  604. status = "{code} {phrase}".format(
  605. code=response.code, phrase=response.phrase.decode("utf-8")
  606. )
  607. resp_body = await make_deferred_yieldable(readBody(response))
  608. if response.code >= 500:
  609. # In case of a server error, we should first try to decode the body
  610. # and check for an error field. If not, we respond with a generic
  611. # error message.
  612. try:
  613. resp = json_decoder.decode(resp_body.decode("utf-8"))
  614. error = resp["error"]
  615. description = resp.get("error_description", error)
  616. except (ValueError, KeyError):
  617. # Catch ValueError for the JSON decoding and KeyError for the "error" field
  618. error = "server_error"
  619. description = (
  620. (
  621. 'Authorization server responded with a "{status}" error '
  622. "while exchanging the authorization code."
  623. ).format(status=status),
  624. )
  625. raise OidcError(error, description)
  626. # Since it is a not a 5xx code, body should be a valid JSON. It will
  627. # raise if not.
  628. resp = json_decoder.decode(resp_body.decode("utf-8"))
  629. if "error" in resp:
  630. error = resp["error"]
  631. # In case the authorization server responded with an error field,
  632. # it should be a 4xx code. If not, warn about it but don't do
  633. # anything special and report the original error message.
  634. if response.code < 400:
  635. logger.debug(
  636. "Invalid response from the authorization server: "
  637. 'responded with a "{status}" '
  638. "but body has an error field: {error!r}".format(
  639. status=status, error=resp["error"]
  640. )
  641. )
  642. description = resp.get("error_description", error)
  643. raise OidcError(error, description)
  644. # Now, this should not be an error. According to RFC6749 sec 5.1, it
  645. # should be a 200 code. We're a bit more flexible than that, and will
  646. # only throw on a 4xx code.
  647. if response.code >= 400:
  648. description = (
  649. 'Authorization server responded with a "{status}" error '
  650. 'but did not include an "error" field in its response.'.format(
  651. status=status
  652. )
  653. )
  654. logger.warning(description)
  655. # Body was still valid JSON. Might be useful to log it for debugging.
  656. logger.warning("Code exchange response: %r", resp)
  657. raise OidcError("server_error", description)
  658. return resp
  659. async def _fetch_userinfo(self, token: Token) -> UserInfo:
  660. """Fetch user information from the ``userinfo_endpoint``.
  661. Args:
  662. token: the token given by the ``token_endpoint``.
  663. Must include an ``access_token`` field.
  664. Returns:
  665. an object representing the user.
  666. """
  667. logger.debug("Using the OAuth2 access_token to request userinfo")
  668. metadata = await self.load_metadata()
  669. resp = await self._http_client.get_json(
  670. metadata["userinfo_endpoint"],
  671. headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
  672. )
  673. logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
  674. return UserInfo(resp)
  675. async def _verify_jwt(
  676. self,
  677. alg_values: List[str],
  678. token: str,
  679. claims_cls: Type[C],
  680. claims_options: Optional[dict] = None,
  681. claims_params: Optional[dict] = None,
  682. ) -> C:
  683. """Decode and validate a JWT, re-fetching the JWKS as needed.
  684. Args:
  685. alg_values: list of `alg` values allowed when verifying the JWT.
  686. token: the JWT.
  687. claims_cls: the JWTClaims class to use to validate the claims.
  688. claims_options: dict of options passed to the `claims_cls` constructor.
  689. claims_params: dict of params passed to the `claims_cls` constructor.
  690. Returns:
  691. The decoded claims in the JWT.
  692. """
  693. jwt = JsonWebToken(alg_values)
  694. logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
  695. # Try to decode the keys in cache first, then retry by forcing the keys
  696. # to be reloaded
  697. jwk_set = await self.load_jwks()
  698. try:
  699. claims = jwt.decode(
  700. token,
  701. key=jwk_set,
  702. claims_cls=claims_cls,
  703. claims_options=claims_options,
  704. claims_params=claims_params,
  705. )
  706. except ValueError:
  707. logger.info("Reloading JWKS after decode error")
  708. jwk_set = await self.load_jwks(force=True) # try reloading the jwks
  709. claims = jwt.decode(
  710. token,
  711. key=jwk_set,
  712. claims_cls=claims_cls,
  713. claims_options=claims_options,
  714. claims_params=claims_params,
  715. )
  716. logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
  717. claims.validate(
  718. now=self._clock.time(), leeway=120
  719. ) # allows 2 min of clock skew
  720. return claims
  721. async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
  722. """Return an instance of UserInfo from token's ``id_token``.
  723. Args:
  724. token: the token given by the ``token_endpoint``.
  725. Must include an ``id_token`` field.
  726. nonce: the nonce value originally sent in the initial authorization
  727. request. This value should match the one inside the token.
  728. Returns:
  729. The decoded claims in the ID token.
  730. """
  731. id_token = token.get("id_token")
  732. # That has been theoritically been checked by the caller, so even though
  733. # assertion are not enabled in production, it is mainly here to appease mypy
  734. assert id_token is not None
  735. metadata = await self.load_metadata()
  736. claims_params = {
  737. "nonce": nonce,
  738. "client_id": self._client_auth.client_id,
  739. }
  740. if "access_token" in token:
  741. # If we got an `access_token`, there should be an `at_hash` claim
  742. # in the `id_token` that we can check against.
  743. claims_params["access_token"] = token["access_token"]
  744. claims_options = {"iss": {"values": [metadata["issuer"]]}}
  745. alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
  746. claims = await self._verify_jwt(
  747. alg_values=alg_values,
  748. token=id_token,
  749. claims_cls=CodeIDToken,
  750. claims_options=claims_options,
  751. claims_params=claims_params,
  752. )
  753. return claims
  754. async def handle_redirect_request(
  755. self,
  756. request: SynapseRequest,
  757. client_redirect_url: Optional[bytes],
  758. ui_auth_session_id: Optional[str] = None,
  759. ) -> str:
  760. """Handle an incoming request to /login/sso/redirect
  761. It returns a redirect to the authorization endpoint with a few
  762. parameters:
  763. - ``client_id``: the client ID set in ``oidc_config.client_id``
  764. - ``response_type``: ``code``
  765. - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
  766. - ``scope``: the list of scopes set in ``oidc_config.scopes``
  767. - ``state``: a random string
  768. - ``nonce``: a random string
  769. - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported)
  770. In addition to generating a redirect URL, we are setting a cookie with
  771. a signed macaroon token containing the state, the nonce, the
  772. client_redirect_url, and (optionally) the code_verifier params. The state,
  773. nonce, and client_redirect_url are then checked when the client comes back
  774. from the provider. The code_verifier is passed back to the server during
  775. the token exchange and compared to the code_challenge sent in this request.
  776. Args:
  777. request: the incoming request from the browser.
  778. We'll respond to it with a redirect and a cookie.
  779. client_redirect_url: the URL that we should redirect the client to
  780. when everything is done (or None for UI Auth)
  781. ui_auth_session_id: The session ID of the ongoing UI Auth (or
  782. None if this is a login).
  783. Returns:
  784. The redirect URL to the authorization endpoint.
  785. """
  786. state = generate_token()
  787. nonce = generate_token()
  788. code_verifier = ""
  789. if not client_redirect_url:
  790. client_redirect_url = b""
  791. metadata = await self.load_metadata()
  792. # Automatically enable PKCE if it is supported.
  793. extra_grant_values = {}
  794. if metadata.get("code_challenge_methods_supported"):
  795. code_verifier = generate_token(48)
  796. # Note that we verified the server supports S256 earlier (in
  797. # OidcProvider._validate_metadata).
  798. extra_grant_values = {
  799. "code_challenge_method": "S256",
  800. "code_challenge": create_s256_code_challenge(code_verifier),
  801. }
  802. cookie = self._macaroon_generaton.generate_oidc_session_token(
  803. state=state,
  804. session_data=OidcSessionData(
  805. idp_id=self.idp_id,
  806. nonce=nonce,
  807. client_redirect_url=client_redirect_url.decode(),
  808. ui_auth_session_id=ui_auth_session_id or "",
  809. code_verifier=code_verifier,
  810. ),
  811. )
  812. # Set the cookies. See the comments on _SESSION_COOKIES for why there are two.
  813. #
  814. # we have to build the header by hand rather than calling request.addCookie
  815. # because the latter does not support SameSite=None
  816. # (https://twistedmatrix.com/trac/ticket/10088)
  817. for cookie_name, options in _SESSION_COOKIES:
  818. request.cookies.append(
  819. b"%s=%s; Max-Age=3600; Path=%s; %s"
  820. % (
  821. cookie_name,
  822. cookie.encode("utf-8"),
  823. self._callback_path_prefix,
  824. options,
  825. )
  826. )
  827. authorization_endpoint = metadata.get("authorization_endpoint")
  828. return prepare_grant_uri(
  829. authorization_endpoint,
  830. client_id=self._client_auth.client_id,
  831. response_type="code",
  832. redirect_uri=self._callback_url,
  833. scope=self._scopes,
  834. state=state,
  835. nonce=nonce,
  836. **extra_grant_values,
  837. )
  838. async def handle_oidc_callback(
  839. self, request: SynapseRequest, session_data: "OidcSessionData", code: str
  840. ) -> None:
  841. """Handle an incoming request to /_synapse/client/oidc/callback
  842. By this time we have already validated the session on the synapse side, and
  843. now need to do the provider-specific operations. This includes:
  844. - exchange the code with the provider using the ``token_endpoint`` (see
  845. ``_exchange_code``)
  846. - once we have the token, use it to either extract the UserInfo from
  847. the ``id_token`` (``_parse_id_token``), or use the ``access_token``
  848. to fetch UserInfo from the ``userinfo_endpoint``
  849. (``_fetch_userinfo``)
  850. - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
  851. finish the login
  852. Args:
  853. request: the incoming request from the browser.
  854. session_data: the session data, extracted from our cookie
  855. code: The authorization code we got from the callback.
  856. """
  857. # Exchange the code with the provider
  858. try:
  859. logger.debug("Exchanging OAuth2 code for a token")
  860. token = await self._exchange_code(
  861. code, code_verifier=session_data.code_verifier
  862. )
  863. except OidcError as e:
  864. logger.warning("Could not exchange OAuth2 code: %s", e)
  865. self._sso_handler.render_error(request, e.error, e.error_description)
  866. return
  867. logger.debug("Successfully obtained OAuth2 token data: %r", token)
  868. # If there is an id_token, it should be validated, regardless of the
  869. # userinfo endpoint is used or not.
  870. if token.get("id_token") is not None:
  871. try:
  872. id_token = await self._parse_id_token(token, nonce=session_data.nonce)
  873. sid = id_token.get("sid")
  874. except Exception as e:
  875. logger.exception("Invalid id_token")
  876. self._sso_handler.render_error(request, "invalid_token", str(e))
  877. return
  878. else:
  879. id_token = None
  880. sid = None
  881. # Now that we have a token, get the userinfo either from the `id_token`
  882. # claims or by fetching the `userinfo_endpoint`.
  883. if self._uses_userinfo:
  884. try:
  885. userinfo = await self._fetch_userinfo(token)
  886. except Exception as e:
  887. logger.exception("Could not fetch userinfo")
  888. self._sso_handler.render_error(request, "fetch_error", str(e))
  889. return
  890. elif id_token is not None:
  891. userinfo = UserInfo(id_token)
  892. else:
  893. logger.error("Missing id_token in token response")
  894. self._sso_handler.render_error(
  895. request, "invalid_token", "Missing id_token in token response"
  896. )
  897. return
  898. # first check if we're doing a UIA
  899. if session_data.ui_auth_session_id:
  900. try:
  901. remote_user_id = self._remote_id_from_userinfo(userinfo)
  902. except Exception as e:
  903. logger.exception("Could not extract remote user id")
  904. self._sso_handler.render_error(request, "mapping_error", str(e))
  905. return
  906. return await self._sso_handler.complete_sso_ui_auth_request(
  907. self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
  908. )
  909. # otherwise, it's a login
  910. logger.debug("Userinfo for OIDC login: %s", userinfo)
  911. # Ensure that the attributes of the logged in user meet the required
  912. # attributes by checking the userinfo against attribute_requirements
  913. # In order to deal with the fact that OIDC userinfo can contain many
  914. # types of data, we wrap non-list values in lists.
  915. if not self._sso_handler.check_required_attributes(
  916. request,
  917. {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()},
  918. self._oidc_attribute_requirements,
  919. ):
  920. return
  921. # Call the mapper to register/login the user
  922. try:
  923. await self._complete_oidc_login(
  924. userinfo, token, request, session_data.client_redirect_url, sid
  925. )
  926. except MappingException as e:
  927. logger.exception("Could not map user")
  928. self._sso_handler.render_error(request, "mapping_error", str(e))
  929. async def _complete_oidc_login(
  930. self,
  931. userinfo: UserInfo,
  932. token: Token,
  933. request: SynapseRequest,
  934. client_redirect_url: str,
  935. sid: Optional[str],
  936. ) -> None:
  937. """Given a UserInfo response, complete the login flow
  938. UserInfo should have a claim that uniquely identifies users. This claim
  939. is usually `sub`, but can be configured with `oidc_config.subject_claim`.
  940. It is then used as an `external_id`.
  941. If we don't find the user that way, we should register the user,
  942. mapping the localpart and the display name from the UserInfo.
  943. If a user already exists with the mxid we've mapped and allow_existing_users
  944. is disabled, raise an exception.
  945. Otherwise, render a redirect back to the client_redirect_url with a loginToken.
  946. Args:
  947. userinfo: an object representing the user
  948. token: a dict with the tokens obtained from the provider
  949. request: The request to respond to
  950. client_redirect_url: The redirect URL passed in by the client.
  951. Raises:
  952. MappingException: if there was an error while mapping some properties
  953. """
  954. try:
  955. remote_user_id = self._remote_id_from_userinfo(userinfo)
  956. except Exception as e:
  957. raise MappingException(
  958. "Failed to extract subject from OIDC response: %s" % (e,)
  959. )
  960. # Older mapping providers don't accept the `failures` argument, so we
  961. # try and detect support.
  962. mapper_signature = inspect.signature(
  963. self._user_mapping_provider.map_user_attributes
  964. )
  965. supports_failures = "failures" in mapper_signature.parameters
  966. async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
  967. """
  968. Call the mapping provider to map the OIDC userinfo and token to user attributes.
  969. This is backwards compatibility for abstraction for the SSO handler.
  970. """
  971. if supports_failures:
  972. attributes = await self._user_mapping_provider.map_user_attributes(
  973. userinfo, token, failures
  974. )
  975. else:
  976. # If the mapping provider does not support processing failures,
  977. # do not continually generate the same Matrix ID since it will
  978. # continue to already be in use. Note that the error raised is
  979. # arbitrary and will get turned into a MappingException.
  980. if failures:
  981. raise MappingException(
  982. "Mapping provider does not support de-duplicating Matrix IDs"
  983. )
  984. attributes = await self._user_mapping_provider.map_user_attributes(
  985. userinfo, token
  986. )
  987. return UserAttributes(**attributes)
  988. async def grandfather_existing_users() -> Optional[str]:
  989. if self._allow_existing_users:
  990. # If allowing existing users we want to generate a single localpart
  991. # and attempt to match it.
  992. attributes = await oidc_response_to_user_attributes(failures=0)
  993. if attributes.localpart is None:
  994. # If no localpart is returned then we will generate one, so
  995. # there is no need to search for existing users.
  996. return None
  997. user_id = UserID(attributes.localpart, self._server_name).to_string()
  998. users = await self._store.get_users_by_id_case_insensitive(user_id)
  999. if users:
  1000. # If an existing matrix ID is returned, then use it.
  1001. if len(users) == 1:
  1002. previously_registered_user_id = next(iter(users))
  1003. elif user_id in users:
  1004. previously_registered_user_id = user_id
  1005. else:
  1006. # Do not attempt to continue generating Matrix IDs.
  1007. raise MappingException(
  1008. "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
  1009. user_id, users
  1010. )
  1011. )
  1012. return previously_registered_user_id
  1013. return None
  1014. # Mapping providers might not have get_extra_attributes: only call this
  1015. # method if it exists.
  1016. extra_attributes = None
  1017. get_extra_attributes = getattr(
  1018. self._user_mapping_provider, "get_extra_attributes", None
  1019. )
  1020. if get_extra_attributes:
  1021. extra_attributes = await get_extra_attributes(userinfo, token)
  1022. await self._sso_handler.complete_sso_login_request(
  1023. self.idp_id,
  1024. remote_user_id,
  1025. request,
  1026. client_redirect_url,
  1027. oidc_response_to_user_attributes,
  1028. grandfather_existing_users,
  1029. extra_attributes,
  1030. auth_provider_session_id=sid,
  1031. registration_enabled=self._config.enable_registration,
  1032. )
  1033. def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
  1034. """Extract the unique remote id from an OIDC UserInfo block
  1035. Args:
  1036. userinfo: An object representing the user given by the OIDC provider
  1037. Returns:
  1038. remote user id
  1039. """
  1040. remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
  1041. # Some OIDC providers use integer IDs, but Synapse expects external IDs
  1042. # to be strings.
  1043. return str(remote_user_id)
  1044. async def handle_backchannel_logout(
  1045. self, request: SynapseRequest, logout_token: str
  1046. ) -> None:
  1047. """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
  1048. The OIDC Provider posts a logout token to this endpoint when a user
  1049. session ends. That token is a JWT signed with the same keys as
  1050. ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
  1051. validate the JWT and figure out what session to end.
  1052. Args:
  1053. request: The request to respond to
  1054. logout_token: The logout token (a JWT) extracted from the request body
  1055. """
  1056. # Back-Channel Logout can be disabled in the config, hence this check.
  1057. # This is not that important for now since Synapse is registered
  1058. # manually to the OP, so not specifying the backchannel-logout URI is
  1059. # as effective than disabling it here. It might make more sense if we
  1060. # support dynamic registration in Synapse at some point.
  1061. if not self._config.backchannel_logout_enabled:
  1062. logger.warning(
  1063. f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
  1064. )
  1065. # TODO: this responds with a 400 status code, which is what the OIDC
  1066. # Back-Channel Logout spec expects, but spec also suggests answering with
  1067. # a JSON object, with the `error` and `error_description` fields set, which
  1068. # we are not doing here.
  1069. # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
  1070. raise SynapseError(
  1071. 400, "OpenID Connect Back-Channel Logout is disabled for this provider"
  1072. )
  1073. metadata = await self.load_metadata()
  1074. # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
  1075. # A Logout Token MUST be signed and MAY also be encrypted. The same
  1076. # keys are used to sign and encrypt Logout Tokens as are used for ID
  1077. # Tokens. If the Logout Token is encrypted, it SHOULD replicate the
  1078. # iss (issuer) claim in the JWT Header Parameters, as specified in
  1079. # Section 5.3 of [JWT].
  1080. alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
  1081. # As per sec. 2.6:
  1082. # 3. Validate the iss, aud, and iat Claims in the same way they are
  1083. # validated in ID Tokens.
  1084. # Which means the audience should contain Synapse's client_id and the
  1085. # issuer should be the IdP issuer
  1086. claims_options = {
  1087. "iss": {"values": [metadata["issuer"]]},
  1088. "aud": {"values": [self.client_id]},
  1089. }
  1090. try:
  1091. claims = await self._verify_jwt(
  1092. alg_values=alg_values,
  1093. token=logout_token,
  1094. claims_cls=LogoutToken,
  1095. claims_options=claims_options,
  1096. )
  1097. except JoseError:
  1098. logger.exception("Invalid logout_token")
  1099. raise SynapseError(400, "Invalid logout_token")
  1100. # As per sec. 2.6:
  1101. # 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
  1102. # or both.
  1103. # 5. Verify that the Logout Token contains an events Claim whose
  1104. # value is JSON object containing the member name
  1105. # http://schemas.openid.net/event/backchannel-logout.
  1106. # 6. Verify that the Logout Token does not contain a nonce Claim.
  1107. # This is all verified by the LogoutToken claims class, so at this
  1108. # point the `sid` claim exists and is a string.
  1109. sid: str = claims.get("sid")
  1110. # If the `sub` claim was included in the logout token, we check that it matches
  1111. # that it matches the right user. We can have cases where the `sub` claim is not
  1112. # the ID saved in database, so we let admins disable this check in config.
  1113. sub: Optional[str] = claims.get("sub")
  1114. expected_user_id: Optional[str] = None
  1115. if sub is not None and not self._config.backchannel_logout_ignore_sub:
  1116. expected_user_id = await self._store.get_user_by_external_id(
  1117. self.idp_id, sub
  1118. )
  1119. # Invalidate any running user-mapping sessions, in-flight login tokens and
  1120. # active devices
  1121. await self._sso_handler.revoke_sessions_for_provider_session_id(
  1122. auth_provider_id=self.idp_id,
  1123. auth_provider_session_id=sid,
  1124. expected_user_id=expected_user_id,
  1125. )
  1126. request.setResponseCode(200)
  1127. request.setHeader(b"Cache-Control", b"no-cache, no-store")
  1128. request.setHeader(b"Pragma", b"no-cache")
  1129. finish_request(request)
  1130. class LogoutToken(JWTClaims): # type: ignore[misc]
  1131. """
  1132. Holds and verify claims of a logout token, as per
  1133. https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
  1134. """
  1135. REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
  1136. def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
  1137. """Validate everything in claims payload."""
  1138. super().validate(now, leeway)
  1139. self.validate_sid()
  1140. self.validate_events()
  1141. self.validate_nonce()
  1142. def validate_sid(self) -> None:
  1143. """Ensure the sid claim is present"""
  1144. sid = self.get("sid")
  1145. if not sid:
  1146. raise MissingClaimError("sid")
  1147. if not isinstance(sid, str):
  1148. raise InvalidClaimError("sid")
  1149. def validate_nonce(self) -> None:
  1150. """Ensure the nonce claim is absent"""
  1151. if "nonce" in self:
  1152. raise InvalidClaimError("nonce")
  1153. def validate_events(self) -> None:
  1154. """Ensure the events claim is present and with the right value"""
  1155. events = self.get("events")
  1156. if not events:
  1157. raise MissingClaimError("events")
  1158. if not isinstance(events, dict):
  1159. raise InvalidClaimError("events")
  1160. if "http://schemas.openid.net/event/backchannel-logout" not in events:
  1161. raise InvalidClaimError("events")
  1162. # number of seconds a newly-generated client secret should be valid for
  1163. CLIENT_SECRET_VALIDITY_SECONDS = 3600
  1164. # minimum remaining validity on a client secret before we should generate a new one
  1165. CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
  1166. class JwtClientSecret:
  1167. """A class which generates a new client secret on demand, based on a JWK
  1168. This implementation is designed to comply with the requirements for Apple Sign in:
  1169. https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
  1170. It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
  1171. but it's worth noting that we still put the generated secret in the "client_secret"
  1172. field (or rather, whereever client_auth_method puts it) rather than in a
  1173. client_assertion field in the body as that RFC seems to require.
  1174. """
  1175. def __init__(
  1176. self,
  1177. key: OidcProviderClientSecretJwtKey,
  1178. oauth_client_id: str,
  1179. oauth_issuer: str,
  1180. clock: Clock,
  1181. ):
  1182. self._key = key
  1183. self._oauth_client_id = oauth_client_id
  1184. self._oauth_issuer = oauth_issuer
  1185. self._clock = clock
  1186. self._cached_secret = b""
  1187. self._cached_secret_replacement_time = 0
  1188. def __str__(self) -> str:
  1189. # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
  1190. # encode_client_secret_basic, which calls "{}".format(secret), which ends up
  1191. # here.
  1192. return self._get_secret().decode("ascii")
  1193. def __bytes__(self) -> bytes:
  1194. # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
  1195. # encode_client_secret_post, which ends up here.
  1196. return self._get_secret()
  1197. def _get_secret(self) -> bytes:
  1198. now = self._clock.time()
  1199. # if we have enough validity on our existing secret, use it
  1200. if now < self._cached_secret_replacement_time:
  1201. return self._cached_secret
  1202. issued_at = int(now)
  1203. expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
  1204. # we copy the configured header because jwt.encode modifies it.
  1205. header = dict(self._key.jwt_header)
  1206. # see https://tools.ietf.org/html/rfc7523#section-3
  1207. payload = {
  1208. "sub": self._oauth_client_id,
  1209. "aud": self._oauth_issuer,
  1210. "iat": issued_at,
  1211. "exp": expires_at,
  1212. **self._key.jwt_payload,
  1213. }
  1214. logger.info(
  1215. "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
  1216. )
  1217. jwt = JsonWebToken(header["alg"])
  1218. self._cached_secret = jwt.encode(header, payload, self._key.key)
  1219. self._cached_secret_replacement_time = (
  1220. expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
  1221. )
  1222. return self._cached_secret
  1223. class UserAttributeDict(TypedDict):
  1224. localpart: Optional[str]
  1225. confirm_localpart: bool
  1226. display_name: Optional[str]
  1227. picture: Optional[str] # may be omitted by older `OidcMappingProviders`
  1228. emails: List[str]
  1229. class OidcMappingProvider(Generic[C]):
  1230. """A mapping provider maps a UserInfo object to user attributes.
  1231. It should provide the API described by this class.
  1232. """
  1233. def __init__(self, config: C):
  1234. """
  1235. Args:
  1236. config: A custom config object from this module, parsed by ``parse_config()``
  1237. """
  1238. @staticmethod
  1239. def parse_config(config: dict) -> C:
  1240. """Parse the dict provided by the homeserver's config
  1241. Args:
  1242. config: A dictionary containing configuration options for this provider
  1243. Returns:
  1244. A custom config object for this module
  1245. """
  1246. raise NotImplementedError()
  1247. def get_remote_user_id(self, userinfo: UserInfo) -> str:
  1248. """Get a unique user ID for this user.
  1249. Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
  1250. Args:
  1251. userinfo: An object representing the user given by the OIDC provider
  1252. Returns:
  1253. A unique user ID
  1254. """
  1255. raise NotImplementedError()
  1256. async def map_user_attributes(
  1257. self, userinfo: UserInfo, token: Token, failures: int
  1258. ) -> UserAttributeDict:
  1259. """Map a `UserInfo` object into user attributes.
  1260. Args:
  1261. userinfo: An object representing the user given by the OIDC provider
  1262. token: A dict with the tokens returned by the provider
  1263. failures: How many times a call to this function with this
  1264. UserInfo has resulted in a failure.
  1265. Returns:
  1266. A dict containing the ``localpart`` and (optionally) the ``display_name``
  1267. """
  1268. raise NotImplementedError()
  1269. async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
  1270. """Map a `UserInfo` object into additional attributes passed to the client during login.
  1271. Args:
  1272. userinfo: An object representing the user given by the OIDC provider
  1273. token: A dict with the tokens returned by the provider
  1274. Returns:
  1275. A dict containing additional attributes. Must be JSON serializable.
  1276. """
  1277. return {}
  1278. # Used to clear out "None" values in templates
  1279. def jinja_finalize(thing: Any) -> Any:
  1280. return thing if thing is not None else ""
  1281. env = Environment(finalize=jinja_finalize)
  1282. env.filters.update(
  1283. {
  1284. "localpart_from_email": _localpart_from_email_filter,
  1285. }
  1286. )
  1287. @attr.s(slots=True, frozen=True, auto_attribs=True)
  1288. class JinjaOidcMappingConfig:
  1289. subject_template: Template
  1290. picture_template: Template
  1291. localpart_template: Optional[Template]
  1292. display_name_template: Optional[Template]
  1293. email_template: Optional[Template]
  1294. extra_attributes: Dict[str, Template]
  1295. confirm_localpart: bool = False
  1296. class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
  1297. """An implementation of a mapping provider based on Jinja templates.
  1298. This is the default mapping provider.
  1299. """
  1300. def __init__(self, config: JinjaOidcMappingConfig):
  1301. self._config = config
  1302. @staticmethod
  1303. def parse_config(config: dict) -> JinjaOidcMappingConfig:
  1304. def parse_template_config_with_claim(
  1305. option_name: str, default_claim: str
  1306. ) -> Template:
  1307. template_name = f"{option_name}_template"
  1308. template = config.get(template_name)
  1309. if not template:
  1310. # Convert the legacy subject_claim into a template.
  1311. claim = config.get(f"{option_name}_claim", default_claim)
  1312. template = "{{ user.%s }}" % (claim,)
  1313. try:
  1314. return env.from_string(template)
  1315. except Exception as e:
  1316. raise ConfigError("invalid jinja template", path=[template_name]) from e
  1317. subject_template = parse_template_config_with_claim("subject", "sub")
  1318. picture_template = parse_template_config_with_claim("picture", "picture")
  1319. def parse_template_config(option_name: str) -> Optional[Template]:
  1320. if option_name not in config:
  1321. return None
  1322. try:
  1323. return env.from_string(config[option_name])
  1324. except Exception as e:
  1325. raise ConfigError("invalid jinja template", path=[option_name]) from e
  1326. localpart_template = parse_template_config("localpart_template")
  1327. display_name_template = parse_template_config("display_name_template")
  1328. email_template = parse_template_config("email_template")
  1329. extra_attributes = {} # type Dict[str, Template]
  1330. if "extra_attributes" in config:
  1331. extra_attributes_config = config.get("extra_attributes") or {}
  1332. if not isinstance(extra_attributes_config, dict):
  1333. raise ConfigError("must be a dict", path=["extra_attributes"])
  1334. for key, value in extra_attributes_config.items():
  1335. try:
  1336. extra_attributes[key] = env.from_string(value)
  1337. except Exception as e:
  1338. raise ConfigError(
  1339. "invalid jinja template", path=["extra_attributes", key]
  1340. ) from e
  1341. confirm_localpart = config.get("confirm_localpart") or False
  1342. if not isinstance(confirm_localpart, bool):
  1343. raise ConfigError("must be a bool", path=["confirm_localpart"])
  1344. return JinjaOidcMappingConfig(
  1345. subject_template=subject_template,
  1346. picture_template=picture_template,
  1347. localpart_template=localpart_template,
  1348. display_name_template=display_name_template,
  1349. email_template=email_template,
  1350. extra_attributes=extra_attributes,
  1351. confirm_localpart=confirm_localpart,
  1352. )
  1353. def get_remote_user_id(self, userinfo: UserInfo) -> str:
  1354. return self._config.subject_template.render(user=userinfo).strip()
  1355. async def map_user_attributes(
  1356. self, userinfo: UserInfo, token: Token, failures: int
  1357. ) -> UserAttributeDict:
  1358. localpart = None
  1359. if self._config.localpart_template:
  1360. localpart = self._config.localpart_template.render(user=userinfo).strip()
  1361. # Ensure only valid characters are included in the MXID.
  1362. localpart = map_username_to_mxid_localpart(localpart)
  1363. # Append suffix integer if last call to this function failed to produce
  1364. # a usable mxid.
  1365. localpart += str(failures) if failures else ""
  1366. def render_template_field(template: Optional[Template]) -> Optional[str]:
  1367. if template is None:
  1368. return None
  1369. return template.render(user=userinfo).strip()
  1370. display_name = render_template_field(self._config.display_name_template)
  1371. if display_name == "":
  1372. display_name = None
  1373. emails: List[str] = []
  1374. email = render_template_field(self._config.email_template)
  1375. if email:
  1376. emails.append(email)
  1377. picture = self._config.picture_template.render(user=userinfo).strip()
  1378. return UserAttributeDict(
  1379. localpart=localpart,
  1380. display_name=display_name,
  1381. emails=emails,
  1382. picture=picture,
  1383. confirm_localpart=self._config.confirm_localpart,
  1384. )
  1385. async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
  1386. extras: Dict[str, str] = {}
  1387. for key, template in self._config.extra_attributes.items():
  1388. try:
  1389. extras[key] = template.render(user=userinfo).strip()
  1390. except Exception as e:
  1391. # Log an error and skip this value (don't break login for this).
  1392. logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
  1393. return extras