Ver a proveniência

Refactor `get_user_by_id` (#16316)

tags/v1.93.0rc1
Erik Johnston há 8 meses
committed by GitHub
ascendente
cometimento
954921736b
Não foi encontrada uma chave conhecida para esta assinatura, na base de dados ID da chave GPG: 4AEE18F83AFDEB23
14 ficheiros alterados com 108 adições e 123 eliminações
  1. +1
    -0
      changelog.d/16316.misc
  2. +1
    -1
      synapse/api/auth/internal.py
  3. +1
    -1
      synapse/api/auth/msc3861_delegated.py
  4. +1
    -1
      synapse/handlers/account.py
  5. +22
    -27
      synapse/handlers/admin.py
  6. +3
    -3
      synapse/handlers/message.py
  7. +2
    -2
      synapse/module_api/__init__.py
  8. +1
    -1
      synapse/rest/consent/consent_resource.py
  9. +3
    -3
      synapse/server_notices/consent_server_notices.py
  10. +11
    -0
      synapse/storage/databases/main/client_ips.py
  11. +22
    -54
      synapse/storage/databases/main/registration.py
  12. +7
    -3
      synapse/types/__init__.py
  13. +9
    -3
      tests/api/test_auth.py
  14. +24
    -24
      tests/storage/test_registration.py

+ 1
- 0
changelog.d/16316.misc Ver ficheiro

@@ -0,0 +1 @@
Refactor `get_user_by_id`.

+ 1
- 1
synapse/api/auth/internal.py Ver ficheiro

@@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
if not stored_user.is_guest:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)


+ 1
- 1
synapse/api/auth/msc3861_delegated.py Ver ficheiro

@@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
user_id = UserID(username, self._hostname)

# First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
# If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen


+ 1
- 1
synapse/handlers/account.py Ver ficheiro

@@ -102,7 +102,7 @@ class AccountHandler:
"""
status = {"exists": False}

userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())

if userinfo is not None:
status = {


+ 22
- 27
synapse/handlers/admin.py Ver ficheiro

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set

from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
@@ -57,38 +57,30 @@ class AdminHandler:

async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None

# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"locked",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
"last_seen_ts",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
}

if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")

# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved

# Add additional user metadata
profile = await self._store.get_profileinfo(user)
@@ -105,6 +97,9 @@ class AdminHandler:
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())

last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts

return user_info_dict

async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:


+ 3
- 3
synapse/handlers/message.py Ver ficheiro

@@ -828,13 +828,13 @@ class EventCreationHandler:

u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
if u["appservice_id"] is not None:
if u.appservice_id is not None:
# users registered by an appservice are exempt
return
if u["consent_version"] == self.config.consent.user_consent_version:
if u.consent_version == self.config.consent.user_consent_version:
return

consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)


+ 2
- 2
synapse/module_api/__init__.py Ver ficheiro

@@ -572,7 +572,7 @@ class ModuleApi:
Returns:
UserInfo object if a user was found, otherwise None
"""
return await self._store.get_userinfo_by_id(user_id)
return await self._store.get_user_by_id(user_id)

async def get_user_by_req(
self,
@@ -1878,7 +1878,7 @@ class AccountDataManager:
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")

# Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None:
if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.")

await self._handler.add_account_data_for_user(user_id, data_type, new_data)

+ 1
- 1
synapse/rest/consent/consent_resource.py Ver ficheiro

@@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
if u is None:
raise NotFoundError("Unknown user")

has_consented = u["consent_version"] == version
has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii")

try:


+ 3
- 3
synapse/server_notices/consent_server_notices.py Ver ficheiro

@@ -79,15 +79,15 @@ class ConsentServerNotices:
if u is None:
return

if u["is_guest"] and not self._send_to_guests:
if u.is_guest and not self._send_to_guests:
# don't send to guests
return

if u["consent_version"] == self._current_consent_version:
if u.consent_version == self._current_consent_version:
# user has already consented
return

if u["consent_server_notice_sent"] == self._current_consent_version:
if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user
return



+ 11
- 0
synapse/storage/databases/main/client_ips.py Ver ficheiro

@@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
}

return list(results.values())

async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
"""Get the last seen timestamp for a user, if we have it."""

return await self.db_pool.simple_select_one_onecol(
table="user_ips",
keyvalues={"user_id": user_id},
retcol="MAX(last_seen)",
allow_none=True,
desc="get_last_seen_for_user_id",
)

+ 22
- 54
synapse/storage/databases/main/registration.py Ver ficheiro

@@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import attr

@@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)

@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""

def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
@@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked, last_seen_ts
COALESCE(locked, FALSE) AS locked
FROM users
LEFT JOIN (
SELECT user_id, MAX(last_seen) AS last_seen_ts
FROM user_ips GROUP BY user_id
) ls ON users.name = ls.user_id
WHERE name = ?
""",
(user_id,),
@@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id",
func=get_user_by_id_txn,
)

if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])

return row

async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.

Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.

Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
if row is None:
return None

return UserInfo(
appservice_id=user_data["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"],
consent_version=user_data["consent_version"],
creation_ts=user_data["creation_ts"],
is_admin=bool(user_data["admin"]),
is_deactivated=bool(user_data["deactivated"]),
is_guest=bool(user_data["is_guest"]),
is_shadow_banned=bool(user_data["shadow_banned"]),
user_id=UserID.from_string(user_data["name"]),
user_type=user_data["user_type"],
last_seen_ts=user_data["last_seen_ts"],
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)

async def is_trial_user(self, user_id: str) -> bool:
@@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):

now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days
info.appservice_id, self.config.server.mau_trial_days
)
trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial

@cached()


+ 7
- 3
synapse/types/__init__.py Ver ficheiro

@@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(

@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id.
"""Holds information about a user. Result of get_user_by_id.

Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
last_seen_ts: Last activity timestamp of the user.
approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
"""

user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
last_seen_ts: Optional[int]
approved: bool
locked: bool


class UserProfile(TypedDict):


+ 9
- 3
tests/api/test_auth.py Ver ficheiro

@@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})

class FakeUserInfo:
is_guest = False

self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
@@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
)

def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
class FakeUserInfo:
is_guest = True

self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)

user_id = "@baldrick:matrix.org"


+ 24
- 24
tests/storage/test_registration.py Ver ficheiro

@@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock

from tests.unittest import HomeserverTestCase, override_config
@@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash))

self.assertEqual(
{
UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'
"name": self.user_id,
"password_hash": self.pwhash,
"admin": 0,
"is_guest": 0,
"consent_version": None,
"consent_ts": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
"locked": 0,
"shadow_banned": 0,
"approved": 1,
"last_seen_ts": None,
},
user_id=UserID.from_string(self.user_id),
is_admin=False,
is_guest=False,
consent_server_notice_sent=None,
consent_ts=None,
consent_version=None,
appservice_id=None,
creation_ts=0,
user_type=None,
is_deactivated=False,
locked=False,
is_shadow_banned=False,
approved=True,
),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)

@@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):

user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user
self.assertEqual(user["consent_version"], "1")
self.assertGreater(user["consent_ts"], before_consent)
self.assertLess(user["consent_ts"], self.clock.time_msec())
self.assertEqual(user.consent_version, "1")
self.assertIsNotNone(user.consent_ts)
assert user.consent_ts is not None
self.assertGreater(user.consent_ts, before_consent)
self.assertLess(user.consent_ts, self.clock.time_msec())

def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
@@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):

user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertTrue(user["approved"])
self.assertTrue(user.approved)

approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
@@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):

user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertFalse(user["approved"])
self.assertFalse(user.approved)

approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
@@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
self.assertEqual(user["approved"], 1)
self.assertEqual(user.approved, 1)

approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)


Carregando…
Cancelar
Guardar