|
|
@@ -17,7 +17,7 @@ |
|
|
|
|
|
|
|
import logging |
|
|
|
import re |
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
|
|
from synapse.api.constants import UserTypes |
|
|
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError |
|
|
@@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
return is_trial |
|
|
|
|
|
|
|
@cached() |
|
|
|
def get_user_by_access_token(self, token): |
|
|
|
async def get_user_by_access_token(self, token: str) -> Optional[dict]: |
|
|
|
"""Get a user from the given access token. |
|
|
|
|
|
|
|
Args: |
|
|
|
token (str): The access token of a user. |
|
|
|
token: The access token of a user. |
|
|
|
Returns: |
|
|
|
defer.Deferred: None, if the token did not match, otherwise dict |
|
|
|
including the keys `name`, `is_guest`, `device_id`, `token_id`, |
|
|
|
`valid_until_ms`. |
|
|
|
None, if the token did not match, otherwise dict |
|
|
|
including the keys `name`, `is_guest`, `device_id`, `token_id`, |
|
|
|
`valid_until_ms`. |
|
|
|
""" |
|
|
|
return self.db_pool.runInteraction( |
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
"get_user_by_access_token", self._query_for_auth, token |
|
|
|
) |
|
|
|
|
|
|
@@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
|
|
|
|
return bool(res) if res else False |
|
|
|
|
|
|
|
def set_server_admin(self, user, admin): |
|
|
|
async def set_server_admin(self, user: UserID, admin: bool) -> None: |
|
|
|
"""Sets whether a user is an admin of this homeserver. |
|
|
|
|
|
|
|
Args: |
|
|
|
user (UserID): user ID of the user to test |
|
|
|
admin (bool): true iff the user is to be a server admin, |
|
|
|
false otherwise. |
|
|
|
user: user ID of the user to test |
|
|
|
admin: true iff the user is to be a server admin, false otherwise. |
|
|
|
""" |
|
|
|
|
|
|
|
def set_server_admin_txn(txn): |
|
|
@@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
txn, self.get_user_by_id, (user.to_string(),) |
|
|
|
) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) |
|
|
|
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) |
|
|
|
|
|
|
|
def _query_for_auth(self, txn, token): |
|
|
|
sql = ( |
|
|
@@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
) |
|
|
|
return True if res == UserTypes.SUPPORT else False |
|
|
|
|
|
|
|
def get_users_by_id_case_insensitive(self, user_id): |
|
|
|
async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: |
|
|
|
"""Gets users that match user_id case insensitively. |
|
|
|
Returns a mapping of user_id -> password_hash. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A mapping of user_id -> password_hash. |
|
|
|
""" |
|
|
|
|
|
|
|
def f(txn): |
|
|
@@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
txn.execute(sql, (user_id,)) |
|
|
|
return dict(txn) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) |
|
|
|
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) |
|
|
|
|
|
|
|
async def get_user_by_external_id( |
|
|
|
self, auth_provider: str, external_id: str |
|
|
@@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
|
|
|
|
return await self.db_pool.runInteraction("count_users", _count_users) |
|
|
|
|
|
|
|
def count_daily_user_type(self): |
|
|
|
async def count_daily_user_type(self) -> Dict[str, int]: |
|
|
|
""" |
|
|
|
Counts 1) native non guest users |
|
|
|
2) native guests users |
|
|
@@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
results[row[0]] = row[1] |
|
|
|
return results |
|
|
|
|
|
|
|
return self.db_pool.runInteraction( |
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
"count_daily_user_type", _count_daily_user_type |
|
|
|
) |
|
|
|
|
|
|
@@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
# Convert the integer into a boolean. |
|
|
|
return res == 1 |
|
|
|
|
|
|
|
def get_threepid_validation_session( |
|
|
|
self, medium, client_secret, address=None, sid=None, validated=True |
|
|
|
): |
|
|
|
async def get_threepid_validation_session( |
|
|
|
self, |
|
|
|
medium: Optional[str], |
|
|
|
client_secret: str, |
|
|
|
address: Optional[str] = None, |
|
|
|
sid: Optional[str] = None, |
|
|
|
validated: Optional[bool] = True, |
|
|
|
) -> Optional[Dict[str, Any]]: |
|
|
|
"""Gets a session_id and last_send_attempt (if available) for a |
|
|
|
combination of validation metadata |
|
|
|
|
|
|
|
Args: |
|
|
|
medium (str|None): The medium of the 3PID |
|
|
|
address (str|None): The address of the 3PID |
|
|
|
sid (str|None): The ID of the validation session |
|
|
|
client_secret (str): A unique string provided by the client to help identify this |
|
|
|
medium: The medium of the 3PID |
|
|
|
client_secret: A unique string provided by the client to help identify this |
|
|
|
validation attempt |
|
|
|
validated (bool|None): Whether sessions should be filtered by |
|
|
|
address: The address of the 3PID |
|
|
|
sid: The ID of the validation session |
|
|
|
validated: Whether sessions should be filtered by |
|
|
|
whether they have been validated already or not. None to |
|
|
|
perform no filtering |
|
|
|
|
|
|
|
Returns: |
|
|
|
Deferred[dict|None]: A dict containing the following: |
|
|
|
A dict containing the following: |
|
|
|
* address - address of the 3pid |
|
|
|
* medium - medium of the 3pid |
|
|
|
* client_secret - a secret provided by the client for this validation session |
|
|
@@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
|
|
|
|
return rows[0] |
|
|
|
|
|
|
|
return self.db_pool.runInteraction( |
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
"get_threepid_validation_session", get_threepid_validation_session_txn |
|
|
|
) |
|
|
|
|
|
|
|
def delete_threepid_session(self, session_id): |
|
|
|
async def delete_threepid_session(self, session_id: str) -> None: |
|
|
|
"""Removes a threepid validation session from the database. This can |
|
|
|
be done after validation has been performed and whatever action was |
|
|
|
waiting on it has been carried out |
|
|
|
|
|
|
|
Args: |
|
|
|
session_id (str): The ID of the session to delete |
|
|
|
session_id: The ID of the session to delete |
|
|
|
""" |
|
|
|
|
|
|
|
def delete_threepid_session_txn(txn): |
|
|
@@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore): |
|
|
|
keyvalues={"session_id": session_id}, |
|
|
|
) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction( |
|
|
|
await self.db_pool.runInteraction( |
|
|
|
"delete_threepid_session", delete_threepid_session_txn |
|
|
|
) |
|
|
|
|
|
|
@@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
desc="add_access_token_to_user", |
|
|
|
) |
|
|
|
|
|
|
|
def register_user( |
|
|
|
async def register_user( |
|
|
|
self, |
|
|
|
user_id, |
|
|
|
password_hash=None, |
|
|
|
was_guest=False, |
|
|
|
make_guest=False, |
|
|
|
appservice_id=None, |
|
|
|
create_profile_with_displayname=None, |
|
|
|
admin=False, |
|
|
|
user_type=None, |
|
|
|
shadow_banned=False, |
|
|
|
): |
|
|
|
user_id: str, |
|
|
|
password_hash: Optional[str] = None, |
|
|
|
was_guest: bool = False, |
|
|
|
make_guest: bool = False, |
|
|
|
appservice_id: Optional[str] = None, |
|
|
|
create_profile_with_displayname: Optional[str] = None, |
|
|
|
admin: bool = False, |
|
|
|
user_type: Optional[str] = None, |
|
|
|
shadow_banned: bool = False, |
|
|
|
) -> None: |
|
|
|
"""Attempts to register an account. |
|
|
|
|
|
|
|
Args: |
|
|
|
user_id (str): The desired user ID to register. |
|
|
|
password_hash (str|None): Optional. The password hash for this user. |
|
|
|
was_guest (bool): Optional. Whether this is a guest account being |
|
|
|
upgraded to a non-guest account. |
|
|
|
make_guest (boolean): True if the the new user should be guest, |
|
|
|
false to add a regular user account. |
|
|
|
appservice_id (str): The ID of the appservice registering the user. |
|
|
|
create_profile_with_displayname (unicode): Optionally create a profile for |
|
|
|
user_id: The desired user ID to register. |
|
|
|
password_hash: Optional. The password hash for this user. |
|
|
|
was_guest: Whether this is a guest account being upgraded to a |
|
|
|
non-guest account. |
|
|
|
make_guest: True if the the new user should be guest, false to add a |
|
|
|
regular user account. |
|
|
|
appservice_id: The ID of the appservice registering the user. |
|
|
|
create_profile_with_displayname: Optionally create a profile for |
|
|
|
the user, setting their displayname to the given value |
|
|
|
admin (boolean): is an admin user? |
|
|
|
user_type (str|None): type of user. One of the values from |
|
|
|
api.constants.UserTypes, or None for a normal user. |
|
|
|
shadow_banned (bool): Whether the user is shadow-banned, |
|
|
|
i.e. they may be told their requests succeeded but we ignore them. |
|
|
|
admin: is an admin user? |
|
|
|
user_type: type of user. One of the values from api.constants.UserTypes, |
|
|
|
or None for a normal user. |
|
|
|
shadow_banned: Whether the user is shadow-banned, i.e. they may be |
|
|
|
told their requests succeeded but we ignore them. |
|
|
|
|
|
|
|
Raises: |
|
|
|
StoreError if the user_id could not be registered. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Deferred |
|
|
|
""" |
|
|
|
return self.db_pool.runInteraction( |
|
|
|
await self.db_pool.runInteraction( |
|
|
|
"register_user", |
|
|
|
self._register_user, |
|
|
|
user_id, |
|
|
@@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
desc="record_user_external_id", |
|
|
|
) |
|
|
|
|
|
|
|
def user_set_password_hash(self, user_id, password_hash): |
|
|
|
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: |
|
|
|
""" |
|
|
|
NB. This does *not* evict any cache because the one use for this |
|
|
|
removes most of the entries subsequently anyway so it would be |
|
|
@@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
) |
|
|
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction( |
|
|
|
await self.db_pool.runInteraction( |
|
|
|
"user_set_password_hash", user_set_password_hash_txn |
|
|
|
) |
|
|
|
|
|
|
|
def user_set_consent_version(self, user_id, consent_version): |
|
|
|
async def user_set_consent_version( |
|
|
|
self, user_id: str, consent_version: str |
|
|
|
) -> None: |
|
|
|
"""Updates the user table to record privacy policy consent |
|
|
|
|
|
|
|
Args: |
|
|
|
user_id (str): full mxid of the user to update |
|
|
|
consent_version (str): version of the policy the user has consented |
|
|
|
to |
|
|
|
user_id: full mxid of the user to update |
|
|
|
consent_version: version of the policy the user has consented to |
|
|
|
|
|
|
|
Raises: |
|
|
|
StoreError(404) if user not found |
|
|
@@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
) |
|
|
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction("user_set_consent_version", f) |
|
|
|
await self.db_pool.runInteraction("user_set_consent_version", f) |
|
|
|
|
|
|
|
def user_set_consent_server_notice_sent(self, user_id, consent_version): |
|
|
|
async def user_set_consent_server_notice_sent( |
|
|
|
self, user_id: str, consent_version: str |
|
|
|
) -> None: |
|
|
|
"""Updates the user table to record that we have sent the user a server |
|
|
|
notice about privacy policy consent |
|
|
|
|
|
|
|
Args: |
|
|
|
user_id (str): full mxid of the user to update |
|
|
|
consent_version (str): version of the policy we have notified the |
|
|
|
user about |
|
|
|
user_id: full mxid of the user to update |
|
|
|
consent_version: version of the policy we have notified the user about |
|
|
|
|
|
|
|
Raises: |
|
|
|
StoreError(404) if user not found |
|
|
@@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
) |
|
|
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) |
|
|
|
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) |
|
|
|
|
|
|
|
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): |
|
|
|
async def user_delete_access_tokens( |
|
|
|
self, |
|
|
|
user_id: str, |
|
|
|
except_token_id: Optional[str] = None, |
|
|
|
device_id: Optional[str] = None, |
|
|
|
) -> List[Tuple[str, int, Optional[str]]]: |
|
|
|
""" |
|
|
|
Invalidate access tokens belonging to a user |
|
|
|
|
|
|
|
Args: |
|
|
|
user_id (str): ID of user the tokens belong to |
|
|
|
except_token_id (str): list of access_tokens IDs which should |
|
|
|
*not* be deleted |
|
|
|
device_id (str|None): ID of device the tokens are associated with. |
|
|
|
user_id: ID of user the tokens belong to |
|
|
|
except_token_id: access_tokens ID which should *not* be deleted |
|
|
|
device_id: ID of device the tokens are associated with. |
|
|
|
If None, tokens associated with any device (or no device) will |
|
|
|
be deleted |
|
|
|
Returns: |
|
|
|
defer.Deferred[list[str, int, str|None, int]]: a list of |
|
|
|
(token, token id, device id) for each of the deleted tokens |
|
|
|
A tuple of (token, token id, device id) for each of the deleted tokens |
|
|
|
""" |
|
|
|
|
|
|
|
def f(txn): |
|
|
@@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
|
|
|
|
return tokens_and_devices |
|
|
|
|
|
|
|
return self.db_pool.runInteraction("user_delete_access_tokens", f) |
|
|
|
return await self.db_pool.runInteraction("user_delete_access_tokens", f) |
|
|
|
|
|
|
|
def delete_access_token(self, access_token): |
|
|
|
async def delete_access_token(self, access_token: str) -> None: |
|
|
|
def f(txn): |
|
|
|
self.db_pool.simple_delete_one_txn( |
|
|
|
txn, table="access_tokens", keyvalues={"token": access_token} |
|
|
@@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
txn, self.get_user_by_access_token, (access_token,) |
|
|
|
) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction("delete_access_token", f) |
|
|
|
await self.db_pool.runInteraction("delete_access_token", f) |
|
|
|
|
|
|
|
@cached() |
|
|
|
async def is_guest(self, user_id: str) -> bool: |
|
|
@@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
desc="get_users_pending_deactivation", |
|
|
|
) |
|
|
|
|
|
|
|
def validate_threepid_session(self, session_id, client_secret, token, current_ts): |
|
|
|
async def validate_threepid_session( |
|
|
|
self, session_id: str, client_secret: str, token: str, current_ts: int |
|
|
|
) -> Optional[str]: |
|
|
|
"""Attempt to validate a threepid session using a token |
|
|
|
|
|
|
|
Args: |
|
|
|
session_id (str): The id of a validation session |
|
|
|
client_secret (str): A unique string provided by the client to |
|
|
|
help identify this validation attempt |
|
|
|
token (str): A validation token |
|
|
|
current_ts (int): The current unix time in milliseconds. Used for |
|
|
|
checking token expiry status |
|
|
|
session_id: The id of a validation session |
|
|
|
client_secret: A unique string provided by the client to help identify |
|
|
|
this validation attempt |
|
|
|
token: A validation token |
|
|
|
current_ts: The current unix time in milliseconds. Used for checking |
|
|
|
token expiry status |
|
|
|
|
|
|
|
Raises: |
|
|
|
ThreepidValidationError: if a matching validation token was not found or has |
|
|
|
expired |
|
|
|
|
|
|
|
Returns: |
|
|
|
deferred str|None: A str representing a link to redirect the user |
|
|
|
to if there is one. |
|
|
|
A str representing a link to redirect the user to if there is one. |
|
|
|
""" |
|
|
|
|
|
|
|
# Insert everything into a transaction in order to run atomically |
|
|
@@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
return next_link |
|
|
|
|
|
|
|
# Return next_link if it exists |
|
|
|
return self.db_pool.runInteraction( |
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
"validate_threepid_session_txn", validate_threepid_session_txn |
|
|
|
) |
|
|
|
|
|
|
|
def start_or_continue_validation_session( |
|
|
|
async def start_or_continue_validation_session( |
|
|
|
self, |
|
|
|
medium, |
|
|
|
address, |
|
|
|
session_id, |
|
|
|
client_secret, |
|
|
|
send_attempt, |
|
|
|
next_link, |
|
|
|
token, |
|
|
|
token_expires, |
|
|
|
): |
|
|
|
medium: str, |
|
|
|
address: str, |
|
|
|
session_id: str, |
|
|
|
client_secret: str, |
|
|
|
send_attempt: int, |
|
|
|
next_link: Optional[str], |
|
|
|
token: str, |
|
|
|
token_expires: int, |
|
|
|
) -> None: |
|
|
|
"""Creates a new threepid validation session if it does not already |
|
|
|
exist and associates a new validation token with it |
|
|
|
|
|
|
|
Args: |
|
|
|
medium (str): The medium of the 3PID |
|
|
|
address (str): The address of the 3PID |
|
|
|
session_id (str): The id of this validation session |
|
|
|
client_secret (str): A unique string provided by the client to |
|
|
|
help identify this validation attempt |
|
|
|
send_attempt (int): The latest send_attempt on this session |
|
|
|
next_link (str|None): The link to redirect the user to upon |
|
|
|
successful validation |
|
|
|
token (str): The validation token |
|
|
|
token_expires (int): The timestamp for which after the token |
|
|
|
will no longer be valid |
|
|
|
medium: The medium of the 3PID |
|
|
|
address: The address of the 3PID |
|
|
|
session_id: The id of this validation session |
|
|
|
client_secret: A unique string provided by the client to help |
|
|
|
identify this validation attempt |
|
|
|
send_attempt: The latest send_attempt on this session |
|
|
|
next_link: The link to redirect the user to upon successful validation |
|
|
|
token: The validation token |
|
|
|
token_expires: The timestamp for which after the token will no |
|
|
|
longer be valid |
|
|
|
""" |
|
|
|
|
|
|
|
def start_or_continue_validation_session_txn(txn): |
|
|
@@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
}, |
|
|
|
) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction( |
|
|
|
await self.db_pool.runInteraction( |
|
|
|
"start_or_continue_validation_session", |
|
|
|
start_or_continue_validation_session_txn, |
|
|
|
) |
|
|
|
|
|
|
|
def cull_expired_threepid_validation_tokens(self): |
|
|
|
async def cull_expired_threepid_validation_tokens(self) -> None: |
|
|
|
"""Remove threepid validation tokens with expiry dates that have passed""" |
|
|
|
|
|
|
|
def cull_expired_threepid_validation_tokens_txn(txn, ts): |
|
|
@@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): |
|
|
|
DELETE FROM threepid_validation_token WHERE |
|
|
|
expires < ? |
|
|
|
""" |
|
|
|
return txn.execute(sql, (ts,)) |
|
|
|
txn.execute(sql, (ts,)) |
|
|
|
|
|
|
|
return self.db_pool.runInteraction( |
|
|
|
await self.db_pool.runInteraction( |
|
|
|
"cull_expired_threepid_validation_tokens", |
|
|
|
cull_expired_threepid_validation_tokens_txn, |
|
|
|
self.clock.time_msec(), |
|
|
|