You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

221 line
7.4 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from synapse.storage._base import SQLBaseStore
  16. from synapse.storage.database import (
  17. DatabasePool,
  18. LoggingDatabaseConnection,
  19. LoggingTransaction,
  20. )
  21. from synapse.storage.databases.main.roommember import ProfileInfo
  22. from synapse.storage.engines import PostgresEngine
  23. from synapse.types import JsonDict, UserID
  24. if TYPE_CHECKING:
  25. from synapse.server import HomeServer
  26. class ProfileWorkerStore(SQLBaseStore):
  27. def __init__(
  28. self,
  29. database: DatabasePool,
  30. db_conn: LoggingDatabaseConnection,
  31. hs: "HomeServer",
  32. ):
  33. super().__init__(database, db_conn, hs)
  34. self.server_name: str = hs.hostname
  35. self.database_engine = database.engine
  36. self.db_pool.updates.register_background_index_update(
  37. "profiles_full_user_id_key_idx",
  38. index_name="profiles_full_user_id_key",
  39. table="profiles",
  40. columns=["full_user_id"],
  41. unique=True,
  42. )
  43. self.db_pool.updates.register_background_update_handler(
  44. "populate_full_user_id_profiles", self.populate_full_user_id_profiles
  45. )
  46. async def populate_full_user_id_profiles(
  47. self, progress: JsonDict, batch_size: int
  48. ) -> int:
  49. """
  50. Background update to populate the column `full_user_id` of the table
  51. profiles from entries in the column `user_local_part` of the same table
  52. """
  53. lower_bound_id = progress.get("lower_bound_id", "")
  54. def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
  55. sql = """
  56. SELECT user_id FROM profiles
  57. WHERE user_id > ?
  58. ORDER BY user_id
  59. LIMIT 1 OFFSET 1000
  60. """
  61. txn.execute(sql, (lower_bound_id,))
  62. res = txn.fetchone()
  63. if res:
  64. upper_bound_id = res[0]
  65. return upper_bound_id
  66. else:
  67. return None
  68. def _process_batch(
  69. txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
  70. ) -> None:
  71. sql = """
  72. UPDATE profiles
  73. SET full_user_id = '@' || user_id || ?
  74. WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
  75. """
  76. txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
  77. def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
  78. sql = """
  79. UPDATE profiles
  80. SET full_user_id = '@' || user_id || ?
  81. WHERE ? < user_id AND full_user_id IS NULL
  82. """
  83. txn.execute(
  84. sql,
  85. (
  86. f":{self.server_name}",
  87. lower_bound_id,
  88. ),
  89. )
  90. if isinstance(self.database_engine, PostgresEngine):
  91. sql = """
  92. ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
  93. """
  94. txn.execute(sql)
  95. upper_bound_id = await self.db_pool.runInteraction(
  96. "populate_full_user_id_profiles", _get_last_id
  97. )
  98. if upper_bound_id is None:
  99. await self.db_pool.runInteraction(
  100. "populate_full_user_id_profiles", _final_batch, lower_bound_id
  101. )
  102. await self.db_pool.updates._end_background_update(
  103. "populate_full_user_id_profiles"
  104. )
  105. return 1
  106. await self.db_pool.runInteraction(
  107. "populate_full_user_id_profiles",
  108. _process_batch,
  109. lower_bound_id,
  110. upper_bound_id,
  111. )
  112. progress["lower_bound_id"] = upper_bound_id
  113. await self.db_pool.runInteraction(
  114. "populate_full_user_id_profiles",
  115. self.db_pool.updates._background_update_progress_txn,
  116. "populate_full_user_id_profiles",
  117. progress,
  118. )
  119. return 50
  120. async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
  121. profile = await self.db_pool.simple_select_one(
  122. table="profiles",
  123. keyvalues={"full_user_id": user_id.to_string()},
  124. retcols=("displayname", "avatar_url"),
  125. desc="get_profileinfo",
  126. allow_none=True,
  127. )
  128. if profile is None:
  129. # no match
  130. return ProfileInfo(None, None)
  131. return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
  132. async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
  133. return await self.db_pool.simple_select_one_onecol(
  134. table="profiles",
  135. keyvalues={"full_user_id": user_id.to_string()},
  136. retcol="displayname",
  137. desc="get_profile_displayname",
  138. )
  139. async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
  140. return await self.db_pool.simple_select_one_onecol(
  141. table="profiles",
  142. keyvalues={"full_user_id": user_id.to_string()},
  143. retcol="avatar_url",
  144. desc="get_profile_avatar_url",
  145. )
  146. async def create_profile(self, user_id: UserID) -> None:
  147. user_localpart = user_id.localpart
  148. await self.db_pool.simple_insert(
  149. table="profiles",
  150. values={"user_id": user_localpart, "full_user_id": user_id.to_string()},
  151. desc="create_profile",
  152. )
  153. async def set_profile_displayname(
  154. self, user_id: UserID, new_displayname: Optional[str]
  155. ) -> None:
  156. """
  157. Set the display name of a user.
  158. Args:
  159. user_id: The user's ID.
  160. new_displayname: The new display name. If this is None, the user's display
  161. name is removed.
  162. """
  163. user_localpart = user_id.localpart
  164. await self.db_pool.simple_upsert(
  165. table="profiles",
  166. keyvalues={"user_id": user_localpart},
  167. values={
  168. "displayname": new_displayname,
  169. "full_user_id": user_id.to_string(),
  170. },
  171. desc="set_profile_displayname",
  172. )
  173. async def set_profile_avatar_url(
  174. self, user_id: UserID, new_avatar_url: Optional[str]
  175. ) -> None:
  176. """
  177. Set the avatar of a user.
  178. Args:
  179. user_id: The user's ID.
  180. new_avatar_url: The new avatar URL. If this is None, the user's avatar is
  181. removed.
  182. """
  183. user_localpart = user_id.localpart
  184. await self.db_pool.simple_upsert(
  185. table="profiles",
  186. keyvalues={"user_id": user_localpart},
  187. values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
  188. desc="set_profile_avatar_url",
  189. )
  190. class ProfileStore(ProfileWorkerStore):
  191. pass