You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

413 regels
14 KiB

  1. # Copyright 2020 Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Tuple, Union, cast
  15. import attr
  16. from synapse.api.constants import LoginType
  17. from synapse.api.errors import StoreError
  18. from synapse.storage._base import SQLBaseStore, db_to_json
  19. from synapse.storage.database import LoggingTransaction
  20. from synapse.types import JsonDict
  21. from synapse.util import json_encoder, stringutils
  22. @attr.s(slots=True, auto_attribs=True)
  23. class UIAuthSessionData:
  24. session_id: str
  25. # The dictionary from the client root level, not the 'auth' key.
  26. clientdict: JsonDict
  27. # The URI and method the session was intiatied with. These are checked at
  28. # each stage of the authentication to ensure that the asked for operation
  29. # has not changed.
  30. uri: str
  31. method: str
  32. # A string description of the operation that the current authentication is
  33. # authorising.
  34. description: str
  35. class UIAuthWorkerStore(SQLBaseStore):
  36. """
  37. Manage user interactive authentication sessions.
  38. """
  39. async def create_ui_auth_session(
  40. self,
  41. clientdict: JsonDict,
  42. uri: str,
  43. method: str,
  44. description: str,
  45. ) -> UIAuthSessionData:
  46. """
  47. Creates a new user interactive authentication session.
  48. The session can be used to track the stages necessary to authenticate a
  49. user across multiple HTTP requests.
  50. Args:
  51. clientdict:
  52. The dictionary from the client root level, not the 'auth' key.
  53. uri:
  54. The URI this session was initiated with, this is checked at each
  55. stage of the authentication to ensure that the asked for
  56. operation has not changed.
  57. method:
  58. The method this session was initiated with, this is checked at each
  59. stage of the authentication to ensure that the asked for
  60. operation has not changed.
  61. description:
  62. A string description of the operation that the current
  63. authentication is authorising.
  64. Returns:
  65. The newly created session.
  66. Raises:
  67. StoreError if a unique session ID cannot be generated.
  68. """
  69. # The clientdict gets stored as JSON.
  70. clientdict_json = json_encoder.encode(clientdict)
  71. # autogen a session ID and try to create it. We may clash, so just
  72. # try a few times till one goes through, giving up eventually.
  73. attempts = 0
  74. while attempts < 5:
  75. session_id = stringutils.random_string(24)
  76. try:
  77. await self.db_pool.simple_insert(
  78. table="ui_auth_sessions",
  79. values={
  80. "session_id": session_id,
  81. "clientdict": clientdict_json,
  82. "uri": uri,
  83. "method": method,
  84. "description": description,
  85. "serverdict": "{}",
  86. "creation_time": self.hs.get_clock().time_msec(),
  87. },
  88. desc="create_ui_auth_session",
  89. )
  90. return UIAuthSessionData(
  91. session_id, clientdict, uri, method, description
  92. )
  93. except self.db_pool.engine.module.IntegrityError:
  94. attempts += 1
  95. raise StoreError(500, "Couldn't generate a session ID.")
  96. async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
  97. """Retrieve a UI auth session.
  98. Args:
  99. session_id: The ID of the session.
  100. Returns:
  101. A dict containing the device information.
  102. Raises:
  103. StoreError if the session is not found.
  104. """
  105. result = await self.db_pool.simple_select_one(
  106. table="ui_auth_sessions",
  107. keyvalues={"session_id": session_id},
  108. retcols=("clientdict", "uri", "method", "description"),
  109. desc="get_ui_auth_session",
  110. )
  111. return UIAuthSessionData(
  112. session_id,
  113. clientdict=db_to_json(result[0]),
  114. uri=result[1],
  115. method=result[2],
  116. description=result[3],
  117. )
  118. async def mark_ui_auth_stage_complete(
  119. self,
  120. session_id: str,
  121. stage_type: str,
  122. result: Union[str, bool, JsonDict],
  123. ) -> None:
  124. """
  125. Mark a session stage as completed.
  126. Args:
  127. session_id: The ID of the corresponding session.
  128. stage_type: The completed stage type.
  129. result: The result of the stage verification.
  130. Raises:
  131. StoreError if the session cannot be found.
  132. """
  133. # Add (or update) the results of the current stage to the database.
  134. #
  135. # Note that we need to allow for the same stage to complete multiple
  136. # times here so that registration is idempotent.
  137. try:
  138. await self.db_pool.simple_upsert(
  139. table="ui_auth_sessions_credentials",
  140. keyvalues={"session_id": session_id, "stage_type": stage_type},
  141. values={"result": json_encoder.encode(result)},
  142. desc="mark_ui_auth_stage_complete",
  143. )
  144. except self.db_pool.engine.module.IntegrityError:
  145. raise StoreError(400, "Unknown session ID: %s" % (session_id,))
  146. async def get_completed_ui_auth_stages(
  147. self, session_id: str
  148. ) -> Dict[str, Union[str, bool, JsonDict]]:
  149. """
  150. Retrieve the completed stages of a UI authentication session.
  151. Args:
  152. session_id: The ID of the session.
  153. Returns:
  154. The completed stages mapped to the result of the verification of
  155. that auth-type.
  156. """
  157. results = {}
  158. rows = cast(
  159. List[Tuple[str, str]],
  160. await self.db_pool.simple_select_list(
  161. table="ui_auth_sessions_credentials",
  162. keyvalues={"session_id": session_id},
  163. retcols=("stage_type", "result"),
  164. desc="get_completed_ui_auth_stages",
  165. ),
  166. )
  167. for stage_type, result in rows:
  168. results[stage_type] = db_to_json(result)
  169. return results
  170. async def set_ui_auth_clientdict(
  171. self, session_id: str, clientdict: JsonDict
  172. ) -> None:
  173. """
  174. Store an updated clientdict for a given session ID.
  175. Args:
  176. session_id: The ID of this session as returned from check_auth
  177. clientdict:
  178. The dictionary from the client root level, not the 'auth' key.
  179. """
  180. # The clientdict gets stored as JSON.
  181. clientdict_json = json_encoder.encode(clientdict)
  182. await self.db_pool.simple_update_one(
  183. table="ui_auth_sessions",
  184. keyvalues={"session_id": session_id},
  185. updatevalues={"clientdict": clientdict_json},
  186. desc="set_ui_auth_client_dict",
  187. )
  188. async def set_ui_auth_session_data(
  189. self, session_id: str, key: str, value: Any
  190. ) -> None:
  191. """
  192. Store a key-value pair into the sessions data associated with this
  193. request. This data is stored server-side and cannot be modified by
  194. the client.
  195. Args:
  196. session_id: The ID of this session as returned from check_auth
  197. key: The key to store the data under
  198. value: The data to store
  199. Raises:
  200. StoreError if the session cannot be found.
  201. """
  202. await self.db_pool.runInteraction(
  203. "set_ui_auth_session_data",
  204. self._set_ui_auth_session_data_txn,
  205. session_id,
  206. key,
  207. value,
  208. )
  209. def _set_ui_auth_session_data_txn(
  210. self, txn: LoggingTransaction, session_id: str, key: str, value: Any
  211. ) -> None:
  212. # Get the current value.
  213. result = self.db_pool.simple_select_one_onecol_txn(
  214. txn,
  215. table="ui_auth_sessions",
  216. keyvalues={"session_id": session_id},
  217. retcol="serverdict",
  218. )
  219. # Update it and add it back to the database.
  220. serverdict = db_to_json(result)
  221. serverdict[key] = value
  222. self.db_pool.simple_update_one_txn(
  223. txn,
  224. table="ui_auth_sessions",
  225. keyvalues={"session_id": session_id},
  226. updatevalues={"serverdict": json_encoder.encode(serverdict)},
  227. )
  228. async def get_ui_auth_session_data(
  229. self, session_id: str, key: str, default: Optional[Any] = None
  230. ) -> Any:
  231. """
  232. Retrieve data stored with set_session_data
  233. Args:
  234. session_id: The ID of this session as returned from check_auth
  235. key: The key to store the data under
  236. default: Value to return if the key has not been set
  237. Raises:
  238. StoreError if the session cannot be found.
  239. """
  240. result = await self.db_pool.simple_select_one_onecol(
  241. table="ui_auth_sessions",
  242. keyvalues={"session_id": session_id},
  243. retcol="serverdict",
  244. desc="get_ui_auth_session_data",
  245. )
  246. serverdict = db_to_json(result)
  247. return serverdict.get(key, default)
  248. async def add_user_agent_ip_to_ui_auth_session(
  249. self,
  250. session_id: str,
  251. user_agent: str,
  252. ip: str,
  253. ) -> None:
  254. """Add the given user agent / IP to the tracking table"""
  255. await self.db_pool.simple_upsert(
  256. table="ui_auth_sessions_ips",
  257. keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
  258. values={},
  259. desc="add_user_agent_ip_to_ui_auth_session",
  260. )
  261. async def get_user_agents_ips_to_ui_auth_session(
  262. self,
  263. session_id: str,
  264. ) -> List[Tuple[str, str]]:
  265. """Get the given user agents / IPs used during the ui auth process
  266. Returns:
  267. List of user_agent/ip pairs
  268. """
  269. return cast(
  270. List[Tuple[str, str]],
  271. await self.db_pool.simple_select_list(
  272. table="ui_auth_sessions_ips",
  273. keyvalues={"session_id": session_id},
  274. retcols=("user_agent", "ip"),
  275. desc="get_user_agents_ips_to_ui_auth_session",
  276. ),
  277. )
  278. async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
  279. """
  280. Remove sessions which were last used earlier than the expiration time.
  281. Args:
  282. expiration_time: The latest time that is still considered valid.
  283. This is an epoch time in milliseconds.
  284. """
  285. await self.db_pool.runInteraction(
  286. "delete_old_ui_auth_sessions",
  287. self._delete_old_ui_auth_sessions_txn,
  288. expiration_time,
  289. )
  290. def _delete_old_ui_auth_sessions_txn(
  291. self, txn: LoggingTransaction, expiration_time: int
  292. ) -> None:
  293. # Get the expired sessions.
  294. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
  295. txn.execute(sql, [expiration_time])
  296. session_ids = [r[0] for r in txn.fetchall()]
  297. # Delete the corresponding IP/user agents.
  298. self.db_pool.simple_delete_many_txn(
  299. txn,
  300. table="ui_auth_sessions_ips",
  301. column="session_id",
  302. values=session_ids,
  303. keyvalues={},
  304. )
  305. # If a registration token was used, decrement the pending counter
  306. # before deleting the session.
  307. rows = cast(
  308. List[Tuple[str]],
  309. self.db_pool.simple_select_many_txn(
  310. txn,
  311. table="ui_auth_sessions_credentials",
  312. column="session_id",
  313. iterable=session_ids,
  314. keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
  315. retcols=["result"],
  316. ),
  317. )
  318. # Get the tokens used and how much pending needs to be decremented by.
  319. token_counts: Dict[str, int] = {}
  320. for r in rows:
  321. # If registration was successfully completed, the result of the
  322. # registration token stage for that session will be True.
  323. # If a token was used to authenticate, but registration was
  324. # never completed, the result will be the token used.
  325. token = db_to_json(r[0])
  326. if isinstance(token, str):
  327. token_counts[token] = token_counts.get(token, 0) + 1
  328. # Update the `pending` counters.
  329. if len(token_counts) > 0:
  330. token_rows = cast(
  331. List[Tuple[str, int]],
  332. self.db_pool.simple_select_many_txn(
  333. txn,
  334. table="registration_tokens",
  335. column="token",
  336. iterable=list(token_counts.keys()),
  337. keyvalues={},
  338. retcols=["token", "pending"],
  339. ),
  340. )
  341. for token, pending in token_rows:
  342. new_pending = pending - token_counts[token]
  343. self.db_pool.simple_update_one_txn(
  344. txn,
  345. table="registration_tokens",
  346. keyvalues={"token": token},
  347. updatevalues={"pending": new_pending},
  348. )
  349. # Delete the corresponding completed credentials.
  350. self.db_pool.simple_delete_many_txn(
  351. txn,
  352. table="ui_auth_sessions_credentials",
  353. column="session_id",
  354. values=session_ids,
  355. keyvalues={},
  356. )
  357. # Finally, delete the sessions.
  358. self.db_pool.simple_delete_many_txn(
  359. txn,
  360. table="ui_auth_sessions",
  361. column="session_id",
  362. values=session_ids,
  363. keyvalues={},
  364. )
  365. class UIAuthStore(UIAuthWorkerStore):
  366. pass