diff --git a/changelog.d/15207.feature b/changelog.d/15207.feature new file mode 100644 index 0000000000..17790d62eb --- /dev/null +++ b/changelog.d/15207.feature @@ -0,0 +1 @@ +Adds on_user_login ModuleAPI callback allowing to execute custom code after (on) Auth. \ No newline at end of file diff --git a/docs/modules/account_validity_callbacks.md b/docs/modules/account_validity_callbacks.md index 3cd0e72198..f5eefcd7d6 100644 --- a/docs/modules/account_validity_callbacks.md +++ b/docs/modules/account_validity_callbacks.md @@ -42,3 +42,16 @@ operations to keep track of them. (e.g. add them to a database table). The user represented by their Matrix user ID. If multiple modules implement this callback, Synapse runs them all in order. + +### `on_user_login` + +_First introduced in Synapse v1.98.0_ + +```python +async def on_user_login(user_id: str, auth_provider_type: str, auth_provider_id: str) -> None +``` + +Called after successfully login or registration of a user for cases when module needs to perform extra operations after auth. +represented by their Matrix user ID. + +If multiple modules implement this callback, Synapse runs them all in order. diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 5e1e8e1abb..68d4227baa 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -296,8 +296,7 @@ impl<'source> FromPyObject<'source> for JsonValue { match l.iter().map(SimpleJsonValue::extract).collect() { Ok(a) => Ok(JsonValue::Array(a)), Err(e) => Err(PyTypeError::new_err(format!( - "Can't convert to JsonValue::Array: {}", - e + "Can't convert to JsonValue::Array: {e}" ))), } } else if let Ok(v) = SimpleJsonValue::extract(ob) { diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 6c2a49a3b9..c66bb6364f 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -98,6 +98,22 @@ class AccountValidityHandler: for callback in self._module_api_callbacks.on_user_registration_callbacks: await callback(user_id) + async def on_user_login( + self, + user_id: str, + auth_provider_type: Optional[str], + auth_provider_id: Optional[str], + ) -> None: + """Tell third-party modules about a user logins. + + Args: + user_id: The mxID of the user. + auth_provider_type: The type of login. + auth_provider_id: The ID of the auth provider. + """ + for callback in self._module_api_callbacks.on_user_login_callbacks: + await callback(user_id, auth_provider_type, auth_provider_id) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2b0c505130..89cbaff864 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -212,6 +212,7 @@ class AuthHandler: self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth self._password_localdb_enabled = hs.config.auth.password_localdb_enabled self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules + self._account_validity_handler = hs.get_account_validity_handler() # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -1783,6 +1784,13 @@ class AuthHandler: client_redirect_url, "loginToken", login_token ) + # Run post-login module callback handlers + await self._account_validity_handler.on_user_login( + user_id=registered_user_id, + auth_provider_type=LoginType.SSO, + auth_provider_id=auth_provider_id, + ) + # if the client is whitelisted, we can redirect straight to it if client_redirect_url.startswith(self._whitelisted_sso_clients): request.redirect(redirect_url) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 812144a128..6ee53511f2 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import ( ON_LEGACY_ADMIN_REQUEST, ON_LEGACY_RENEW_CALLBACK, ON_LEGACY_SEND_MAIL_CALLBACK, + ON_USER_LOGIN_CALLBACK, ON_USER_REGISTRATION_CALLBACK, ) from synapse.module_api.callbacks.spamchecker_callbacks import ( @@ -334,6 +335,7 @@ class ModuleApi: *, is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None, on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, @@ -345,6 +347,7 @@ class ModuleApi: return self._callbacks.account_validity.register_callbacks( is_user_expired=is_user_expired, on_user_registration=on_user_registration, + on_user_login=on_user_login, on_legacy_send_mail=on_legacy_send_mail, on_legacy_renew=on_legacy_renew, on_legacy_admin_request=on_legacy_admin_request, diff --git a/synapse/module_api/callbacks/account_validity_callbacks.py b/synapse/module_api/callbacks/account_validity_callbacks.py index 531d0c9ddc..cbfdfbd3d1 100644 --- a/synapse/module_api/callbacks/account_validity_callbacks.py +++ b/synapse/module_api/callbacks/account_validity_callbacks.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) # Types for callbacks to be registered via the module api IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] +ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable] # Temporary hooks to allow for a transition from `/_matrix/client` endpoints # to `/_synapse/client/account_validity`. See `register_callbacks` below. ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] @@ -33,6 +34,7 @@ class AccountValidityModuleApiCallbacks: def __init__(self) -> None: self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] + self.on_user_login_callbacks: List[ON_USER_LOGIN_CALLBACK] = [] self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None @@ -44,6 +46,7 @@ class AccountValidityModuleApiCallbacks: self, is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None, on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, @@ -55,6 +58,9 @@ class AccountValidityModuleApiCallbacks: if on_user_registration is not None: self.on_user_registration_callbacks.append(on_user_registration) + if on_user_login is not None: + self.on_user_login_callbacks.append(on_user_login) + # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and # an admin one). As part of moving the feature into a module, we need to change # the path from /_matrix/client/unstable/account_validity/... to diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 7be327e26f..546f042f87 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -115,6 +115,7 @@ class LoginRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() self._spam_checker = hs.get_module_api_callbacks().spam_checker + self._account_validity_handler = hs.get_account_validity_handler() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( @@ -470,6 +471,13 @@ class LoginRestServlet(RestServlet): device_id=device_id, ) + # execute the callback + await self._account_validity_handler.on_user_login( + user_id, + auth_provider_type=login_submission.get("type"), + auth_provider_id=auth_provider_id, + ) + if valid_until_ms is not None: expires_in_ms = valid_until_ms - self.clock.time_msec() result["expires_in_ms"] = expires_in_ms