Signed-off-by: Andrii Yasynyshyn yasinishyn.a.n@gmail.comtags/v1.98.0rc1
@@ -0,0 +1 @@ | |||||
Adds on_user_login ModuleAPI callback allowing to execute custom code after (on) Auth. |
@@ -42,3 +42,16 @@ operations to keep track of them. (e.g. add them to a database table). The user | |||||
represented by their Matrix user ID. | represented by their Matrix user ID. | ||||
If multiple modules implement this callback, Synapse runs them all in order. | If multiple modules implement this callback, Synapse runs them all in order. | ||||
### `on_user_login` | |||||
_First introduced in Synapse v1.98.0_ | |||||
```python | |||||
async def on_user_login(user_id: str, auth_provider_type: str, auth_provider_id: str) -> None | |||||
``` | |||||
Called after successfully login or registration of a user for cases when module needs to perform extra operations after auth. | |||||
represented by their Matrix user ID. | |||||
If multiple modules implement this callback, Synapse runs them all in order. |
@@ -296,8 +296,7 @@ impl<'source> FromPyObject<'source> for JsonValue { | |||||
match l.iter().map(SimpleJsonValue::extract).collect() { | match l.iter().map(SimpleJsonValue::extract).collect() { | ||||
Ok(a) => Ok(JsonValue::Array(a)), | Ok(a) => Ok(JsonValue::Array(a)), | ||||
Err(e) => Err(PyTypeError::new_err(format!( | Err(e) => Err(PyTypeError::new_err(format!( | ||||
"Can't convert to JsonValue::Array: {}", | |||||
e | |||||
"Can't convert to JsonValue::Array: {e}" | |||||
))), | ))), | ||||
} | } | ||||
} else if let Ok(v) = SimpleJsonValue::extract(ob) { | } else if let Ok(v) = SimpleJsonValue::extract(ob) { | ||||
@@ -98,6 +98,22 @@ class AccountValidityHandler: | |||||
for callback in self._module_api_callbacks.on_user_registration_callbacks: | for callback in self._module_api_callbacks.on_user_registration_callbacks: | ||||
await callback(user_id) | await callback(user_id) | ||||
async def on_user_login( | |||||
self, | |||||
user_id: str, | |||||
auth_provider_type: Optional[str], | |||||
auth_provider_id: Optional[str], | |||||
) -> None: | |||||
"""Tell third-party modules about a user logins. | |||||
Args: | |||||
user_id: The mxID of the user. | |||||
auth_provider_type: The type of login. | |||||
auth_provider_id: The ID of the auth provider. | |||||
""" | |||||
for callback in self._module_api_callbacks.on_user_login_callbacks: | |||||
await callback(user_id, auth_provider_type, auth_provider_id) | |||||
@wrap_as_background_process("send_renewals") | @wrap_as_background_process("send_renewals") | ||||
async def _send_renewal_emails(self) -> None: | async def _send_renewal_emails(self) -> None: | ||||
"""Gets the list of users whose account is expiring in the amount of time | """Gets the list of users whose account is expiring in the amount of time | ||||
@@ -212,6 +212,7 @@ class AuthHandler: | |||||
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth | self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth | ||||
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled | self._password_localdb_enabled = hs.config.auth.password_localdb_enabled | ||||
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules | self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules | ||||
self._account_validity_handler = hs.get_account_validity_handler() | |||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config | # Ratelimiter for failed auth during UIA. Uses same ratelimit config | ||||
# as per `rc_login.failed_attempts`. | # as per `rc_login.failed_attempts`. | ||||
@@ -1783,6 +1784,13 @@ class AuthHandler: | |||||
client_redirect_url, "loginToken", login_token | client_redirect_url, "loginToken", login_token | ||||
) | ) | ||||
# Run post-login module callback handlers | |||||
await self._account_validity_handler.on_user_login( | |||||
user_id=registered_user_id, | |||||
auth_provider_type=LoginType.SSO, | |||||
auth_provider_id=auth_provider_id, | |||||
) | |||||
# if the client is whitelisted, we can redirect straight to it | # if the client is whitelisted, we can redirect straight to it | ||||
if client_redirect_url.startswith(self._whitelisted_sso_clients): | if client_redirect_url.startswith(self._whitelisted_sso_clients): | ||||
request.redirect(redirect_url) | request.redirect(redirect_url) | ||||
@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import ( | |||||
ON_LEGACY_ADMIN_REQUEST, | ON_LEGACY_ADMIN_REQUEST, | ||||
ON_LEGACY_RENEW_CALLBACK, | ON_LEGACY_RENEW_CALLBACK, | ||||
ON_LEGACY_SEND_MAIL_CALLBACK, | ON_LEGACY_SEND_MAIL_CALLBACK, | ||||
ON_USER_LOGIN_CALLBACK, | |||||
ON_USER_REGISTRATION_CALLBACK, | ON_USER_REGISTRATION_CALLBACK, | ||||
) | ) | ||||
from synapse.module_api.callbacks.spamchecker_callbacks import ( | from synapse.module_api.callbacks.spamchecker_callbacks import ( | ||||
@@ -334,6 +335,7 @@ class ModuleApi: | |||||
*, | *, | ||||
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, | is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, | ||||
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, | on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, | ||||
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None, | |||||
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, | on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, | ||||
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, | on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, | ||||
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, | on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, | ||||
@@ -345,6 +347,7 @@ class ModuleApi: | |||||
return self._callbacks.account_validity.register_callbacks( | return self._callbacks.account_validity.register_callbacks( | ||||
is_user_expired=is_user_expired, | is_user_expired=is_user_expired, | ||||
on_user_registration=on_user_registration, | on_user_registration=on_user_registration, | ||||
on_user_login=on_user_login, | |||||
on_legacy_send_mail=on_legacy_send_mail, | on_legacy_send_mail=on_legacy_send_mail, | ||||
on_legacy_renew=on_legacy_renew, | on_legacy_renew=on_legacy_renew, | ||||
on_legacy_admin_request=on_legacy_admin_request, | on_legacy_admin_request=on_legacy_admin_request, | ||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) | |||||
# Types for callbacks to be registered via the module api | # Types for callbacks to be registered via the module api | ||||
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] | IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] | ||||
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] | ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] | ||||
ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable] | |||||
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints | # Temporary hooks to allow for a transition from `/_matrix/client` endpoints | ||||
# to `/_synapse/client/account_validity`. See `register_callbacks` below. | # to `/_synapse/client/account_validity`. See `register_callbacks` below. | ||||
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] | ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] | ||||
@@ -33,6 +34,7 @@ class AccountValidityModuleApiCallbacks: | |||||
def __init__(self) -> None: | def __init__(self) -> None: | ||||
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] | self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] | ||||
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] | self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] | ||||
self.on_user_login_callbacks: List[ON_USER_LOGIN_CALLBACK] = [] | |||||
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None | self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None | ||||
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None | self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None | ||||
@@ -44,6 +46,7 @@ class AccountValidityModuleApiCallbacks: | |||||
self, | self, | ||||
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, | is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, | ||||
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, | on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, | ||||
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None, | |||||
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, | on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, | ||||
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, | on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, | ||||
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, | on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, | ||||
@@ -55,6 +58,9 @@ class AccountValidityModuleApiCallbacks: | |||||
if on_user_registration is not None: | if on_user_registration is not None: | ||||
self.on_user_registration_callbacks.append(on_user_registration) | self.on_user_registration_callbacks.append(on_user_registration) | ||||
if on_user_login is not None: | |||||
self.on_user_login_callbacks.append(on_user_login) | |||||
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and | # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and | ||||
# an admin one). As part of moving the feature into a module, we need to change | # an admin one). As part of moving the feature into a module, we need to change | ||||
# the path from /_matrix/client/unstable/account_validity/... to | # the path from /_matrix/client/unstable/account_validity/... to | ||||
@@ -115,6 +115,7 @@ class LoginRestServlet(RestServlet): | |||||
self.registration_handler = hs.get_registration_handler() | self.registration_handler = hs.get_registration_handler() | ||||
self._sso_handler = hs.get_sso_handler() | self._sso_handler = hs.get_sso_handler() | ||||
self._spam_checker = hs.get_module_api_callbacks().spam_checker | self._spam_checker = hs.get_module_api_callbacks().spam_checker | ||||
self._account_validity_handler = hs.get_account_validity_handler() | |||||
self._well_known_builder = WellKnownBuilder(hs) | self._well_known_builder = WellKnownBuilder(hs) | ||||
self._address_ratelimiter = Ratelimiter( | self._address_ratelimiter = Ratelimiter( | ||||
@@ -470,6 +471,13 @@ class LoginRestServlet(RestServlet): | |||||
device_id=device_id, | device_id=device_id, | ||||
) | ) | ||||
# execute the callback | |||||
await self._account_validity_handler.on_user_login( | |||||
user_id, | |||||
auth_provider_type=login_submission.get("type"), | |||||
auth_provider_id=auth_provider_id, | |||||
) | |||||
if valid_until_ms is not None: | if valid_until_ms is not None: | ||||
expires_in_ms = valid_until_ms - self.clock.time_msec() | expires_in_ms = valid_until_ms - self.clock.time_msec() | ||||
result["expires_in_ms"] = expires_in_ms | result["expires_in_ms"] = expires_in_ms | ||||