You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

614 lines
21 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from enum import Enum
  16. from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
  17. import attr
  18. from canonicaljson import encode_canonical_json
  19. from synapse.api.constants import Direction
  20. from synapse.metrics.background_process_metrics import wrap_as_background_process
  21. from synapse.storage._base import db_to_json
  22. from synapse.storage.database import (
  23. DatabasePool,
  24. LoggingDatabaseConnection,
  25. LoggingTransaction,
  26. )
  27. from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
  28. from synapse.types import JsonDict
  29. from synapse.util.caches.descriptors import cached
  30. if TYPE_CHECKING:
  31. from synapse.server import HomeServer
  32. db_binary_type = memoryview
  33. logger = logging.getLogger(__name__)
  34. class DestinationSortOrder(Enum):
  35. """Enum to define the sorting method used when returning destinations."""
  36. DESTINATION = "destination"
  37. RETRY_LAST_TS = "retry_last_ts"
  38. RETTRY_INTERVAL = "retry_interval"
  39. FAILURE_TS = "failure_ts"
  40. LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
  41. @attr.s(slots=True, frozen=True, auto_attribs=True)
  42. class DestinationRetryTimings:
  43. """The current destination retry timing info for a remote server."""
  44. # The first time we tried and failed to reach the remote server, in ms.
  45. failure_ts: int
  46. # The last time we tried and failed to reach the remote server, in ms.
  47. retry_last_ts: int
  48. # How long since the last time we tried to reach the remote server before
  49. # trying again, in ms.
  50. retry_interval: int
  51. class TransactionWorkerStore(CacheInvalidationWorkerStore):
  52. def __init__(
  53. self,
  54. database: DatabasePool,
  55. db_conn: LoggingDatabaseConnection,
  56. hs: "HomeServer",
  57. ):
  58. super().__init__(database, db_conn, hs)
  59. if hs.config.worker.run_background_tasks:
  60. self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
  61. @wrap_as_background_process("cleanup_transactions")
  62. async def _cleanup_transactions(self) -> None:
  63. now = self._clock.time_msec()
  64. month_ago = now - 30 * 24 * 60 * 60 * 1000
  65. def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
  66. txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
  67. await self.db_pool.runInteraction(
  68. "_cleanup_transactions", _cleanup_transactions_txn
  69. )
  70. async def get_received_txn_response(
  71. self, transaction_id: str, origin: str
  72. ) -> Optional[Tuple[int, JsonDict]]:
  73. """For an incoming transaction from a given origin, check if we have
  74. already responded to it. If so, return the response code and response
  75. body (as a dict).
  76. Args:
  77. transaction_id
  78. origin
  79. Returns:
  80. None if we have not previously responded to this transaction or a
  81. 2-tuple of (int, dict)
  82. """
  83. return await self.db_pool.runInteraction(
  84. "get_received_txn_response",
  85. self._get_received_txn_response,
  86. transaction_id,
  87. origin,
  88. )
  89. def _get_received_txn_response(
  90. self, txn: LoggingTransaction, transaction_id: str, origin: str
  91. ) -> Optional[Tuple[int, JsonDict]]:
  92. result = self.db_pool.simple_select_one_txn(
  93. txn,
  94. table="received_transactions",
  95. keyvalues={"transaction_id": transaction_id, "origin": origin},
  96. retcols=(
  97. "transaction_id",
  98. "origin",
  99. "ts",
  100. "response_code",
  101. "response_json",
  102. "has_been_referenced",
  103. ),
  104. allow_none=True,
  105. )
  106. if result and result["response_code"]:
  107. return result["response_code"], db_to_json(result["response_json"])
  108. else:
  109. return None
  110. async def set_received_txn_response(
  111. self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
  112. ) -> None:
  113. """Persist the response we returned for an incoming transaction, and
  114. should return for subsequent transactions with the same transaction_id
  115. and origin.
  116. Args:
  117. transaction_id: The incoming transaction ID.
  118. origin: The origin server.
  119. code: The response code.
  120. response_dict: The response, to be encoded into JSON.
  121. """
  122. await self.db_pool.simple_upsert(
  123. table="received_transactions",
  124. keyvalues={
  125. "transaction_id": transaction_id,
  126. "origin": origin,
  127. },
  128. values={},
  129. insertion_values={
  130. "response_code": code,
  131. "response_json": db_binary_type(encode_canonical_json(response_dict)),
  132. "ts": self._clock.time_msec(),
  133. },
  134. desc="set_received_txn_response",
  135. )
  136. @cached(max_entries=10000)
  137. async def get_destination_retry_timings(
  138. self,
  139. destination: str,
  140. ) -> Optional[DestinationRetryTimings]:
  141. """Gets the current retry timings (if any) for a given destination.
  142. Args:
  143. destination (str)
  144. Returns:
  145. None if not retrying
  146. Otherwise a dict for the retry scheme
  147. """
  148. result = await self.db_pool.runInteraction(
  149. "get_destination_retry_timings",
  150. self._get_destination_retry_timings,
  151. destination,
  152. )
  153. return result
  154. def _get_destination_retry_timings(
  155. self, txn: LoggingTransaction, destination: str
  156. ) -> Optional[DestinationRetryTimings]:
  157. result = self.db_pool.simple_select_one_txn(
  158. txn,
  159. table="destinations",
  160. keyvalues={"destination": destination},
  161. retcols=("failure_ts", "retry_last_ts", "retry_interval"),
  162. allow_none=True,
  163. )
  164. # check we have a row and retry_last_ts is not null or zero
  165. # (retry_last_ts can't be negative)
  166. if result and result["retry_last_ts"]:
  167. return DestinationRetryTimings(**result)
  168. else:
  169. return None
  170. async def set_destination_retry_timings(
  171. self,
  172. destination: str,
  173. failure_ts: Optional[int],
  174. retry_last_ts: int,
  175. retry_interval: int,
  176. ) -> None:
  177. """Sets the current retry timings for a given destination.
  178. Both timings should be zero if retrying is no longer occurring.
  179. Args:
  180. destination
  181. failure_ts: when the server started failing (ms since epoch)
  182. retry_last_ts: time of last retry attempt in unix epoch ms
  183. retry_interval: how long until next retry in ms
  184. """
  185. await self.db_pool.runInteraction(
  186. "set_destination_retry_timings",
  187. self._set_destination_retry_timings_native,
  188. destination,
  189. failure_ts,
  190. retry_last_ts,
  191. retry_interval,
  192. db_autocommit=True, # Safe as it's a single upsert
  193. )
  194. def _set_destination_retry_timings_native(
  195. self,
  196. txn: LoggingTransaction,
  197. destination: str,
  198. failure_ts: Optional[int],
  199. retry_last_ts: int,
  200. retry_interval: int,
  201. ) -> None:
  202. # Upsert retry time interval if retry_interval is zero (i.e. we're
  203. # resetting it) or greater than the existing retry interval.
  204. #
  205. # WARNING: This is executed in autocommit, so we shouldn't add any more
  206. # SQL calls in here (without being very careful).
  207. sql = """
  208. INSERT INTO destinations (
  209. destination, failure_ts, retry_last_ts, retry_interval
  210. )
  211. VALUES (?, ?, ?, ?)
  212. ON CONFLICT (destination) DO UPDATE SET
  213. failure_ts = EXCLUDED.failure_ts,
  214. retry_last_ts = EXCLUDED.retry_last_ts,
  215. retry_interval = EXCLUDED.retry_interval
  216. WHERE
  217. EXCLUDED.retry_interval = 0
  218. OR destinations.retry_interval IS NULL
  219. OR destinations.retry_interval < EXCLUDED.retry_interval
  220. """
  221. txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
  222. self._invalidate_cache_and_stream(
  223. txn, self.get_destination_retry_timings, (destination,)
  224. )
  225. def _set_destination_retry_timings_emulated(
  226. self,
  227. txn: LoggingTransaction,
  228. destination: str,
  229. failure_ts: Optional[int],
  230. retry_last_ts: int,
  231. retry_interval: int,
  232. ) -> None:
  233. self.database_engine.lock_table(txn, "destinations")
  234. # We need to be careful here as the data may have changed from under us
  235. # due to a worker setting the timings.
  236. prev_row = self.db_pool.simple_select_one_txn(
  237. txn,
  238. table="destinations",
  239. keyvalues={"destination": destination},
  240. retcols=("failure_ts", "retry_last_ts", "retry_interval"),
  241. allow_none=True,
  242. )
  243. if not prev_row:
  244. self.db_pool.simple_insert_txn(
  245. txn,
  246. table="destinations",
  247. values={
  248. "destination": destination,
  249. "failure_ts": failure_ts,
  250. "retry_last_ts": retry_last_ts,
  251. "retry_interval": retry_interval,
  252. },
  253. )
  254. elif (
  255. retry_interval == 0
  256. or prev_row["retry_interval"] is None
  257. or prev_row["retry_interval"] < retry_interval
  258. ):
  259. self.db_pool.simple_update_one_txn(
  260. txn,
  261. "destinations",
  262. keyvalues={"destination": destination},
  263. updatevalues={
  264. "failure_ts": failure_ts,
  265. "retry_last_ts": retry_last_ts,
  266. "retry_interval": retry_interval,
  267. },
  268. )
  269. self._invalidate_cache_and_stream(
  270. txn, self.get_destination_retry_timings, (destination,)
  271. )
  272. async def store_destination_rooms_entries(
  273. self,
  274. destinations: Iterable[str],
  275. room_id: str,
  276. stream_ordering: int,
  277. ) -> None:
  278. """
  279. Updates or creates `destination_rooms` entries in batch for a single event.
  280. Args:
  281. destinations: list of destinations
  282. room_id: the room_id of the event
  283. stream_ordering: the stream_ordering of the event
  284. """
  285. await self.db_pool.simple_upsert_many(
  286. table="destinations",
  287. key_names=("destination",),
  288. key_values=[(d,) for d in destinations],
  289. value_names=[],
  290. value_values=[],
  291. desc="store_destination_rooms_entries_dests",
  292. )
  293. rows = [(destination, room_id) for destination in destinations]
  294. await self.db_pool.simple_upsert_many(
  295. table="destination_rooms",
  296. key_names=("destination", "room_id"),
  297. key_values=rows,
  298. value_names=["stream_ordering"],
  299. value_values=[(stream_ordering,)] * len(rows),
  300. desc="store_destination_rooms_entries_rooms",
  301. )
  302. async def get_destination_last_successful_stream_ordering(
  303. self, destination: str
  304. ) -> Optional[int]:
  305. """
  306. Gets the stream ordering of the PDU most-recently successfully sent
  307. to the specified destination, or None if this information has not been
  308. tracked yet.
  309. Args:
  310. destination: the destination to query
  311. """
  312. return await self.db_pool.simple_select_one_onecol(
  313. "destinations",
  314. {"destination": destination},
  315. "last_successful_stream_ordering",
  316. allow_none=True,
  317. desc="get_last_successful_stream_ordering",
  318. )
  319. async def set_destination_last_successful_stream_ordering(
  320. self, destination: str, last_successful_stream_ordering: int
  321. ) -> None:
  322. """
  323. Marks that we have successfully sent the PDUs up to and including the
  324. one specified.
  325. Args:
  326. destination: the destination we have successfully sent to
  327. last_successful_stream_ordering: the stream_ordering of the most
  328. recent successfully-sent PDU
  329. """
  330. await self.db_pool.simple_upsert(
  331. "destinations",
  332. keyvalues={"destination": destination},
  333. values={"last_successful_stream_ordering": last_successful_stream_ordering},
  334. desc="set_last_successful_stream_ordering",
  335. )
  336. async def get_catch_up_room_event_ids(
  337. self,
  338. destination: str,
  339. last_successful_stream_ordering: int,
  340. ) -> List[str]:
  341. """
  342. Returns at most 50 event IDs and their corresponding stream_orderings
  343. that correspond to the oldest events that have not yet been sent to
  344. the destination.
  345. Args:
  346. destination: the destination in question
  347. last_successful_stream_ordering: the stream_ordering of the
  348. most-recently successfully-transmitted event to the destination
  349. Returns:
  350. list of event_ids
  351. """
  352. return await self.db_pool.runInteraction(
  353. "get_catch_up_room_event_ids",
  354. self._get_catch_up_room_event_ids_txn,
  355. destination,
  356. last_successful_stream_ordering,
  357. )
  358. @staticmethod
  359. def _get_catch_up_room_event_ids_txn(
  360. txn: LoggingTransaction,
  361. destination: str,
  362. last_successful_stream_ordering: int,
  363. ) -> List[str]:
  364. q = """
  365. SELECT event_id FROM destination_rooms
  366. JOIN events USING (stream_ordering)
  367. WHERE destination = ?
  368. AND stream_ordering > ?
  369. ORDER BY stream_ordering
  370. LIMIT 50
  371. """
  372. txn.execute(
  373. q,
  374. (destination, last_successful_stream_ordering),
  375. )
  376. event_ids = [row[0] for row in txn]
  377. return event_ids
  378. async def get_catch_up_outstanding_destinations(
  379. self, after_destination: Optional[str]
  380. ) -> List[str]:
  381. """
  382. Gets at most 25 destinations which have outstanding PDUs to be caught up,
  383. and are not being backed off from
  384. Args:
  385. after_destination:
  386. If provided, all destinations must be lexicographically greater
  387. than this one.
  388. Returns:
  389. list of up to 25 destinations with outstanding catch-up.
  390. These are the lexicographically first destinations which are
  391. lexicographically greater than after_destination (if provided).
  392. """
  393. time = self.hs.get_clock().time_msec()
  394. return await self.db_pool.runInteraction(
  395. "get_catch_up_outstanding_destinations",
  396. self._get_catch_up_outstanding_destinations_txn,
  397. time,
  398. after_destination,
  399. )
  400. @staticmethod
  401. def _get_catch_up_outstanding_destinations_txn(
  402. txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
  403. ) -> List[str]:
  404. q = """
  405. SELECT DISTINCT destination FROM destinations
  406. INNER JOIN destination_rooms USING (destination)
  407. WHERE
  408. stream_ordering > last_successful_stream_ordering
  409. AND destination > ?
  410. AND (
  411. retry_last_ts IS NULL OR
  412. retry_last_ts + retry_interval < ?
  413. )
  414. ORDER BY destination
  415. LIMIT 25
  416. """
  417. txn.execute(
  418. q,
  419. (
  420. # everything is lexicographically greater than "" so this gives
  421. # us the first batch of up to 25.
  422. after_destination or "",
  423. now_time_ms,
  424. ),
  425. )
  426. destinations = [row[0] for row in txn]
  427. return destinations
  428. async def get_destinations_paginate(
  429. self,
  430. start: int,
  431. limit: int,
  432. destination: Optional[str] = None,
  433. order_by: str = DestinationSortOrder.DESTINATION.value,
  434. direction: Direction = Direction.FORWARDS,
  435. ) -> Tuple[List[JsonDict], int]:
  436. """Function to retrieve a paginated list of destinations.
  437. This will return a json list of destinations and the
  438. total number of destinations matching the filter criteria.
  439. Args:
  440. start: start number to begin the query from
  441. limit: number of rows to retrieve
  442. destination: search string in destination
  443. order_by: the sort order of the returned list
  444. direction: sort ascending or descending
  445. Returns:
  446. A tuple of a list of mappings from destination to information
  447. and a count of total destinations.
  448. """
  449. def get_destinations_paginate_txn(
  450. txn: LoggingTransaction,
  451. ) -> Tuple[List[JsonDict], int]:
  452. order_by_column = DestinationSortOrder(order_by).value
  453. if direction == Direction.BACKWARDS:
  454. order = "DESC"
  455. else:
  456. order = "ASC"
  457. args: List[object] = []
  458. where_statement = ""
  459. if destination:
  460. args.extend(["%" + destination.lower() + "%"])
  461. where_statement = "WHERE LOWER(destination) LIKE ?"
  462. sql_base = f"FROM destinations {where_statement} "
  463. sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
  464. txn.execute(sql, args)
  465. count = cast(Tuple[int], txn.fetchone())[0]
  466. sql = f"""
  467. SELECT destination, retry_last_ts, retry_interval, failure_ts,
  468. last_successful_stream_ordering
  469. {sql_base}
  470. ORDER BY {order_by_column} {order}, destination ASC
  471. LIMIT ? OFFSET ?
  472. """
  473. txn.execute(sql, args + [limit, start])
  474. destinations = self.db_pool.cursor_to_dict(txn)
  475. return destinations, count
  476. return await self.db_pool.runInteraction(
  477. "get_destinations_paginate_txn", get_destinations_paginate_txn
  478. )
  479. async def get_destination_rooms_paginate(
  480. self,
  481. destination: str,
  482. start: int,
  483. limit: int,
  484. direction: Direction = Direction.FORWARDS,
  485. ) -> Tuple[List[JsonDict], int]:
  486. """Function to retrieve a paginated list of destination's rooms.
  487. This will return a json list of rooms and the
  488. total number of rooms.
  489. Args:
  490. destination: the destination to query
  491. start: start number to begin the query from
  492. limit: number of rows to retrieve
  493. direction: sort ascending or descending by room_id
  494. Returns:
  495. A tuple of a dict of rooms and a count of total rooms.
  496. """
  497. def get_destination_rooms_paginate_txn(
  498. txn: LoggingTransaction,
  499. ) -> Tuple[List[JsonDict], int]:
  500. if direction == Direction.BACKWARDS:
  501. order = "DESC"
  502. else:
  503. order = "ASC"
  504. sql = """
  505. SELECT COUNT(*) as total_rooms
  506. FROM destination_rooms
  507. WHERE destination = ?
  508. """
  509. txn.execute(sql, [destination])
  510. count = cast(Tuple[int], txn.fetchone())[0]
  511. rooms = self.db_pool.simple_select_list_paginate_txn(
  512. txn=txn,
  513. table="destination_rooms",
  514. orderby="room_id",
  515. start=start,
  516. limit=limit,
  517. retcols=("room_id", "stream_ordering"),
  518. order_direction=order,
  519. )
  520. return rooms, count
  521. return await self.db_pool.runInteraction(
  522. "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn
  523. )
  524. async def is_destination_known(self, destination: str) -> bool:
  525. """Check if a destination is known to the server."""
  526. result = await self.db_pool.simple_select_one_onecol(
  527. table="destinations",
  528. keyvalues={"destination": destination},
  529. retcol="1",
  530. allow_none=True,
  531. desc="is_destination_known",
  532. )
  533. return bool(result)