This commit adds support for handling a provided avatar picture URL when logging in via SSO. Signed-off-by: Ashish Kumar <ashfame@users.noreply.github.com> Fixes #9357.tags/v1.73.0rc1
@@ -0,0 +1 @@ | |||
Adds support for handling avatar in SSO login. Contributed by @ashfame. |
@@ -2968,10 +2968,17 @@ Options for each entry include: | |||
For the default provider, the following settings are available: | |||
* subject_claim: name of the claim containing a unique identifier | |||
* `subject_claim`: name of the claim containing a unique identifier | |||
for the user. Defaults to 'sub', which OpenID Connect | |||
compliant providers should provide. | |||
* `picture_claim`: name of the claim containing an url for the user's profile picture. | |||
Defaults to 'picture', which OpenID Connect compliant providers should provide | |||
and has to refer to a direct image file such as PNG, JPEG, or GIF image file. | |||
Currently only supported in monolithic (single-process) server configurations | |||
where the media repository runs within the Synapse process. | |||
* `localpart_template`: Jinja2 template for the localpart of the MXID. | |||
If this is not set, the user will be prompted to choose their | |||
own username (see the documentation for the `sso_auth_account_details.html` | |||
@@ -119,6 +119,9 @@ disallow_untyped_defs = True | |||
[mypy-tests.storage.test_profile] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_sso] | |||
disallow_untyped_defs = True | |||
[mypy-tests.storage.test_user_directory] | |||
disallow_untyped_defs = True | |||
@@ -137,7 +140,6 @@ disallow_untyped_defs = False | |||
[mypy-tests.utils] | |||
disallow_untyped_defs = True | |||
;; Dependencies without annotations | |||
;; Before ignoring a module, check to see if type stubs are available. | |||
;; The `typeshed` project maintains stubs here: | |||
@@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict): | |||
localpart: Optional[str] | |||
confirm_localpart: bool | |||
display_name: Optional[str] | |||
picture: Optional[str] # may be omitted by older `OidcMappingProviders` | |||
emails: List[str] | |||
@@ -1520,6 +1521,7 @@ env.filters.update( | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class JinjaOidcMappingConfig: | |||
subject_claim: str | |||
picture_claim: str | |||
localpart_template: Optional[Template] | |||
display_name_template: Optional[Template] | |||
email_template: Optional[Template] | |||
@@ -1539,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): | |||
@staticmethod | |||
def parse_config(config: dict) -> JinjaOidcMappingConfig: | |||
subject_claim = config.get("subject_claim", "sub") | |||
picture_claim = config.get("picture_claim", "picture") | |||
def parse_template_config(option_name: str) -> Optional[Template]: | |||
if option_name not in config: | |||
@@ -1572,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): | |||
return JinjaOidcMappingConfig( | |||
subject_claim=subject_claim, | |||
picture_claim=picture_claim, | |||
localpart_template=localpart_template, | |||
display_name_template=display_name_template, | |||
email_template=email_template, | |||
@@ -1611,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): | |||
if email: | |||
emails.append(email) | |||
picture = userinfo.get("picture") | |||
return UserAttributeDict( | |||
localpart=localpart, | |||
display_name=display_name, | |||
emails=emails, | |||
picture=picture, | |||
confirm_localpart=self._config.confirm_localpart, | |||
) | |||
@@ -12,6 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import abc | |||
import hashlib | |||
import io | |||
import logging | |||
from typing import ( | |||
TYPE_CHECKING, | |||
@@ -138,6 +140,7 @@ class UserAttributes: | |||
localpart: Optional[str] | |||
confirm_localpart: bool = False | |||
display_name: Optional[str] = None | |||
picture: Optional[str] = None | |||
emails: Collection[str] = attr.Factory(list) | |||
@@ -196,6 +199,10 @@ class SsoHandler: | |||
self._error_template = hs.config.sso.sso_error_template | |||
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template | |||
self._profile_handler = hs.get_profile_handler() | |||
self._media_repo = ( | |||
hs.get_media_repository() if hs.config.media.can_load_media_repo else None | |||
) | |||
self._http_client = hs.get_proxied_blacklisted_http_client() | |||
# The following template is shown after a successful user interactive | |||
# authentication session. It tells the user they can close the window. | |||
@@ -495,6 +502,8 @@ class SsoHandler: | |||
await self._profile_handler.set_displayname( | |||
user_id_obj, requester, attributes.display_name, True | |||
) | |||
if attributes.picture: | |||
await self.set_avatar(user_id, attributes.picture) | |||
await self._auth_handler.complete_sso_login( | |||
user_id, | |||
@@ -703,8 +712,110 @@ class SsoHandler: | |||
await self._store.record_user_external_id( | |||
auth_provider_id, remote_user_id, registered_user_id | |||
) | |||
# Set avatar, if available | |||
if attributes.picture: | |||
await self.set_avatar(registered_user_id, attributes.picture) | |||
return registered_user_id | |||
async def set_avatar(self, user_id: str, picture_https_url: str) -> bool: | |||
"""Set avatar of the user. | |||
This downloads the image file from the URL provided, stores that in | |||
the media repository and then sets the avatar on the user's profile. | |||
It can detect if the same image is being saved again and bails early by storing | |||
the hash of the file in the `upload_name` of the avatar image. | |||
Currently, it only supports server configurations which run the media repository | |||
within the same process. | |||
It silently fails and logs a warning by raising an exception and catching it | |||
internally if: | |||
* it is unable to fetch the image itself (non 200 status code) or | |||
* the image supplied is bigger than max allowed size or | |||
* the image type is not one of the allowed image types. | |||
Args: | |||
user_id: matrix user ID in the form @localpart:domain as a string. | |||
picture_https_url: HTTPS url for the picture image file. | |||
Returns: `True` if the user's avatar has been successfully set to the image at | |||
`picture_https_url`. | |||
""" | |||
if self._media_repo is None: | |||
logger.info( | |||
"failed to set user avatar because out-of-process media repositories " | |||
"are not supported yet " | |||
) | |||
return False | |||
try: | |||
uid = UserID.from_string(user_id) | |||
def is_allowed_mime_type(content_type: str) -> bool: | |||
if ( | |||
self._profile_handler.allowed_avatar_mimetypes | |||
and content_type | |||
not in self._profile_handler.allowed_avatar_mimetypes | |||
): | |||
return False | |||
return True | |||
# download picture, enforcing size limit & mime type check | |||
picture = io.BytesIO() | |||
content_length, headers, uri, code = await self._http_client.get_file( | |||
url=picture_https_url, | |||
output_stream=picture, | |||
max_size=self._profile_handler.max_avatar_size, | |||
is_allowed_content_type=is_allowed_mime_type, | |||
) | |||
if code != 200: | |||
raise Exception( | |||
"GET request to download sso avatar image returned {}".format(code) | |||
) | |||
# upload name includes hash of the image file's content so that we can | |||
# easily check if it requires an update or not, the next time user logs in | |||
upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest() | |||
# bail if user already has the same avatar | |||
profile = await self._profile_handler.get_profile(user_id) | |||
if profile["avatar_url"] is not None: | |||
server_name = profile["avatar_url"].split("/")[-2] | |||
media_id = profile["avatar_url"].split("/")[-1] | |||
if server_name == self._server_name: | |||
media = await self._media_repo.store.get_local_media(media_id) | |||
if media is not None and upload_name == media["upload_name"]: | |||
logger.info("skipping saving the user avatar") | |||
return True | |||
# store it in media repository | |||
avatar_mxc_url = await self._media_repo.create_content( | |||
media_type=headers[b"Content-Type"][0].decode("utf-8"), | |||
upload_name=upload_name, | |||
content=picture, | |||
content_length=content_length, | |||
auth_user=uid, | |||
) | |||
# save it as user avatar | |||
await self._profile_handler.set_avatar_url( | |||
uid, | |||
create_requester(uid), | |||
str(avatar_mxc_url), | |||
) | |||
logger.info("successfully saved the user avatar") | |||
return True | |||
except Exception: | |||
logger.warning("failed to save the user avatar") | |||
return False | |||
async def complete_sso_ui_auth_request( | |||
self, | |||
auth_provider_id: str, | |||
@@ -0,0 +1,145 @@ | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from http import HTTPStatus | |||
from typing import BinaryIO, Callable, Dict, List, Optional, Tuple | |||
from unittest.mock import Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.web.http_headers import Headers | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.http.client import RawHeaders | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import SMALL_PNG, FakeResponse | |||
class TestSSOHandler(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.http_client = Mock(spec=["get_file"]) | |||
self.http_client.get_file.side_effect = mock_get_file | |||
self.http_client.user_agent = b"Synapse Test" | |||
hs = self.setup_test_homeserver( | |||
proxied_blacklisted_http_client=self.http_client | |||
) | |||
return hs | |||
async def test_set_avatar(self) -> None: | |||
"""Tests successfully setting the avatar of a newly created user""" | |||
handler = self.hs.get_sso_handler() | |||
# Create a new user to set avatar for | |||
reg_handler = self.hs.get_registration_handler() | |||
user_id = self.get_success(reg_handler.register_user(approved=True)) | |||
self.assertTrue( | |||
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) | |||
) | |||
# Ensure avatar is set on this newly created user, | |||
# so no need to compare for the exact image | |||
profile_handler = self.hs.get_profile_handler() | |||
profile = self.get_success(profile_handler.get_profile(user_id)) | |||
self.assertIsNot(profile["avatar_url"], None) | |||
@unittest.override_config({"max_avatar_size": 1}) | |||
async def test_set_avatar_too_big_image(self) -> None: | |||
"""Tests that saving an avatar fails when it is too big""" | |||
handler = self.hs.get_sso_handler() | |||
# any random user works since image check is supposed to fail | |||
user_id = "@sso-user:test" | |||
self.assertFalse( | |||
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) | |||
) | |||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]}) | |||
async def test_set_avatar_incorrect_mime_type(self) -> None: | |||
"""Tests that saving an avatar fails when its mime type is not allowed""" | |||
handler = self.hs.get_sso_handler() | |||
# any random user works since image check is supposed to fail | |||
user_id = "@sso-user:test" | |||
self.assertFalse( | |||
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) | |||
) | |||
async def test_skip_saving_avatar_when_not_changed(self) -> None: | |||
"""Tests whether saving of avatar correctly skips if the avatar hasn't | |||
changed""" | |||
handler = self.hs.get_sso_handler() | |||
# Create a new user to set avatar for | |||
reg_handler = self.hs.get_registration_handler() | |||
user_id = self.get_success(reg_handler.register_user(approved=True)) | |||
# set avatar for the first time, should be a success | |||
self.assertTrue( | |||
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) | |||
) | |||
# get avatar picture for comparison after another attempt | |||
profile_handler = self.hs.get_profile_handler() | |||
profile = self.get_success(profile_handler.get_profile(user_id)) | |||
url_to_match = profile["avatar_url"] | |||
# set same avatar for the second time, should be a success | |||
self.assertTrue( | |||
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png")) | |||
) | |||
# compare avatar picture's url from previous step | |||
profile = self.get_success(profile_handler.get_profile(user_id)) | |||
self.assertEqual(profile["avatar_url"], url_to_match) | |||
async def mock_get_file( | |||
url: str, | |||
output_stream: BinaryIO, | |||
max_size: Optional[int] = None, | |||
headers: Optional[RawHeaders] = None, | |||
is_allowed_content_type: Optional[Callable[[str], bool]] = None, | |||
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: | |||
fake_response = FakeResponse(code=404) | |||
if url == "http://my.server/me.png": | |||
fake_response = FakeResponse( | |||
code=200, | |||
headers=Headers( | |||
{"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]} | |||
), | |||
body=SMALL_PNG, | |||
) | |||
if max_size is not None and max_size < len(SMALL_PNG): | |||
raise SynapseError( | |||
HTTPStatus.BAD_GATEWAY, | |||
"Requested file is too large > %r bytes" % (max_size,), | |||
Codes.TOO_LARGE, | |||
) | |||
if is_allowed_content_type and not is_allowed_content_type("image/png"): | |||
raise SynapseError( | |||
HTTPStatus.BAD_GATEWAY, | |||
( | |||
"Requested file's content type not allowed for this operation: %s" | |||
% "image/png" | |||
), | |||
) | |||
output_stream.write(fake_response.body) | |||
return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200 |