Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier <babolivier@matrix.org>tags/v1.46.0rc1
@@ -0,0 +1 @@ | |||
Port the Password Auth Providers module interface to the new generic interface. |
@@ -43,6 +43,7 @@ | |||
- [Third-party rules callbacks](modules/third_party_rules_callbacks.md) | |||
- [Presence router callbacks](modules/presence_router_callbacks.md) | |||
- [Account validity callbacks](modules/account_validity_callbacks.md) | |||
- [Password auth provider callbacks](modules/password_auth_provider_callbacks.md) | |||
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md) | |||
- [Workers](workers.md) | |||
- [Using `synctl` with Workers](synctl_workers.md) | |||
@@ -0,0 +1,153 @@ | |||
# Password auth provider callbacks | |||
Password auth providers offer a way for server administrators to integrate | |||
their Synapse installation with an external authentication system. The callbacks can be | |||
registered by using the Module API's `register_password_auth_provider_callbacks` method. | |||
## Callbacks | |||
### `auth_checkers` | |||
``` | |||
auth_checkers: Dict[Tuple[str,Tuple], Callable] | |||
``` | |||
A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a | |||
tuple of field names (such as `("password", "secret_thing")`) to authentication checking | |||
callbacks, which should be of the following form: | |||
```python | |||
async def check_auth( | |||
user: str, | |||
login_type: str, | |||
login_dict: "synapse.module_api.JsonDict", | |||
) -> Optional[ | |||
Tuple[ | |||
str, | |||
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] | |||
] | |||
] | |||
``` | |||
The login type and field names should be provided by the user in the | |||
request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types) | |||
defines some types, however user defined ones are also allowed. | |||
The callback is passed the `user` field provided by the client (which might not be in | |||
`@username:server` form), the login type, and a dictionary of login secrets passed by | |||
the client. | |||
If the authentication is successful, the module must return the user's Matrix ID (e.g. | |||
`@alice:example.com`) and optionally a callback to be called with the response to the | |||
`/login` request. If the module doesn't wish to return a callback, it must return `None` | |||
instead. | |||
If the authentication is unsuccessful, the module must return `None`. | |||
### `check_3pid_auth` | |||
```python | |||
async def check_3pid_auth( | |||
medium: str, | |||
address: str, | |||
password: str, | |||
) -> Optional[ | |||
Tuple[ | |||
str, | |||
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] | |||
] | |||
] | |||
``` | |||
Called when a user attempts to register or log in with a third party identifier, | |||
such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`) | |||
and the user's password. | |||
If the authentication is successful, the module must return the user's Matrix ID (e.g. | |||
`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request. | |||
If the module doesn't wish to return a callback, it must return None instead. | |||
If the authentication is unsuccessful, the module must return None. | |||
### `on_logged_out` | |||
```python | |||
async def on_logged_out( | |||
user_id: str, | |||
device_id: Optional[str], | |||
access_token: str | |||
) -> None | |||
``` | |||
Called during a logout request for a user. It is passed the qualified user ID, the ID of the | |||
deactivated device (if any: access tokens are occasionally created without an associated | |||
device ID), and the (now deactivated) access token. | |||
## Example | |||
The example module below implements authentication checkers for two different login types: | |||
- `my.login.type` | |||
- Expects a `my_field` field to be sent to `/login` | |||
- Is checked by the method: `self.check_my_login` | |||
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) | |||
- Expects a `password` field to be sent to `/login` | |||
- Is checked by the method: `self.check_pass` | |||
```python | |||
from typing import Awaitable, Callable, Optional, Tuple | |||
import synapse | |||
from synapse import module_api | |||
class MyAuthProvider: | |||
def __init__(self, config: dict, api: module_api): | |||
self.api = api | |||
self.credentials = { | |||
"bob": "building", | |||
"@scoop:matrix.org": "digging", | |||
} | |||
api.register_password_auth_provider_callbacks( | |||
auth_checkers={ | |||
("my.login_type", ("my_field",)): self.check_my_login, | |||
("m.login.password", ("password",)): self.check_pass, | |||
}, | |||
) | |||
async def check_my_login( | |||
self, | |||
username: str, | |||
login_type: str, | |||
login_dict: "synapse.module_api.JsonDict", | |||
) -> Optional[ | |||
Tuple[ | |||
str, | |||
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], | |||
] | |||
]: | |||
if login_type != "my.login_type": | |||
return None | |||
if self.credentials.get(username) == login_dict.get("my_field"): | |||
return self.api.get_qualified_user_id(username) | |||
async def check_pass( | |||
self, | |||
username: str, | |||
login_type: str, | |||
login_dict: "synapse.module_api.JsonDict", | |||
) -> Optional[ | |||
Tuple[ | |||
str, | |||
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], | |||
] | |||
]: | |||
if login_type != "m.login.password": | |||
return None | |||
if self.credentials.get(username) == login_dict.get("password"): | |||
return self.api.get_qualified_user_id(username) | |||
``` |
@@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r | |||
method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for | |||
more info). | |||
There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any | |||
changes to the database should now be made by the module using the module API class. | |||
The module's author should also update any example in the module's configuration to only | |||
use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules) | |||
for more info). |
@@ -1,3 +1,9 @@ | |||
<h2 style="color:red"> | |||
This page of the Synapse documentation is now deprecated. For up to date | |||
documentation on setting up or writing a password auth provider module, please see | |||
<a href="modules.md">this page</a>. | |||
</h2> | |||
# Password auth provider modules | |||
Password auth providers offer a way for server administrators to | |||
@@ -2260,34 +2260,6 @@ email: | |||
#email_validation: "[%(server_name)s] Validate your email" | |||
# Password providers allow homeserver administrators to integrate | |||
# their Synapse installation with existing authentication methods | |||
# ex. LDAP, external tokens, etc. | |||
# | |||
# For more information and known implementations, please see | |||
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html | |||
# | |||
# Note: instances wishing to use SAML or CAS authentication should | |||
# instead use the `saml2_config` or `cas_config` options, | |||
# respectively. | |||
# | |||
password_providers: | |||
# # Example config for an LDAP auth provider | |||
# - module: "ldap_auth_provider.LdapAuthProvider" | |||
# config: | |||
# enabled: true | |||
# uri: "ldap://ldap.example.com:389" | |||
# start_tls: true | |||
# base: "ou=users,dc=example,dc=com" | |||
# attributes: | |||
# uid: "cn" | |||
# mail: "email" | |||
# name: "givenName" | |||
# #bind_dn: | |||
# #bind_password: | |||
# #filter: "(objectClass=posixAccount)" | |||
## Push ## | |||
@@ -42,6 +42,7 @@ from synapse.crypto import context_factory | |||
from synapse.events.presence_router import load_legacy_presence_router | |||
from synapse.events.spamcheck import load_legacy_spam_checkers | |||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules | |||
from synapse.handlers.auth import load_legacy_password_auth_providers | |||
from synapse.logging.context import PreserveLoggingContext | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
from synapse.metrics.jemalloc import setup_jemalloc_stats | |||
@@ -379,6 +380,7 @@ async def start(hs: "HomeServer"): | |||
load_legacy_spam_checkers(hs) | |||
load_legacy_third_party_event_rules(hs) | |||
load_legacy_presence_router(hs) | |||
load_legacy_password_auth_providers(hs) | |||
# If we've configured an expiry time for caches, start the background job now. | |||
setup_expire_lru_cache_entries(hs) | |||
@@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config): | |||
section = "authproviders" | |||
def read_config(self, config, **kwargs): | |||
"""Parses the old password auth providers config. The config format looks like this: | |||
password_providers: | |||
# Example config for an LDAP auth provider | |||
- module: "ldap_auth_provider.LdapAuthProvider" | |||
config: | |||
enabled: true | |||
uri: "ldap://ldap.example.com:389" | |||
start_tls: true | |||
base: "ou=users,dc=example,dc=com" | |||
attributes: | |||
uid: "cn" | |||
mail: "email" | |||
name: "givenName" | |||
#bind_dn: | |||
#bind_password: | |||
#filter: "(objectClass=posixAccount)" | |||
We expect admins to use modules for this feature (which is why it doesn't appear | |||
in the sample config file), but we want to keep support for it around for a bit | |||
for backwards compatibility. | |||
""" | |||
self.password_providers: List[Tuple[Type, Any]] = [] | |||
providers = [] | |||
@@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config): | |||
) | |||
self.password_providers.append((provider_class, provider_config)) | |||
def generate_config_section(self, **kwargs): | |||
return """\ | |||
# Password providers allow homeserver administrators to integrate | |||
# their Synapse installation with existing authentication methods | |||
# ex. LDAP, external tokens, etc. | |||
# | |||
# For more information and known implementations, please see | |||
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html | |||
# | |||
# Note: instances wishing to use SAML or CAS authentication should | |||
# instead use the `saml2_config` or `cas_config` options, | |||
# respectively. | |||
# | |||
password_providers: | |||
# # Example config for an LDAP auth provider | |||
# - module: "ldap_auth_provider.LdapAuthProvider" | |||
# config: | |||
# enabled: true | |||
# uri: "ldap://ldap.example.com:389" | |||
# start_tls: true | |||
# base: "ou=users,dc=example,dc=com" | |||
# attributes: | |||
# uid: "cn" | |||
# mail: "email" | |||
# name: "givenName" | |||
# #bind_dn: | |||
# #bind_password: | |||
# #filter: "(objectClass=posixAccount)" | |||
""" |
@@ -200,46 +200,13 @@ class AuthHandler: | |||
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds | |||
# we can't use hs.get_module_api() here, because to do so will create an | |||
# import loop. | |||
# | |||
# TODO: refactor this class to separate the lower-level stuff that | |||
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as | |||
# better way to break the loop | |||
account_handler = ModuleApi(hs, self) | |||
self.password_providers = [ | |||
PasswordProvider.load(module, config, account_handler) | |||
for module, config in hs.config.authproviders.password_providers | |||
] | |||
logger.info("Extra password_providers: %s", self.password_providers) | |||
self.password_auth_provider = hs.get_password_auth_provider() | |||
self.hs = hs # FIXME better possibility to access registrationHandler later? | |||
self.macaroon_gen = hs.get_macaroon_generator() | |||
self._password_enabled = hs.config.auth.password_enabled | |||
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled | |||
# start out by assuming PASSWORD is enabled; we will remove it later if not. | |||
login_types = set() | |||
if self._password_localdb_enabled: | |||
login_types.add(LoginType.PASSWORD) | |||
for provider in self.password_providers: | |||
login_types.update(provider.get_supported_login_types().keys()) | |||
if not self._password_enabled: | |||
login_types.discard(LoginType.PASSWORD) | |||
# Some clients just pick the first type in the list. In this case, we want | |||
# them to use PASSWORD (rather than token or whatever), so we want to make sure | |||
# that comes first, where it's present. | |||
self._supported_login_types = [] | |||
if LoginType.PASSWORD in login_types: | |||
self._supported_login_types.append(LoginType.PASSWORD) | |||
login_types.remove(LoginType.PASSWORD) | |||
self._supported_login_types.extend(login_types) | |||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config | |||
# as per `rc_login.failed_attempts`. | |||
self._failed_uia_attempts_ratelimiter = Ratelimiter( | |||
@@ -427,11 +394,10 @@ class AuthHandler: | |||
ui_auth_types.add(LoginType.PASSWORD) | |||
# also allow auth from password providers | |||
for provider in self.password_providers: | |||
for t in provider.get_supported_login_types().keys(): | |||
if t == LoginType.PASSWORD and not self._password_enabled: | |||
continue | |||
ui_auth_types.add(t) | |||
for t in self.password_auth_provider.get_supported_login_types().keys(): | |||
if t == LoginType.PASSWORD and not self._password_enabled: | |||
continue | |||
ui_auth_types.add(t) | |||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping | |||
# from sso to mxid. | |||
@@ -1038,7 +1004,25 @@ class AuthHandler: | |||
Returns: | |||
login types | |||
""" | |||
return self._supported_login_types | |||
# Load any login types registered by modules | |||
# This is stored in the password_auth_provider so this doesn't trigger | |||
# any callbacks | |||
types = list(self.password_auth_provider.get_supported_login_types().keys()) | |||
# This list should include PASSWORD if (either _password_localdb_enabled is | |||
# true or if one of the modules registered it) AND _password_enabled is true | |||
# Also: | |||
# Some clients just pick the first type in the list. In this case, we want | |||
# them to use PASSWORD (rather than token or whatever), so we want to make sure | |||
# that comes first, where it's present. | |||
if LoginType.PASSWORD in types: | |||
types.remove(LoginType.PASSWORD) | |||
if self._password_enabled: | |||
types.insert(0, LoginType.PASSWORD) | |||
elif self._password_localdb_enabled and self._password_enabled: | |||
types.insert(0, LoginType.PASSWORD) | |||
return types | |||
async def validate_login( | |||
self, | |||
@@ -1217,15 +1201,20 @@ class AuthHandler: | |||
known_login_type = False | |||
for provider in self.password_providers: | |||
supported_login_types = provider.get_supported_login_types() | |||
if login_type not in supported_login_types: | |||
# this password provider doesn't understand this login type | |||
continue | |||
# Check if login_type matches a type registered by one of the modules | |||
# We don't need to remove LoginType.PASSWORD from the list if password login is | |||
# disabled, since if that were the case then by this point we know that the | |||
# login_type is not LoginType.PASSWORD | |||
supported_login_types = self.password_auth_provider.get_supported_login_types() | |||
# check if the login type being used is supported by a module | |||
if login_type in supported_login_types: | |||
# Make a note that this login type is supported by the server | |||
known_login_type = True | |||
# Get all the fields expected for this login types | |||
login_fields = supported_login_types[login_type] | |||
# go through the login submission and keep track of which required fields are | |||
# provided/not provided | |||
missing_fields = [] | |||
login_dict = {} | |||
for f in login_fields: | |||
@@ -1233,6 +1222,7 @@ class AuthHandler: | |||
missing_fields.append(f) | |||
else: | |||
login_dict[f] = login_submission[f] | |||
# raise an error if any of the expected fields for that login type weren't provided | |||
if missing_fields: | |||
raise SynapseError( | |||
400, | |||
@@ -1240,10 +1230,15 @@ class AuthHandler: | |||
% (login_type, missing_fields), | |||
) | |||
result = await provider.check_auth(username, login_type, login_dict) | |||
# call all of the check_auth hooks for that login_type | |||
# it will return a result once the first success is found (or None otherwise) | |||
result = await self.password_auth_provider.check_auth( | |||
username, login_type, login_dict | |||
) | |||
if result: | |||
return result | |||
# if no module managed to authenticate the user, then fallback to built in password based auth | |||
if login_type == LoginType.PASSWORD and self._password_localdb_enabled: | |||
known_login_type = True | |||
@@ -1282,11 +1277,16 @@ class AuthHandler: | |||
completed login/registration, or `None`. If authentication was | |||
unsuccessful, `user_id` and `callback` are both `None`. | |||
""" | |||
for provider in self.password_providers: | |||
result = await provider.check_3pid_auth(medium, address, password) | |||
if result: | |||
return result | |||
# call all of the check_3pid_auth callbacks | |||
# Result will be from the first callback that returns something other than None | |||
# If all the callbacks return None, then result is also set to None | |||
result = await self.password_auth_provider.check_3pid_auth( | |||
medium, address, password | |||
) | |||
if result: | |||
return result | |||
# if result is None then return (None, None) | |||
return None, None | |||
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: | |||
@@ -1365,13 +1365,12 @@ class AuthHandler: | |||
user_info = await self.auth.get_user_by_access_token(access_token) | |||
await self.store.delete_access_token(access_token) | |||
# see if any of our auth providers want to know about this | |||
for provider in self.password_providers: | |||
await provider.on_logged_out( | |||
user_id=user_info.user_id, | |||
device_id=user_info.device_id, | |||
access_token=access_token, | |||
) | |||
# see if any modules want to know about this | |||
await self.password_auth_provider.on_logged_out( | |||
user_id=user_info.user_id, | |||
device_id=user_info.device_id, | |||
access_token=access_token, | |||
) | |||
# delete pushers associated with this access token | |||
if user_info.token_id is not None: | |||
@@ -1398,12 +1397,11 @@ class AuthHandler: | |||
user_id, except_token_id=except_token_id, device_id=device_id | |||
) | |||
# see if any of our auth providers want to know about this | |||
for provider in self.password_providers: | |||
for token, _, device_id in tokens_and_devices: | |||
await provider.on_logged_out( | |||
user_id=user_id, device_id=device_id, access_token=token | |||
) | |||
# see if any modules want to know about this | |||
for token, _, device_id in tokens_and_devices: | |||
await self.password_auth_provider.on_logged_out( | |||
user_id=user_id, device_id=device_id, access_token=token | |||
) | |||
# delete pushers associated with the access tokens | |||
await self.hs.get_pusherpool().remove_pushers_by_access_token( | |||
@@ -1811,40 +1809,228 @@ class MacaroonGenerator: | |||
return macaroon | |||
class PasswordProvider: | |||
"""Wrapper for a password auth provider module | |||
def load_legacy_password_auth_providers(hs: "HomeServer") -> None: | |||
module_api = hs.get_module_api() | |||
for module, config in hs.config.authproviders.password_providers: | |||
load_single_legacy_password_auth_provider( | |||
module=module, config=config, api=module_api | |||
) | |||
This class abstracts out all of the backwards-compatibility hacks for | |||
password providers, to provide a consistent interface. | |||
""" | |||
@classmethod | |||
def load( | |||
cls, module: Type, config: JsonDict, module_api: ModuleApi | |||
) -> "PasswordProvider": | |||
try: | |||
pp = module(config=config, account_handler=module_api) | |||
except Exception as e: | |||
logger.error("Error while initializing %r: %s", module, e) | |||
raise | |||
return cls(pp, module_api) | |||
def load_single_legacy_password_auth_provider( | |||
module: Type, config: JsonDict, api: ModuleApi | |||
) -> None: | |||
try: | |||
provider = module(config=config, account_handler=api) | |||
except Exception as e: | |||
logger.error("Error while initializing %r: %s", module, e) | |||
raise | |||
# The known hooks. If a module implements a method who's name appears in this set | |||
# we'll want to register it | |||
password_auth_provider_methods = { | |||
"check_3pid_auth", | |||
"on_logged_out", | |||
} | |||
# All methods that the module provides should be async, but this wasn't enforced | |||
# in the old module system, so we wrap them if needed | |||
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: | |||
# f might be None if the callback isn't implemented by the module. In this | |||
# case we don't want to register a callback at all so we return None. | |||
if f is None: | |||
return None | |||
# We need to wrap check_password because its old form would return a boolean | |||
# but we now want it to behave just like check_auth() and return the matrix id of | |||
# the user if authentication succeeded or None otherwise | |||
if f.__name__ == "check_password": | |||
async def wrapped_check_password( | |||
username: str, login_type: str, login_dict: JsonDict | |||
) -> Optional[Tuple[str, Optional[Callable]]]: | |||
# We've already made sure f is not None above, but mypy doesn't do well | |||
# across function boundaries so we need to tell it f is definitely not | |||
# None. | |||
assert f is not None | |||
matrix_user_id = api.get_qualified_user_id(username) | |||
password = login_dict["password"] | |||
is_valid = await f(matrix_user_id, password) | |||
if is_valid: | |||
return matrix_user_id, None | |||
return None | |||
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): | |||
self._pp = pp | |||
self._module_api = module_api | |||
return wrapped_check_password | |||
# We need to wrap check_auth as in the old form it could return | |||
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]] | |||
if f.__name__ == "check_auth": | |||
async def wrapped_check_auth( | |||
username: str, login_type: str, login_dict: JsonDict | |||
) -> Optional[Tuple[str, Optional[Callable]]]: | |||
# We've already made sure f is not None above, but mypy doesn't do well | |||
# across function boundaries so we need to tell it f is definitely not | |||
# None. | |||
assert f is not None | |||
result = await f(username, login_type, login_dict) | |||
if isinstance(result, str): | |||
return result, None | |||
return result | |||
return wrapped_check_auth | |||
# We need to wrap check_3pid_auth as in the old form it could return | |||
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]] | |||
if f.__name__ == "check_3pid_auth": | |||
async def wrapped_check_3pid_auth( | |||
medium: str, address: str, password: str | |||
) -> Optional[Tuple[str, Optional[Callable]]]: | |||
# We've already made sure f is not None above, but mypy doesn't do well | |||
# across function boundaries so we need to tell it f is definitely not | |||
# None. | |||
assert f is not None | |||
result = await f(medium, address, password) | |||
if isinstance(result, str): | |||
return result, None | |||
return result | |||
self._supported_login_types = {} | |||
return wrapped_check_3pid_auth | |||
# grandfather in check_password support | |||
if hasattr(self._pp, "check_password"): | |||
self._supported_login_types[LoginType.PASSWORD] = ("password",) | |||
def run(*args: Tuple, **kwargs: Dict) -> Awaitable: | |||
# mypy doesn't do well across function boundaries so we need to tell it | |||
# f is definitely not None. | |||
assert f is not None | |||
g = getattr(self._pp, "get_supported_login_types", None) | |||
if g: | |||
self._supported_login_types.update(g()) | |||
return maybe_awaitable(f(*args, **kwargs)) | |||
def __str__(self) -> str: | |||
return str(self._pp) | |||
return run | |||
# populate hooks with the implemented methods, wrapped with async_wrapper | |||
hooks = { | |||
hook: async_wrapper(getattr(provider, hook, None)) | |||
for hook in password_auth_provider_methods | |||
} | |||
supported_login_types = {} | |||
# call get_supported_login_types and add that to the dict | |||
g = getattr(provider, "get_supported_login_types", None) | |||
if g is not None: | |||
# Note the old module style also called get_supported_login_types at loading time | |||
# and it is synchronous | |||
supported_login_types.update(g()) | |||
auth_checkers = {} | |||
# Legacy modules have a check_auth method which expects to be called with one of | |||
# the keys returned by get_supported_login_types. New style modules register a | |||
# dictionary of login_type->check_auth_method mappings | |||
check_auth = async_wrapper(getattr(provider, "check_auth", None)) | |||
if check_auth is not None: | |||
for login_type, fields in supported_login_types.items(): | |||
# need tuple(fields) since fields can be any Iterable type (so may not be hashable) | |||
auth_checkers[(login_type, tuple(fields))] = check_auth | |||
# if it has a "check_password" method then it should handle all auth checks | |||
# with login type of LoginType.PASSWORD | |||
check_password = async_wrapper(getattr(provider, "check_password", None)) | |||
if check_password is not None: | |||
# need to use a tuple here for ("password",) not a list since lists aren't hashable | |||
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password | |||
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) | |||
CHECK_3PID_AUTH_CALLBACK = Callable[ | |||
[str, str, str], | |||
Awaitable[ | |||
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] | |||
], | |||
] | |||
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] | |||
CHECK_AUTH_CALLBACK = Callable[ | |||
[str, str, JsonDict], | |||
Awaitable[ | |||
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] | |||
], | |||
] | |||
class PasswordAuthProvider: | |||
""" | |||
A class that the AuthHandler calls when authenticating users | |||
It allows modules to provide alternative methods for authentication | |||
""" | |||
def __init__(self) -> None: | |||
# lists of callbacks | |||
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] | |||
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] | |||
# Mapping from login type to login parameters | |||
self._supported_login_types: Dict[str, Iterable[str]] = {} | |||
# Mapping from login type to auth checker callbacks | |||
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} | |||
def register_password_auth_provider_callbacks( | |||
self, | |||
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, | |||
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, | |||
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, | |||
) -> None: | |||
# Register check_3pid_auth callback | |||
if check_3pid_auth is not None: | |||
self.check_3pid_auth_callbacks.append(check_3pid_auth) | |||
# register on_logged_out callback | |||
if on_logged_out is not None: | |||
self.on_logged_out_callbacks.append(on_logged_out) | |||
if auth_checkers is not None: | |||
# register a new supported login_type | |||
# Iterate through all of the types being registered | |||
for (login_type, fields), callback in auth_checkers.items(): | |||
# Note: fields may be empty here. This would allow a modules auth checker to | |||
# be called with just 'login_type' and no password or other secrets | |||
# Need to check that all the field names are strings or may get nasty errors later | |||
for f in fields: | |||
if not isinstance(f, str): | |||
raise RuntimeError( | |||
"A module tried to register support for login type: %s with parameters %s" | |||
" but all parameter names must be strings" | |||
% (login_type, fields) | |||
) | |||
# 2 modules supporting the same login type must expect the same fields | |||
# e.g. 1 can't expect "pass" if the other expects "password" | |||
# so throw an exception if that happens | |||
if login_type not in self._supported_login_types.get(login_type, []): | |||
self._supported_login_types[login_type] = fields | |||
else: | |||
fields_currently_supported = self._supported_login_types.get( | |||
login_type | |||
) | |||
if fields_currently_supported != fields: | |||
raise RuntimeError( | |||
"A module tried to register support for login type: %s with parameters %s" | |||
" but another module had already registered support for that type with parameters %s" | |||
% (login_type, fields, fields_currently_supported) | |||
) | |||
# Add the new method to the list of auth_checker_callbacks for this login type | |||
self.auth_checker_callbacks.setdefault(login_type, []).append(callback) | |||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: | |||
"""Get the login types supported by this password provider | |||
@@ -1852,20 +2038,15 @@ class PasswordProvider: | |||
Returns a map from a login type identifier (such as m.login.password) to an | |||
iterable giving the fields which must be provided by the user in the submission | |||
to the /login API. | |||
This wrapper adds m.login.password to the list if the underlying password | |||
provider supports the check_password() api. | |||
""" | |||
return self._supported_login_types | |||
async def check_auth( | |||
self, username: str, login_type: str, login_dict: JsonDict | |||
) -> Optional[Tuple[str, Optional[Callable]]]: | |||
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: | |||
"""Check if the user has presented valid login credentials | |||
This wrapper also calls check_password() if the underlying password provider | |||
supports the check_password() api and the login type is m.login.password. | |||
Args: | |||
username: user id presented by the client. Either an MXID or an unqualified | |||
username. | |||
@@ -1879,63 +2060,130 @@ class PasswordProvider: | |||
user, and `callback` is an optional callback which will be called with the | |||
result from the /login call (including access_token, device_id, etc.) | |||
""" | |||
# first grandfather in a call to check_password | |||
if login_type == LoginType.PASSWORD: | |||
check_password = getattr(self._pp, "check_password", None) | |||
if check_password: | |||
qualified_user_id = self._module_api.get_qualified_user_id(username) | |||
is_valid = await check_password( | |||
qualified_user_id, login_dict["password"] | |||
) | |||
if is_valid: | |||
return qualified_user_id, None | |||
check_auth = getattr(self._pp, "check_auth", None) | |||
if not check_auth: | |||
return None | |||
result = await check_auth(username, login_type, login_dict) | |||
# Go through all callbacks for the login type until one returns with a value | |||
# other than None (i.e. until a callback returns a success) | |||
for callback in self.auth_checker_callbacks[login_type]: | |||
try: | |||
result = await callback(username, login_type, login_dict) | |||
except Exception as e: | |||
logger.warning("Failed to run module API callback %s: %s", callback, e) | |||
continue | |||
# Check if the return value is a str or a tuple | |||
if isinstance(result, str): | |||
# If it's a str, set callback function to None | |||
return result, None | |||
if result is not None: | |||
# Check that the callback returned a Tuple[str, Optional[Callable]] | |||
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks | |||
# result is always the right type, but as it is 3rd party code it might not be | |||
if not isinstance(result, tuple) or len(result) != 2: | |||
logger.warning( | |||
"Wrong type returned by module API callback %s: %s, expected" | |||
" Optional[Tuple[str, Optional[Callable]]]", | |||
callback, | |||
result, | |||
) | |||
continue | |||
return result | |||
# pull out the two parts of the tuple so we can do type checking | |||
str_result, callback_result = result | |||
# the 1st item in the tuple should be a str | |||
if not isinstance(str_result, str): | |||
logger.warning( # type: ignore[unreachable] | |||
"Wrong type returned by module API callback %s: %s, expected" | |||
" Optional[Tuple[str, Optional[Callable]]]", | |||
callback, | |||
result, | |||
) | |||
continue | |||
# the second should be Optional[Callable] | |||
if callback_result is not None: | |||
if not callable(callback_result): | |||
logger.warning( # type: ignore[unreachable] | |||
"Wrong type returned by module API callback %s: %s, expected" | |||
" Optional[Tuple[str, Optional[Callable]]]", | |||
callback, | |||
result, | |||
) | |||
continue | |||
# The result is a (str, Optional[callback]) tuple so return the successful result | |||
return result | |||
# If this point has been reached then none of the callbacks successfully authenticated | |||
# the user so return None | |||
return None | |||
async def check_3pid_auth( | |||
self, medium: str, address: str, password: str | |||
) -> Optional[Tuple[str, Optional[Callable]]]: | |||
g = getattr(self._pp, "check_3pid_auth", None) | |||
if not g: | |||
return None | |||
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: | |||
# This function is able to return a deferred that either | |||
# resolves None, meaning authentication failure, or upon | |||
# success, to a str (which is the user_id) or a tuple of | |||
# (user_id, callback_func), where callback_func should be run | |||
# after we've finished everything else | |||
result = await g(medium, address, password) | |||
# Check if the return value is a str or a tuple | |||
if isinstance(result, str): | |||
# If it's a str, set callback function to None | |||
return result, None | |||
for callback in self.check_3pid_auth_callbacks: | |||
try: | |||
result = await callback(medium, address, password) | |||
except Exception as e: | |||
logger.warning("Failed to run module API callback %s: %s", callback, e) | |||
continue | |||
return result | |||
if result is not None: | |||
# Check that the callback returned a Tuple[str, Optional[Callable]] | |||
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks | |||
# result is always the right type, but as it is 3rd party code it might not be | |||
if not isinstance(result, tuple) or len(result) != 2: | |||
logger.warning( | |||
"Wrong type returned by module API callback %s: %s, expected" | |||
" Optional[Tuple[str, Optional[Callable]]]", | |||
callback, | |||
result, | |||
) | |||
continue | |||
# pull out the two parts of the tuple so we can do type checking | |||
str_result, callback_result = result | |||
# the 1st item in the tuple should be a str | |||
if not isinstance(str_result, str): | |||
logger.warning( # type: ignore[unreachable] | |||
"Wrong type returned by module API callback %s: %s, expected" | |||
" Optional[Tuple[str, Optional[Callable]]]", | |||
callback, | |||
result, | |||
) | |||
continue | |||
# the second should be Optional[Callable] | |||
if callback_result is not None: | |||
if not callable(callback_result): | |||
logger.warning( # type: ignore[unreachable] | |||
"Wrong type returned by module API callback %s: %s, expected" | |||
" Optional[Tuple[str, Optional[Callable]]]", | |||
callback, | |||
result, | |||
) | |||
continue | |||
# The result is a (str, Optional[callback]) tuple so return the successful result | |||
return result | |||
# If this point has been reached then none of the callbacks successfully authenticated | |||
# the user so return None | |||
return None | |||
async def on_logged_out( | |||
self, user_id: str, device_id: Optional[str], access_token: str | |||
) -> None: | |||
g = getattr(self._pp, "on_logged_out", None) | |||
if not g: | |||
return | |||
# This might return an awaitable, if it does block the log out | |||
# until it completes. | |||
await maybe_awaitable( | |||
g( | |||
user_id=user_id, | |||
device_id=device_id, | |||
access_token=access_token, | |||
) | |||
) | |||
# call all of the on_logged_out callbacks | |||
for callback in self.on_logged_out_callbacks: | |||
try: | |||
callback(user_id, device_id, access_token) | |||
except Exception as e: | |||
logger.warning("Failed to run module API callback %s: %s", callback, e) | |||
continue |
@@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request | |||
from synapse.http.site import SynapseRequest | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.metrics.background_process_metrics import run_as_background_process | |||
from synapse.rest.client.login import LoginResponse | |||
from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.storage.databases.main.roommember import ProfileInfo | |||
from synapse.storage.state import StateFilter | |||
@@ -83,6 +84,8 @@ __all__ = [ | |||
"DirectServeJsonResource", | |||
"ModuleApi", | |||
"PRESENCE_ALL_USERS", | |||
"LoginResponse", | |||
"JsonDict", | |||
] | |||
logger = logging.getLogger(__name__) | |||
@@ -139,6 +142,7 @@ class ModuleApi: | |||
self._spam_checker = hs.get_spam_checker() | |||
self._account_validity_handler = hs.get_account_validity_handler() | |||
self._third_party_event_rules = hs.get_third_party_event_rules() | |||
self._password_auth_provider = hs.get_password_auth_provider() | |||
self._presence_router = hs.get_presence_router() | |||
################################################################################# | |||
@@ -164,6 +168,11 @@ class ModuleApi: | |||
"""Registers callbacks for presence router capabilities.""" | |||
return self._presence_router.register_presence_router_callbacks | |||
@property | |||
def register_password_auth_provider_callbacks(self): | |||
"""Registers callbacks for password auth provider capabilities.""" | |||
return self._password_auth_provider.register_password_auth_provider_callbacks | |||
def register_web_resource(self, path: str, resource: IResource): | |||
"""Registers a web resource to be served at the given path. | |||
@@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler | |||
from synapse.handlers.account_validity import AccountValidityHandler | |||
from synapse.handlers.admin import AdminHandler | |||
from synapse.handlers.appservice import ApplicationServicesHandler | |||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator | |||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider | |||
from synapse.handlers.cas import CasHandler | |||
from synapse.handlers.deactivate_account import DeactivateAccountHandler | |||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler | |||
@@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
def get_third_party_event_rules(self) -> ThirdPartyEventRules: | |||
return ThirdPartyEventRules(self) | |||
@cache_in_self | |||
def get_password_auth_provider(self) -> PasswordAuthProvider: | |||
return PasswordAuthProvider() | |||
@cache_in_self | |||
def get_room_member_handler(self) -> RoomMemberHandler: | |||
if self.config.worker.worker_app: | |||
@@ -549,6 +549,8 @@ def _apply_module_schemas( | |||
database_engine: | |||
config: application config | |||
""" | |||
# This is the old way for password_auth_provider modules to make changes | |||
# to the database. This should instead be done using the module API | |||
for (mod, _config) in config.authproviders.password_providers: | |||
if not hasattr(mod, "get_db_schema_files"): | |||
continue | |||
@@ -20,6 +20,8 @@ from unittest.mock import Mock | |||
from twisted.internet import defer | |||
import synapse | |||
from synapse.handlers.auth import load_legacy_password_auth_providers | |||
from synapse.module_api import ModuleApi | |||
from synapse.rest.client import devices, login | |||
from synapse.types import JsonDict | |||
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi | |||
mock_password_provider = Mock() | |||
class PasswordOnlyAuthProvider: | |||
"""A password_provider which only implements `check_password`.""" | |||
class LegacyPasswordOnlyAuthProvider: | |||
"""A legacy password_provider which only implements `check_password`.""" | |||
@staticmethod | |||
def parse_config(self): | |||
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider: | |||
return mock_password_provider.check_password(*args) | |||
class CustomAuthProvider: | |||
"""A password_provider which implements a custom login type.""" | |||
class LegacyCustomAuthProvider: | |||
"""A legacy password_provider which implements a custom login type.""" | |||
@staticmethod | |||
def parse_config(self): | |||
@@ -67,7 +69,23 @@ class CustomAuthProvider: | |||
return mock_password_provider.check_auth(*args) | |||
class PasswordCustomAuthProvider: | |||
class CustomAuthProvider: | |||
"""A module which registers password_auth_provider callbacks for a custom login type.""" | |||
@staticmethod | |||
def parse_config(self): | |||
pass | |||
def __init__(self, config, api: ModuleApi): | |||
api.register_password_auth_provider_callbacks( | |||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, | |||
) | |||
def check_auth(self, *args): | |||
return mock_password_provider.check_auth(*args) | |||
class LegacyPasswordCustomAuthProvider: | |||
"""A password_provider which implements password login via `check_auth`, as well | |||
as a custom type.""" | |||
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider: | |||
return mock_password_provider.check_auth(*args) | |||
def providers_config(*providers: Type[Any]) -> dict: | |||
"""Returns a config dict that will enable the given password auth providers""" | |||
class PasswordCustomAuthProvider: | |||
"""A module which registers password_auth_provider callbacks for a custom login type. | |||
as well as a password login""" | |||
@staticmethod | |||
def parse_config(self): | |||
pass | |||
def __init__(self, config, api: ModuleApi): | |||
api.register_password_auth_provider_callbacks( | |||
auth_checkers={ | |||
("test.login_type", ("test_field",)): self.check_auth, | |||
("m.login.password", ("password",)): self.check_auth, | |||
}, | |||
) | |||
pass | |||
def check_auth(self, *args): | |||
return mock_password_provider.check_auth(*args) | |||
def check_pass(self, *args): | |||
return mock_password_provider.check_password(*args) | |||
def legacy_providers_config(*providers: Type[Any]) -> dict: | |||
"""Returns a config dict that will enable the given legacy password auth providers""" | |||
return { | |||
"password_providers": [ | |||
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} | |||
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: | |||
} | |||
def providers_config(*providers: Type[Any]) -> dict: | |||
"""Returns a config dict that will enable the given modules""" | |||
return { | |||
"modules": [ | |||
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} | |||
for provider in providers | |||
] | |||
} | |||
class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
servlets = [ | |||
synapse.rest.admin.register_servlets, | |||
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.reset_mock() | |||
super().setUp() | |||
@override_config(providers_config(PasswordOnlyAuthProvider)) | |||
def test_password_only_auth_provider_login(self): | |||
def make_homeserver(self, reactor, clock): | |||
hs = self.setup_test_homeserver() | |||
# Load the modules into the homeserver | |||
module_api = hs.get_module_api() | |||
for module, config in hs.config.modules.loaded_modules: | |||
module(config=config, api=module_api) | |||
load_legacy_password_auth_providers(hs) | |||
return hs | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_password_only_auth_progiver_login_legacy(self): | |||
self.password_only_auth_provider_login_test_body() | |||
def password_only_auth_provider_login_test_body(self): | |||
# login flows should only have m.login.password | |||
flows = self._get_login_flows() | |||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) | |||
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"@ USER🙂NAME :test", " pASS😢word " | |||
) | |||
@override_config(providers_config(PasswordOnlyAuthProvider)) | |||
def test_password_only_auth_provider_ui_auth(self): | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_password_only_auth_provider_ui_auth_legacy(self): | |||
self.password_only_auth_provider_ui_auth_test_body() | |||
def password_only_auth_provider_ui_auth_test_body(self): | |||
"""UI Auth should delegate correctly to the password provider""" | |||
# create the user, otherwise access doesn't work | |||
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 200) | |||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") | |||
@override_config(providers_config(PasswordOnlyAuthProvider)) | |||
def test_local_user_fallback_login(self): | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_local_user_fallback_login_legacy(self): | |||
self.local_user_fallback_login_test_body() | |||
def local_user_fallback_login_test_body(self): | |||
"""rejected login should fall back to local db""" | |||
self.register_user("localuser", "localpass") | |||
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 200, channel.result) | |||
self.assertEqual("@localuser:test", channel.json_body["user_id"]) | |||
@override_config(providers_config(PasswordOnlyAuthProvider)) | |||
def test_local_user_fallback_ui_auth(self): | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_local_user_fallback_ui_auth_legacy(self): | |||
self.local_user_fallback_ui_auth_test_body() | |||
def local_user_fallback_ui_auth_test_body(self): | |||
"""rejected login should fall back to local db""" | |||
self.register_user("localuser", "localpass") | |||
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
@override_config( | |||
{ | |||
**providers_config(PasswordOnlyAuthProvider), | |||
**legacy_providers_config(LegacyPasswordOnlyAuthProvider), | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_no_local_user_fallback_login(self): | |||
def test_no_local_user_fallback_login_legacy(self): | |||
self.no_local_user_fallback_login_test_body() | |||
def no_local_user_fallback_login_test_body(self): | |||
"""localdb_enabled can block login with the local password""" | |||
self.register_user("localuser", "localpass") | |||
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
@override_config( | |||
{ | |||
**providers_config(PasswordOnlyAuthProvider), | |||
**legacy_providers_config(LegacyPasswordOnlyAuthProvider), | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_no_local_user_fallback_ui_auth(self): | |||
def test_no_local_user_fallback_ui_auth_legacy(self): | |||
self.no_local_user_fallback_ui_auth_test_body() | |||
def no_local_user_fallback_ui_auth_test_body(self): | |||
"""localdb_enabled can block ui auth with the local password""" | |||
self.register_user("localuser", "localpass") | |||
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
@override_config( | |||
{ | |||
**providers_config(PasswordOnlyAuthProvider), | |||
**legacy_providers_config(LegacyPasswordOnlyAuthProvider), | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_auth_disabled(self): | |||
def test_password_auth_disabled_legacy(self): | |||
self.password_auth_disabled_test_body() | |||
def password_auth_disabled_test_body(self): | |||
"""password auth doesn't work if it's disabled across the board""" | |||
# login flows should be empty | |||
flows = self._get_login_flows() | |||
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 400, channel.result) | |||
mock_password_provider.check_password.assert_not_called() | |||
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) | |||
def test_custom_auth_provider_login_legacy(self): | |||
self.custom_auth_provider_login_test_body() | |||
@override_config(providers_config(CustomAuthProvider)) | |||
def test_custom_auth_provider_login(self): | |||
self.custom_auth_provider_login_test_body() | |||
def custom_auth_provider_login_test_body(self): | |||
# login flows should have the custom flow and m.login.password, since we | |||
# haven't disabled local password lookup. | |||
# (password must come first, because reasons) | |||
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 400, channel.result) | |||
mock_password_provider.check_auth.assert_not_called() | |||
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") | |||
mock_password_provider.check_auth.return_value = defer.succeed( | |||
("@user:bz", None) | |||
) | |||
channel = self._send_login("test.login_type", "u", test_field="y") | |||
self.assertEqual(channel.code, 200, channel.result) | |||
self.assertEqual("@user:bz", channel.json_body["user_id"]) | |||
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
# in these cases, but at least we can guard against the API changing | |||
# unexpectedly | |||
mock_password_provider.check_auth.return_value = defer.succeed( | |||
"@ MALFORMED! :bz" | |||
("@ MALFORMED! :bz", None) | |||
) | |||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") | |||
self.assertEqual(channel.code, 200, channel.result) | |||
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
" USER🙂NAME ", "test.login_type", {"test_field": " abc "} | |||
) | |||
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) | |||
def test_custom_auth_provider_ui_auth_legacy(self): | |||
self.custom_auth_provider_ui_auth_test_body() | |||
@override_config(providers_config(CustomAuthProvider)) | |||
def test_custom_auth_provider_ui_auth(self): | |||
self.custom_auth_provider_ui_auth_test_body() | |||
def custom_auth_provider_ui_auth_test_body(self): | |||
# register the user and log in twice, to get two devices | |||
self.register_user("localuser", "localpass") | |||
tok1 = self.login("localuser", "localpass") | |||
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.reset_mock() | |||
# right params, but authing as the wrong user | |||
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") | |||
mock_password_provider.check_auth.return_value = defer.succeed( | |||
("@user:bz", None) | |||
) | |||
body["auth"]["test_field"] = "foo" | |||
channel = self._delete_device(tok1, "dev2", body) | |||
self.assertEqual(channel.code, 403) | |||
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
# and finally, succeed | |||
mock_password_provider.check_auth.return_value = defer.succeed( | |||
"@localuser:test" | |||
("@localuser:test", None) | |||
) | |||
channel = self._delete_device(tok1, "dev2", body) | |||
self.assertEqual(channel.code, 200) | |||
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"localuser", "test.login_type", {"test_field": "foo"} | |||
) | |||
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) | |||
def test_custom_auth_provider_callback_legacy(self): | |||
self.custom_auth_provider_callback_test_body() | |||
@override_config(providers_config(CustomAuthProvider)) | |||
def test_custom_auth_provider_callback(self): | |||
self.custom_auth_provider_callback_test_body() | |||
def custom_auth_provider_callback_test_body(self): | |||
callback = Mock(return_value=defer.succeed(None)) | |||
mock_password_provider.check_auth.return_value = defer.succeed( | |||
@@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
for p in ["user_id", "access_token", "device_id", "home_server"]: | |||
self.assertIn(p, call_args[0]) | |||
@override_config( | |||
{ | |||
**legacy_providers_config(LegacyCustomAuthProvider), | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_password_disabled_legacy(self): | |||
self.custom_auth_password_disabled_test_body() | |||
@override_config( | |||
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} | |||
) | |||
def test_custom_auth_password_disabled(self): | |||
self.custom_auth_password_disabled_test_body() | |||
def custom_auth_password_disabled_test_body(self): | |||
"""Test login with a custom auth provider where password login is disabled""" | |||
self.register_user("localuser", "localpass") | |||
@@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 400, channel.result) | |||
mock_password_provider.check_auth.assert_not_called() | |||
@override_config( | |||
{ | |||
**legacy_providers_config(LegacyCustomAuthProvider), | |||
"password_config": {"enabled": False, "localdb_enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_password_disabled_localdb_enabled_legacy(self): | |||
self.custom_auth_password_disabled_localdb_enabled_test_body() | |||
@override_config( | |||
{ | |||
**providers_config(CustomAuthProvider), | |||
@@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
} | |||
) | |||
def test_custom_auth_password_disabled_localdb_enabled(self): | |||
self.custom_auth_password_disabled_localdb_enabled_test_body() | |||
def custom_auth_password_disabled_localdb_enabled_test_body(self): | |||
"""Check the localdb_enabled == enabled == False | |||
Regression test for https://github.com/matrix-org/synapse/issues/8914: check | |||
@@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 400, channel.result) | |||
mock_password_provider.check_auth.assert_not_called() | |||
@override_config( | |||
{ | |||
**legacy_providers_config(LegacyPasswordCustomAuthProvider), | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_login_legacy(self): | |||
self.password_custom_auth_password_disabled_login_test_body() | |||
@override_config( | |||
{ | |||
**providers_config(PasswordCustomAuthProvider), | |||
@@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_login(self): | |||
self.password_custom_auth_password_disabled_login_test_body() | |||
def password_custom_auth_password_disabled_login_test_body(self): | |||
"""log in with a custom auth provider which implements password, but password | |||
login is disabled""" | |||
self.register_user("localuser", "localpass") | |||
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
channel = self._send_password_login("localuser", "localpass") | |||
self.assertEqual(channel.code, 400, channel.result) | |||
mock_password_provider.check_auth.assert_not_called() | |||
mock_password_provider.check_password.assert_not_called() | |||
@override_config( | |||
{ | |||
**legacy_providers_config(LegacyPasswordCustomAuthProvider), | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_ui_auth_legacy(self): | |||
self.password_custom_auth_password_disabled_ui_auth_test_body() | |||
@override_config( | |||
{ | |||
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_ui_auth(self): | |||
self.password_custom_auth_password_disabled_ui_auth_test_body() | |||
def password_custom_auth_password_disabled_ui_auth_test_body(self): | |||
"""UI Auth with a custom auth provider which implements password, but password | |||
login is disabled""" | |||
# register the user and log in twice via the test login type to get two devices, | |||
self.register_user("localuser", "localpass") | |||
mock_password_provider.check_auth.return_value = defer.succeed( | |||
"@localuser:test" | |||
("@localuser:test", None) | |||
) | |||
channel = self._send_login("test.login_type", "localuser", test_field="") | |||
self.assertEqual(channel.code, 200, channel.result) | |||
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"Password login has been disabled.", channel.json_body["error"] | |||
) | |||
mock_password_provider.check_auth.assert_not_called() | |||
mock_password_provider.check_password.assert_not_called() | |||
mock_password_provider.reset_mock() | |||
# successful auth | |||
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.check_auth.assert_called_once_with( | |||
"localuser", "test.login_type", {"test_field": "x"} | |||
) | |||
mock_password_provider.check_password.assert_not_called() | |||
@override_config( | |||
{ | |||
**legacy_providers_config(LegacyCustomAuthProvider), | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_no_local_user_fallback_legacy(self): | |||
self.custom_auth_no_local_user_fallback_test_body() | |||
@override_config( | |||
{ | |||
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
} | |||
) | |||
def test_custom_auth_no_local_user_fallback(self): | |||
self.custom_auth_no_local_user_fallback_test_body() | |||
def custom_auth_no_local_user_fallback_test_body(self): | |||
"""Test login with a custom auth provider where the local db is disabled""" | |||
self.register_user("localuser", "localpass") | |||