@@ -0,0 +1 @@ | |||
Read from column `full_user_id` rather than `user_id` of tables `profiles` and `user_filters`. |
@@ -152,9 +152,9 @@ class Filtering: | |||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) | |||
async def get_user_filter( | |||
self, user_localpart: str, filter_id: Union[int, str] | |||
self, user_id: UserID, filter_id: Union[int, str] | |||
) -> "FilterCollection": | |||
result = await self.store.get_user_filter(user_localpart, filter_id) | |||
result = await self.store.get_user_filter(user_id, filter_id) | |||
return FilterCollection(self._hs, result) | |||
def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]: | |||
@@ -164,7 +164,7 @@ class AccountValidityHandler: | |||
try: | |||
user_display_name = await self.store.get_profile_displayname( | |||
UserID.from_string(user_id).localpart | |||
UserID.from_string(user_id) | |||
) | |||
if user_display_name is None: | |||
user_display_name = user_id | |||
@@ -89,7 +89,7 @@ class AdminHandler: | |||
} | |||
# Add additional user metadata | |||
profile = await self._store.get_profileinfo(user.localpart) | |||
profile = await self._store.get_profileinfo(user) | |||
threepids = await self._store.user_get_threepids(user.to_string()) | |||
external_ids = [ | |||
({"auth_provider": auth_provider, "external_id": external_id}) | |||
@@ -1759,7 +1759,7 @@ class AuthHandler: | |||
return | |||
user_profile_data = await self.store.get_profileinfo( | |||
UserID.from_string(registered_user_id).localpart | |||
UserID.from_string(registered_user_id) | |||
) | |||
# Store any extra attributes which will be passed in the login response. | |||
@@ -297,5 +297,5 @@ class DeactivateAccountHandler: | |||
# Add the user to the directory, if necessary. Note that | |||
# this must be done after the user is re-activated, because | |||
# deactivated users are excluded from the user directory. | |||
profile = await self.store.get_profileinfo(user.localpart) | |||
profile = await self.store.get_profileinfo(user) | |||
await self.user_directory_handler.handle_local_profile_change(user_id, profile) |
@@ -67,7 +67,7 @@ class ProfileHandler: | |||
target_user = UserID.from_string(user_id) | |||
if self.hs.is_mine(target_user): | |||
profileinfo = await self.store.get_profileinfo(target_user.localpart) | |||
profileinfo = await self.store.get_profileinfo(target_user) | |||
if profileinfo.display_name is None: | |||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) | |||
@@ -99,9 +99,7 @@ class ProfileHandler: | |||
async def get_displayname(self, target_user: UserID) -> Optional[str]: | |||
if self.hs.is_mine(target_user): | |||
try: | |||
displayname = await self.store.get_profile_displayname( | |||
target_user.localpart | |||
) | |||
displayname = await self.store.get_profile_displayname(target_user) | |||
except StoreError as e: | |||
if e.code == 404: | |||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) | |||
@@ -147,7 +145,7 @@ class ProfileHandler: | |||
raise AuthError(400, "Cannot set another user's displayname") | |||
if not by_admin and not self.hs.config.registration.enable_set_displayname: | |||
profile = await self.store.get_profileinfo(target_user.localpart) | |||
profile = await self.store.get_profileinfo(target_user) | |||
if profile.display_name: | |||
raise SynapseError( | |||
400, | |||
@@ -180,7 +178,7 @@ class ProfileHandler: | |||
await self.store.set_profile_displayname(target_user, displayname_to_set) | |||
profile = await self.store.get_profileinfo(target_user.localpart) | |||
profile = await self.store.get_profileinfo(target_user) | |||
await self.user_directory_handler.handle_local_profile_change( | |||
target_user.to_string(), profile | |||
) | |||
@@ -194,9 +192,7 @@ class ProfileHandler: | |||
async def get_avatar_url(self, target_user: UserID) -> Optional[str]: | |||
if self.hs.is_mine(target_user): | |||
try: | |||
avatar_url = await self.store.get_profile_avatar_url( | |||
target_user.localpart | |||
) | |||
avatar_url = await self.store.get_profile_avatar_url(target_user) | |||
except StoreError as e: | |||
if e.code == 404: | |||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) | |||
@@ -241,7 +237,7 @@ class ProfileHandler: | |||
raise AuthError(400, "Cannot set another user's avatar_url") | |||
if not by_admin and not self.hs.config.registration.enable_set_avatar_url: | |||
profile = await self.store.get_profileinfo(target_user.localpart) | |||
profile = await self.store.get_profileinfo(target_user) | |||
if profile.avatar_url: | |||
raise SynapseError( | |||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN | |||
@@ -272,7 +268,7 @@ class ProfileHandler: | |||
await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) | |||
profile = await self.store.get_profileinfo(target_user.localpart) | |||
profile = await self.store.get_profileinfo(target_user) | |||
await self.user_directory_handler.handle_local_profile_change( | |||
target_user.to_string(), profile | |||
) | |||
@@ -369,14 +365,10 @@ class ProfileHandler: | |||
response = {} | |||
try: | |||
if just_field is None or just_field == "displayname": | |||
response["displayname"] = await self.store.get_profile_displayname( | |||
user.localpart | |||
) | |||
response["displayname"] = await self.store.get_profile_displayname(user) | |||
if just_field is None or just_field == "avatar_url": | |||
response["avatar_url"] = await self.store.get_profile_avatar_url( | |||
user.localpart | |||
) | |||
response["avatar_url"] = await self.store.get_profile_avatar_url(user) | |||
except StoreError as e: | |||
if e.code == 404: | |||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) | |||
@@ -315,7 +315,7 @@ class RegistrationHandler: | |||
approved=approved, | |||
) | |||
profile = await self.store.get_profileinfo(localpart) | |||
profile = await self.store.get_profileinfo(user) | |||
await self.user_directory_handler.handle_local_profile_change( | |||
user_id, profile | |||
) | |||
@@ -655,7 +655,9 @@ class ModuleApi: | |||
Returns: | |||
The profile information (i.e. display name and avatar URL). | |||
""" | |||
return await self._store.get_profileinfo(localpart) | |||
server_name = self._hs.hostname | |||
user_id = UserID.from_string(f"@{localpart}:{server_name}") | |||
return await self._store.get_profileinfo(user_id) | |||
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: | |||
"""Look up the threepids (email addresses and phone numbers) associated with the | |||
@@ -247,7 +247,7 @@ class Mailer: | |||
try: | |||
user_display_name = await self.store.get_profile_displayname( | |||
UserID.from_string(user_id).localpart | |||
UserID.from_string(user_id) | |||
) | |||
if user_display_name is None: | |||
user_display_name = user_id | |||
@@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet): | |||
try: | |||
filter_collection = await self.filtering.get_user_filter( | |||
user_localpart=target_user.localpart, filter_id=filter_id_int | |||
user_id=target_user, filter_id=filter_id_int | |||
) | |||
except StoreError as e: | |||
if e.code != 404: | |||
@@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet): | |||
else: | |||
try: | |||
filter_collection = await self.filtering.get_user_filter( | |||
user.localpart, filter_id | |||
user, filter_id | |||
) | |||
except StoreError as err: | |||
if err.code != 404: | |||
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore): | |||
@cached(num_args=2) | |||
async def get_user_filter( | |||
self, user_localpart: str, filter_id: Union[int, str] | |||
self, user_id: UserID, filter_id: Union[int, str] | |||
) -> JsonDict: | |||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail | |||
# with a coherent error message rather than 500 M_UNKNOWN. | |||
@@ -156,7 +156,7 @@ class FilteringWorkerStore(SQLBaseStore): | |||
def_json = await self.db_pool.simple_select_one_onecol( | |||
table="user_filters", | |||
keyvalues={"user_id": user_localpart, "filter_id": filter_id}, | |||
keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id}, | |||
retcol="filter_json", | |||
allow_none=False, | |||
desc="get_user_filter", | |||
@@ -172,15 +172,15 @@ class FilteringWorkerStore(SQLBaseStore): | |||
def _do_txn(txn: LoggingTransaction) -> int: | |||
sql = ( | |||
"SELECT filter_id FROM user_filters " | |||
"WHERE user_id = ? AND filter_json = ?" | |||
"WHERE full_user_id = ? AND filter_json = ?" | |||
) | |||
txn.execute(sql, (user_id.localpart, bytearray(def_json))) | |||
txn.execute(sql, (user_id.to_string(), bytearray(def_json))) | |||
filter_id_response = txn.fetchone() | |||
if filter_id_response is not None: | |||
return filter_id_response[0] | |||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" | |||
txn.execute(sql, (user_id.localpart,)) | |||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?" | |||
txn.execute(sql, (user_id.to_string(),)) | |||
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] | |||
if max_id is None: | |||
filter_id = 0 | |||
@@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore): | |||
return 50 | |||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: | |||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: | |||
try: | |||
profile = await self.db_pool.simple_select_one( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
keyvalues={"full_user_id": user_id.to_string()}, | |||
retcols=("displayname", "avatar_url"), | |||
desc="get_profileinfo", | |||
) | |||
@@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore): | |||
avatar_url=profile["avatar_url"], display_name=profile["displayname"] | |||
) | |||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: | |||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: | |||
return await self.db_pool.simple_select_one_onecol( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
keyvalues={"full_user_id": user_id.to_string()}, | |||
retcol="displayname", | |||
desc="get_profile_displayname", | |||
) | |||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: | |||
async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: | |||
return await self.db_pool.simple_select_one_onecol( | |||
table="profiles", | |||
keyvalues={"user_id": user_localpart}, | |||
keyvalues={"full_user_id": user_id.to_string()}, | |||
retcol="avatar_url", | |||
desc="get_profile_avatar_url", | |||
) | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
SCHEMA_VERSION = 77 # remember to update the list below when updating | |||
SCHEMA_VERSION = 78 # remember to update the list below when updating | |||
"""Represents the expectations made by the codebase about the database schema | |||
This should be incremented whenever the codebase changes its requirements on the | |||
@@ -103,6 +103,9 @@ Changes in SCHEMA_VERSION = 76: | |||
Changes in SCHEMA_VERSION = 77 | |||
- (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters | |||
Changes in SCHEMA_VERSION = 78 | |||
- Validate check (full_user_id IS NOT NULL) on tables profiles and user_filters | |||
""" | |||
@@ -0,0 +1,92 @@ | |||
# Copyright 2023 The Matrix.org Foundation C.I.C | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | |||
def run_upgrade( | |||
cur: LoggingTransaction, | |||
database_engine: BaseDatabaseEngine, | |||
config: HomeServerConfig, | |||
) -> None: | |||
""" | |||
Part 3 of a multi-step migration to drop the column `user_id` and replace it with | |||
`full_user_id`. See the database schema docs for more information on the full | |||
migration steps. | |||
""" | |||
hostname = config.server.server_name | |||
if isinstance(database_engine, PostgresEngine): | |||
# check if the constraint can be validated | |||
check_sql = """ | |||
SELECT user_id from profiles WHERE full_user_id IS NULL | |||
""" | |||
cur.execute(check_sql) | |||
res = cur.fetchall() | |||
if res: | |||
# there are rows the background job missed, finish them here before we validate the constraint | |||
process_rows_sql = """ | |||
UPDATE profiles | |||
SET full_user_id = '@' || user_id || ? | |||
WHERE user_id IN ( | |||
SELECT user_id FROM profiles WHERE full_user_id IS NULL | |||
) | |||
""" | |||
cur.execute(process_rows_sql, (f":{hostname}",)) | |||
# Now we can validate | |||
validate_sql = """ | |||
ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null | |||
""" | |||
cur.execute(validate_sql) | |||
else: | |||
# in SQLite we need to rewrite the table to add the constraint. | |||
# First drop any temporary table that might be here from a previous failed migration. | |||
cur.execute("DROP TABLE IF EXISTS temp_profiles") | |||
create_sql = """ | |||
CREATE TABLE temp_profiles ( | |||
full_user_id text NOT NULL, | |||
user_id text, | |||
displayname text, | |||
avatar_url text, | |||
UNIQUE (full_user_id), | |||
UNIQUE (user_id) | |||
) | |||
""" | |||
cur.execute(create_sql) | |||
copy_sql = """ | |||
INSERT INTO temp_profiles ( | |||
user_id, | |||
displayname, | |||
avatar_url, | |||
full_user_id) | |||
SELECT user_id, displayname, avatar_url, '@' || user_id || ':' || ? FROM profiles | |||
""" | |||
cur.execute(copy_sql, (f"{hostname}",)) | |||
drop_sql = """ | |||
DROP TABLE profiles | |||
""" | |||
cur.execute(drop_sql) | |||
rename_sql = """ | |||
ALTER TABLE temp_profiles RENAME to profiles | |||
""" | |||
cur.execute(rename_sql) |
@@ -0,0 +1,95 @@ | |||
# Copyright 2023 The Matrix.org Foundation C.I.C | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | |||
def run_upgrade( | |||
cur: LoggingTransaction, | |||
database_engine: BaseDatabaseEngine, | |||
config: HomeServerConfig, | |||
) -> None: | |||
""" | |||
Part 3 of a multi-step migration to drop the column `user_id` and replace it with | |||
`full_user_id`. See the database schema docs for more information on the full | |||
migration steps. | |||
""" | |||
hostname = config.server.server_name | |||
if isinstance(database_engine, PostgresEngine): | |||
# check if the constraint can be validated | |||
check_sql = """ | |||
SELECT user_id from user_filters WHERE full_user_id IS NULL | |||
""" | |||
cur.execute(check_sql) | |||
res = cur.fetchall() | |||
if res: | |||
# there are rows the background job missed, finish them here before we validate constraint | |||
process_rows_sql = """ | |||
UPDATE user_filters | |||
SET full_user_id = '@' || user_id || ? | |||
WHERE user_id IN ( | |||
SELECT user_id FROM user_filters WHERE full_user_id IS NULL | |||
) | |||
""" | |||
cur.execute(process_rows_sql, (f":{hostname}",)) | |||
# Now we can validate | |||
validate_sql = """ | |||
ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null | |||
""" | |||
cur.execute(validate_sql) | |||
else: | |||
cur.execute("DROP TABLE IF EXISTS temp_user_filters") | |||
create_sql = """ | |||
CREATE TABLE temp_user_filters ( | |||
full_user_id text NOT NULL, | |||
user_id text NOT NULL, | |||
filter_id bigint NOT NULL, | |||
filter_json bytea NOT NULL, | |||
UNIQUE (full_user_id), | |||
UNIQUE (user_id) | |||
) | |||
""" | |||
cur.execute(create_sql) | |||
index_sql = """ | |||
CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON | |||
temp_user_filters (user_id, filter_id) | |||
""" | |||
cur.execute(index_sql) | |||
copy_sql = """ | |||
INSERT INTO temp_user_filters ( | |||
user_id, | |||
filter_id, | |||
filter_json, | |||
full_user_id) | |||
SELECT user_id, filter_id, filter_json, '@' || user_id || ':' || ? FROM user_filters | |||
""" | |||
cur.execute(copy_sql, (f"{hostname}",)) | |||
drop_sql = """ | |||
DROP TABLE user_filters | |||
""" | |||
cur.execute(drop_sql) | |||
rename_sql = """ | |||
ALTER TABLE temp_user_filters RENAME to user_filters | |||
""" | |||
cur.execute(rename_sql) |
@@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent | |||
user_id = UserID.from_string("@test_user:test") | |||
user2_id = UserID.from_string("@test_user2:test") | |||
user_localpart = "test_user" | |||
class FilteringTestCase(unittest.HomeserverTestCase): | |||
@@ -449,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): | |||
] | |||
user_filter = self.get_success( | |||
self.filtering.get_user_filter( | |||
user_localpart=user_localpart, filter_id=filter_id | |||
) | |||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) | |||
) | |||
results = self.get_success(user_filter.filter_presence(presence_states)) | |||
@@ -479,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): | |||
] | |||
user_filter = self.get_success( | |||
self.filtering.get_user_filter( | |||
user_localpart=user_localpart + "2", filter_id=filter_id | |||
) | |||
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id) | |||
) | |||
results = self.get_success(user_filter.filter_presence(presence_states)) | |||
@@ -498,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): | |||
events = [event] | |||
user_filter = self.get_success( | |||
self.filtering.get_user_filter( | |||
user_localpart=user_localpart, filter_id=filter_id | |||
) | |||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) | |||
) | |||
results = self.get_success(user_filter.filter_room_state(events=events)) | |||
@@ -519,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): | |||
events = [event] | |||
user_filter = self.get_success( | |||
self.filtering.get_user_filter( | |||
user_localpart=user_localpart, filter_id=filter_id | |||
) | |||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) | |||
) | |||
results = self.get_success(user_filter.filter_room_state(events)) | |||
@@ -603,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): | |||
user_filter_json, | |||
( | |||
self.get_success( | |||
self.datastore.get_user_filter( | |||
user_localpart=user_localpart, filter_id=0 | |||
) | |||
self.datastore.get_user_filter(user_id=user_id, filter_id=0) | |||
) | |||
), | |||
) | |||
@@ -620,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): | |||
) | |||
filter = self.get_success( | |||
self.filtering.get_user_filter( | |||
user_localpart=user_localpart, filter_id=filter_id | |||
) | |||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) | |||
) | |||
self.assertEqual(filter.get_filter_json(), user_filter_json) | |||
@@ -80,11 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual( | |||
( | |||
self.get_success( | |||
self.store.get_profile_displayname(self.frank.localpart) | |||
) | |||
), | |||
(self.get_success(self.store.get_profile_displayname(self.frank))), | |||
"Frank Jr.", | |||
) | |||
@@ -96,11 +92,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual( | |||
( | |||
self.get_success( | |||
self.store.get_profile_displayname(self.frank.localpart) | |||
) | |||
), | |||
(self.get_success(self.store.get_profile_displayname(self.frank))), | |||
"Frank", | |||
) | |||
@@ -112,7 +104,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertIsNone( | |||
self.get_success(self.store.get_profile_displayname(self.frank.localpart)) | |||
self.get_success(self.store.get_profile_displayname(self.frank)) | |||
) | |||
def test_set_my_name_if_disabled(self) -> None: | |||
@@ -122,11 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
self.get_success(self.store.set_profile_displayname(self.frank, "Frank")) | |||
self.assertEqual( | |||
( | |||
self.get_success( | |||
self.store.get_profile_displayname(self.frank.localpart) | |||
) | |||
), | |||
(self.get_success(self.store.get_profile_displayname(self.frank))), | |||
"Frank", | |||
) | |||
@@ -201,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual( | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank))), | |||
"http://my.server/pic.gif", | |||
) | |||
@@ -215,7 +203,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual( | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank))), | |||
"http://my.server/me.png", | |||
) | |||
@@ -229,7 +217,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertIsNone( | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank))), | |||
) | |||
def test_set_my_avatar_if_disabled(self) -> None: | |||
@@ -241,7 +229,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual( | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), | |||
(self.get_success(self.store.get_profile_avatar_url(self.frank))), | |||
"http://my.server/me.png", | |||
) | |||
@@ -28,7 +28,7 @@ from synapse.module_api import ModuleApi | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, notifications, presence, profile, room | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, create_requester | |||
from synapse.types import JsonDict, UserID, create_requester | |||
from synapse.util import Clock | |||
from tests.events.test_presence_router import send_presence_update, sync_presence | |||
@@ -103,7 +103,9 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||
self.assertEqual(email["added_at"], 0) | |||
# Check that the displayname was assigned | |||
displayname = self.get_success(self.store.get_profile_displayname("bob")) | |||
displayname = self.get_success( | |||
self.store.get_profile_displayname(UserID.from_string("@bob:test")) | |||
) | |||
self.assertEqual(displayname, "Bobberino") | |||
def test_can_register_admin_user(self) -> None: | |||
@@ -46,7 +46,9 @@ class FilterTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual(channel.json_body, {"filter_id": "0"}) | |||
filter = self.get_success( | |||
self.store.get_user_filter(user_localpart="apple", filter_id=0) | |||
self.store.get_user_filter( | |||
user_id=UserID.from_string(FilterTestCase.user_id), filter_id=0 | |||
) | |||
) | |||
self.pump() | |||
self.assertEqual(filter, self.EXAMPLE_FILTER) | |||
@@ -11,6 +11,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.server import HomeServer | |||
@@ -35,18 +36,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual( | |||
"Frank", | |||
( | |||
self.get_success( | |||
self.store.get_profile_displayname(self.u_frank.localpart) | |||
) | |||
), | |||
(self.get_success(self.store.get_profile_displayname(self.u_frank))), | |||
) | |||
# test set to None | |||
self.get_success(self.store.set_profile_displayname(self.u_frank, None)) | |||
self.assertIsNone( | |||
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) | |||
self.get_success(self.store.get_profile_displayname(self.u_frank)) | |||
) | |||
def test_avatar_url(self) -> None: | |||
@@ -58,18 +55,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual( | |||
"http://my.site/here", | |||
( | |||
self.get_success( | |||
self.store.get_profile_avatar_url(self.u_frank.localpart) | |||
) | |||
), | |||
(self.get_success(self.store.get_profile_avatar_url(self.u_frank))), | |||
) | |||
# test set to None | |||
self.get_success(self.store.set_profile_avatar_url(self.u_frank, None)) | |||
self.assertIsNone( | |||
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) | |||
self.get_success(self.store.get_profile_avatar_url(self.u_frank)) | |||
) | |||
def test_profiles_bg_migration(self) -> None: | |||