@@ -0,0 +1 @@ | |||
Refactor `get_user_by_id`. |
@@ -268,7 +268,7 @@ class InternalAuth(BaseAuth): | |||
stored_user = await self.store.get_user_by_id(user_id) | |||
if not stored_user: | |||
raise InvalidClientTokenError("Unknown user_id %s" % user_id) | |||
if not stored_user["is_guest"]: | |||
if not stored_user.is_guest: | |||
raise InvalidClientTokenError( | |||
"Guest access token used for regular user" | |||
) | |||
@@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth): | |||
user_id = UserID(username, self._hostname) | |||
# First try to find a user from the username claim | |||
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string()) | |||
user_info = await self.store.get_user_by_id(user_id=user_id.to_string()) | |||
if user_info is None: | |||
# If the user does not exist, we should create it on the fly | |||
# TODO: we could use SCIM to provision users ahead of time and listen | |||
@@ -102,7 +102,7 @@ class AccountHandler: | |||
""" | |||
status = {"exists": False} | |||
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) | |||
userinfo = await self._main_store.get_user_by_id(user_id.to_string()) | |||
if userinfo is not None: | |||
status = { | |||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set | |||
from synapse.api.constants import Direction, Membership | |||
from synapse.events import EventBase | |||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID | |||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo | |||
from synapse.visibility import filter_events_for_client | |||
if TYPE_CHECKING: | |||
@@ -57,38 +57,30 @@ class AdminHandler: | |||
async def get_user(self, user: UserID) -> Optional[JsonDict]: | |||
"""Function to get user details""" | |||
user_info_dict = await self._store.get_user_by_id(user.to_string()) | |||
if user_info_dict is None: | |||
user_info: Optional[UserInfo] = await self._store.get_user_by_id( | |||
user.to_string() | |||
) | |||
if user_info is None: | |||
return None | |||
# Restrict returned information to a known set of fields. This prevents additional | |||
# fields added to get_user_by_id from modifying Synapse's external API surface. | |||
user_info_to_return = { | |||
"name", | |||
"admin", | |||
"deactivated", | |||
"locked", | |||
"shadow_banned", | |||
"creation_ts", | |||
"appservice_id", | |||
"consent_server_notice_sent", | |||
"consent_version", | |||
"consent_ts", | |||
"user_type", | |||
"is_guest", | |||
"last_seen_ts", | |||
user_info_dict = { | |||
"name": user.to_string(), | |||
"admin": user_info.is_admin, | |||
"deactivated": user_info.is_deactivated, | |||
"locked": user_info.locked, | |||
"shadow_banned": user_info.is_shadow_banned, | |||
"creation_ts": user_info.creation_ts, | |||
"appservice_id": user_info.appservice_id, | |||
"consent_server_notice_sent": user_info.consent_server_notice_sent, | |||
"consent_version": user_info.consent_version, | |||
"consent_ts": user_info.consent_ts, | |||
"user_type": user_info.user_type, | |||
"is_guest": user_info.is_guest, | |||
} | |||
if self._msc3866_enabled: | |||
# Only include the approved flag if support for MSC3866 is enabled. | |||
user_info_to_return.add("approved") | |||
# Restrict returned keys to a known set. | |||
user_info_dict = { | |||
key: value | |||
for key, value in user_info_dict.items() | |||
if key in user_info_to_return | |||
} | |||
user_info_dict["approved"] = user_info.approved | |||
# Add additional user metadata | |||
profile = await self._store.get_profileinfo(user) | |||
@@ -105,6 +97,9 @@ class AdminHandler: | |||
user_info_dict["external_ids"] = external_ids | |||
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) | |||
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string()) | |||
user_info_dict["last_seen_ts"] = last_seen_ts | |||
return user_info_dict | |||
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: | |||
@@ -828,13 +828,13 @@ class EventCreationHandler: | |||
u = await self.store.get_user_by_id(user_id) | |||
assert u is not None | |||
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): | |||
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT): | |||
# support and bot users are not required to consent | |||
return | |||
if u["appservice_id"] is not None: | |||
if u.appservice_id is not None: | |||
# users registered by an appservice are exempt | |||
return | |||
if u["consent_version"] == self.config.consent.user_consent_version: | |||
if u.consent_version == self.config.consent.user_consent_version: | |||
return | |||
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) | |||
@@ -572,7 +572,7 @@ class ModuleApi: | |||
Returns: | |||
UserInfo object if a user was found, otherwise None | |||
""" | |||
return await self._store.get_userinfo_by_id(user_id) | |||
return await self._store.get_user_by_id(user_id) | |||
async def get_user_by_req( | |||
self, | |||
@@ -1878,7 +1878,7 @@ class AccountDataManager: | |||
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") | |||
# Ensure the user exists, so we don't just write to users that aren't there. | |||
if await self._store.get_userinfo_by_id(user_id) is None: | |||
if await self._store.get_user_by_id(user_id) is None: | |||
raise ValueError(f"User {user_id} does not exist on this server.") | |||
await self._handler.add_account_data_for_user(user_id, data_type, new_data) |
@@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource): | |||
if u is None: | |||
raise NotFoundError("Unknown user") | |||
has_consented = u["consent_version"] == version | |||
has_consented = u.consent_version == version | |||
userhmac = userhmac_bytes.decode("ascii") | |||
try: | |||
@@ -79,15 +79,15 @@ class ConsentServerNotices: | |||
if u is None: | |||
return | |||
if u["is_guest"] and not self._send_to_guests: | |||
if u.is_guest and not self._send_to_guests: | |||
# don't send to guests | |||
return | |||
if u["consent_version"] == self._current_consent_version: | |||
if u.consent_version == self._current_consent_version: | |||
# user has already consented | |||
return | |||
if u["consent_server_notice_sent"] == self._current_consent_version: | |||
if u.consent_server_notice_sent == self._current_consent_version: | |||
# we've already sent a notice to the user | |||
return | |||
@@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke | |||
} | |||
return list(results.values()) | |||
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]: | |||
"""Get the last seen timestamp for a user, if we have it.""" | |||
return await self.db_pool.simple_select_one_onecol( | |||
table="user_ips", | |||
keyvalues={"user_id": user_id}, | |||
retcol="MAX(last_seen)", | |||
allow_none=True, | |||
desc="get_last_seen_for_user_id", | |||
) |
@@ -16,7 +16,7 @@ | |||
import logging | |||
import random | |||
import re | |||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast | |||
import attr | |||
@@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
) | |||
@cached() | |||
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: | |||
"""Deprecated: use get_userinfo_by_id instead""" | |||
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: | |||
"""Returns info about the user account, if it exists.""" | |||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: | |||
# We could technically use simple_select_one here, but it would not perform | |||
@@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
txn.execute( | |||
""" | |||
SELECT | |||
name, password_hash, is_guest, admin, consent_version, consent_ts, | |||
name, is_guest, admin, consent_version, consent_ts, | |||
consent_server_notice_sent, appservice_id, creation_ts, user_type, | |||
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, | |||
COALESCE(approved, TRUE) AS approved, | |||
COALESCE(locked, FALSE) AS locked, last_seen_ts | |||
COALESCE(locked, FALSE) AS locked | |||
FROM users | |||
LEFT JOIN ( | |||
SELECT user_id, MAX(last_seen) AS last_seen_ts | |||
FROM user_ips GROUP BY user_id | |||
) ls ON users.name = ls.user_id | |||
WHERE name = ? | |||
""", | |||
(user_id,), | |||
@@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
desc="get_user_by_id", | |||
func=get_user_by_id_txn, | |||
) | |||
if row is not None: | |||
# If we're using SQLite our boolean values will be integers. Because we | |||
# present some of this data as is to e.g. server admins via REST APIs, we | |||
# want to make sure we're returning the right type of data. | |||
# Note: when adding a column name to this list, be wary of NULLable columns, | |||
# since NULL values will be turned into False. | |||
boolean_columns = [ | |||
"admin", | |||
"deactivated", | |||
"shadow_banned", | |||
"approved", | |||
"locked", | |||
] | |||
for column in boolean_columns: | |||
row[column] = bool(row[column]) | |||
return row | |||
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: | |||
"""Get a UserInfo object for a user by user ID. | |||
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, | |||
this method should be cached. | |||
Args: | |||
user_id: The user to fetch user info for. | |||
Returns: | |||
`UserInfo` object if user found, otherwise `None`. | |||
""" | |||
user_data = await self.get_user_by_id(user_id) | |||
if not user_data: | |||
if row is None: | |||
return None | |||
return UserInfo( | |||
appservice_id=user_data["appservice_id"], | |||
consent_server_notice_sent=user_data["consent_server_notice_sent"], | |||
consent_version=user_data["consent_version"], | |||
creation_ts=user_data["creation_ts"], | |||
is_admin=bool(user_data["admin"]), | |||
is_deactivated=bool(user_data["deactivated"]), | |||
is_guest=bool(user_data["is_guest"]), | |||
is_shadow_banned=bool(user_data["shadow_banned"]), | |||
user_id=UserID.from_string(user_data["name"]), | |||
user_type=user_data["user_type"], | |||
last_seen_ts=user_data["last_seen_ts"], | |||
appservice_id=row["appservice_id"], | |||
consent_server_notice_sent=row["consent_server_notice_sent"], | |||
consent_version=row["consent_version"], | |||
consent_ts=row["consent_ts"], | |||
creation_ts=row["creation_ts"], | |||
is_admin=bool(row["admin"]), | |||
is_deactivated=bool(row["deactivated"]), | |||
is_guest=bool(row["is_guest"]), | |||
is_shadow_banned=bool(row["shadow_banned"]), | |||
user_id=UserID.from_string(row["name"]), | |||
user_type=row["user_type"], | |||
approved=bool(row["approved"]), | |||
locked=bool(row["locked"]), | |||
) | |||
async def is_trial_user(self, user_id: str) -> bool: | |||
@@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
now = self._clock.time_msec() | |||
days = self.config.server.mau_appservice_trial_days.get( | |||
info["appservice_id"], self.config.server.mau_trial_days | |||
info.appservice_id, self.config.server.mau_trial_days | |||
) | |||
trial_duration_ms = days * 24 * 60 * 60 * 1000 | |||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms | |||
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms | |||
return is_trial | |||
@cached() | |||
@@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key( | |||
@attr.s(auto_attribs=True, frozen=True, slots=True) | |||
class UserInfo: | |||
"""Holds information about a user. Result of get_userinfo_by_id. | |||
"""Holds information about a user. Result of get_user_by_id. | |||
Attributes: | |||
user_id: ID of the user. | |||
appservice_id: Application service ID that created this user. | |||
consent_server_notice_sent: Version of policy documents the user has been sent. | |||
consent_version: Version of policy documents the user has consented to. | |||
consent_ts: Time the user consented | |||
creation_ts: Creation timestamp of the user. | |||
is_admin: True if the user is an admin. | |||
is_deactivated: True if the user has been deactivated. | |||
is_guest: True if the user is a guest user. | |||
is_shadow_banned: True if the user has been shadow-banned. | |||
user_type: User type (None for normal user, 'support' and 'bot' other options). | |||
last_seen_ts: Last activity timestamp of the user. | |||
approved: If the user has been "approved" to register on the server. | |||
locked: Whether the user's account has been locked | |||
""" | |||
user_id: UserID | |||
appservice_id: Optional[int] | |||
consent_server_notice_sent: Optional[str] | |||
consent_version: Optional[str] | |||
consent_ts: Optional[int] | |||
user_type: Optional[str] | |||
creation_ts: int | |||
is_admin: bool | |||
is_deactivated: bool | |||
is_guest: bool | |||
is_shadow_banned: bool | |||
last_seen_ts: Optional[int] | |||
approved: bool | |||
locked: bool | |||
class UserProfile(TypedDict): | |||
@@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
) | |||
app_service.is_interested_in_user = Mock(return_value=True) | |||
self.store.get_app_service_by_token = Mock(return_value=app_service) | |||
# This just needs to return a truth-y value. | |||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | |||
class FakeUserInfo: | |||
is_guest = False | |||
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) | |||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||
request = Mock(args={}) | |||
@@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
) | |||
def test_get_guest_user_from_macaroon(self) -> None: | |||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) | |||
class FakeUserInfo: | |||
is_guest = True | |||
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) | |||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||
user_id = "@baldrick:matrix.org" | |||
@@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import UserTypes | |||
from synapse.api.errors import ThreepidValidationError | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, UserID | |||
from synapse.types import JsonDict, UserID, UserInfo | |||
from synapse.util import Clock | |||
from tests.unittest import HomeserverTestCase, override_config | |||
@@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase): | |||
self.get_success(self.store.register_user(self.user_id, self.pwhash)) | |||
self.assertEqual( | |||
{ | |||
UserInfo( | |||
# TODO(paul): Surely this field should be 'user_id', not 'name' | |||
"name": self.user_id, | |||
"password_hash": self.pwhash, | |||
"admin": 0, | |||
"is_guest": 0, | |||
"consent_version": None, | |||
"consent_ts": None, | |||
"consent_server_notice_sent": None, | |||
"appservice_id": None, | |||
"creation_ts": 0, | |||
"user_type": None, | |||
"deactivated": 0, | |||
"locked": 0, | |||
"shadow_banned": 0, | |||
"approved": 1, | |||
"last_seen_ts": None, | |||
}, | |||
user_id=UserID.from_string(self.user_id), | |||
is_admin=False, | |||
is_guest=False, | |||
consent_server_notice_sent=None, | |||
consent_ts=None, | |||
consent_version=None, | |||
appservice_id=None, | |||
creation_ts=0, | |||
user_type=None, | |||
is_deactivated=False, | |||
locked=False, | |||
is_shadow_banned=False, | |||
approved=True, | |||
), | |||
(self.get_success(self.store.get_user_by_id(self.user_id))), | |||
) | |||
@@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase): | |||
user = self.get_success(self.store.get_user_by_id(self.user_id)) | |||
assert user | |||
self.assertEqual(user["consent_version"], "1") | |||
self.assertGreater(user["consent_ts"], before_consent) | |||
self.assertLess(user["consent_ts"], self.clock.time_msec()) | |||
self.assertEqual(user.consent_version, "1") | |||
self.assertIsNotNone(user.consent_ts) | |||
assert user.consent_ts is not None | |||
self.assertGreater(user.consent_ts, before_consent) | |||
self.assertLess(user.consent_ts, self.clock.time_msec()) | |||
def test_add_tokens(self) -> None: | |||
self.get_success(self.store.register_user(self.user_id, self.pwhash)) | |||
@@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): | |||
user = self.get_success(self.store.get_user_by_id(self.user_id)) | |||
assert user is not None | |||
self.assertTrue(user["approved"]) | |||
self.assertTrue(user.approved) | |||
approved = self.get_success(self.store.is_user_approved(self.user_id)) | |||
self.assertTrue(approved) | |||
@@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): | |||
user = self.get_success(self.store.get_user_by_id(self.user_id)) | |||
assert user is not None | |||
self.assertFalse(user["approved"]) | |||
self.assertFalse(user.approved) | |||
approved = self.get_success(self.store.is_user_approved(self.user_id)) | |||
self.assertFalse(approved) | |||
@@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): | |||
user = self.get_success(self.store.get_user_by_id(self.user_id)) | |||
self.assertIsNotNone(user) | |||
assert user is not None | |||
self.assertEqual(user["approved"], 1) | |||
self.assertEqual(user.approved, 1) | |||
approved = self.get_success(self.store.is_user_approved(self.user_id)) | |||
self.assertTrue(approved) | |||