You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

282 lines
10 KiB

  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, Tuple
  15. from unittest.mock import Mock, call
  16. from twisted.internet import defer
  17. from twisted.internet.defer import CancelledError, Deferred
  18. from twisted.test.proto_helpers import MemoryReactor
  19. from synapse.server import HomeServer
  20. from synapse.storage.database import (
  21. DatabasePool,
  22. LoggingDatabaseConnection,
  23. LoggingTransaction,
  24. make_tuple_comparison_clause,
  25. )
  26. from synapse.util import Clock
  27. from tests import unittest
  28. class TupleComparisonClauseTestCase(unittest.TestCase):
  29. def test_native_tuple_comparison(self) -> None:
  30. clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
  31. self.assertEqual(clause, "(a,b) > (?,?)")
  32. self.assertEqual(args, [1, 2])
  33. class ExecuteScriptTestCase(unittest.HomeserverTestCase):
  34. """Tests for `BaseDatabaseEngine.executescript` implementations."""
  35. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  36. self.store = hs.get_datastores().main
  37. self.db_pool: DatabasePool = self.store.db_pool
  38. self.get_success(
  39. self.db_pool.runInteraction(
  40. "create",
  41. lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"),
  42. )
  43. )
  44. def test_transaction(self) -> None:
  45. """Test that all statements are run in a single transaction."""
  46. def run(conn: LoggingDatabaseConnection) -> None:
  47. cur = conn.cursor(txn_name="test_transaction")
  48. self.db_pool.engine.executescript(
  49. cur,
  50. ";".join(
  51. [
  52. "INSERT INTO foo (name) VALUES ('transaction test')",
  53. # This next statement will fail. When `executescript` is not
  54. # transactional, the previous row will be observed later.
  55. "INSERT INTO foo (name) VALUES ('transaction test')",
  56. ]
  57. ),
  58. )
  59. self.get_failure(
  60. self.db_pool.runWithConnection(run),
  61. self.db_pool.engine.module.IntegrityError,
  62. )
  63. self.assertIsNone(
  64. self.get_success(
  65. self.db_pool.simple_select_one_onecol(
  66. "foo",
  67. keyvalues={"name": "transaction test"},
  68. retcol="name",
  69. allow_none=True,
  70. )
  71. ),
  72. "executescript is not running statements inside a transaction",
  73. )
  74. def test_commit(self) -> None:
  75. """Test that the script transaction remains open and can be committed."""
  76. def run(conn: LoggingDatabaseConnection) -> None:
  77. cur = conn.cursor(txn_name="test_commit")
  78. self.db_pool.engine.executescript(
  79. cur, "INSERT INTO foo (name) VALUES ('commit test')"
  80. )
  81. cur.execute("COMMIT")
  82. self.get_success(self.db_pool.runWithConnection(run))
  83. self.assertIsNotNone(
  84. self.get_success(
  85. self.db_pool.simple_select_one_onecol(
  86. "foo",
  87. keyvalues={"name": "commit test"},
  88. retcol="name",
  89. allow_none=True,
  90. )
  91. ),
  92. )
  93. def test_rollback(self) -> None:
  94. """Test that the script transaction remains open and can be rolled back."""
  95. def run(conn: LoggingDatabaseConnection) -> None:
  96. cur = conn.cursor(txn_name="test_rollback")
  97. self.db_pool.engine.executescript(
  98. cur, "INSERT INTO foo (name) VALUES ('rollback test')"
  99. )
  100. cur.execute("ROLLBACK")
  101. self.get_success(self.db_pool.runWithConnection(run))
  102. self.assertIsNone(
  103. self.get_success(
  104. self.db_pool.simple_select_one_onecol(
  105. "foo",
  106. keyvalues={"name": "rollback test"},
  107. retcol="name",
  108. allow_none=True,
  109. )
  110. ),
  111. "executescript is not leaving the script transaction open",
  112. )
  113. class CallbacksTestCase(unittest.HomeserverTestCase):
  114. """Tests for transaction callbacks."""
  115. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  116. self.store = hs.get_datastores().main
  117. self.db_pool: DatabasePool = self.store.db_pool
  118. def _run_interaction(
  119. self, func: Callable[[LoggingTransaction], object]
  120. ) -> Tuple[Mock, Mock]:
  121. """Run the given function in a database transaction, with callbacks registered.
  122. Args:
  123. func: The function to be run in a transaction. The transaction will be
  124. retried if `func` raises an `OperationalError`.
  125. Returns:
  126. Two mocks, which were registered as an `after_callback` and an
  127. `exception_callback` respectively, on every transaction attempt.
  128. """
  129. after_callback = Mock()
  130. exception_callback = Mock()
  131. def _test_txn(txn: LoggingTransaction) -> None:
  132. txn.call_after(after_callback, 123, 456, extra=789)
  133. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  134. func(txn)
  135. try:
  136. self.get_success_or_raise(
  137. self.db_pool.runInteraction("test_transaction", _test_txn)
  138. )
  139. except Exception:
  140. pass
  141. return after_callback, exception_callback
  142. def test_after_callback(self) -> None:
  143. """Test that the after callback is called when a transaction succeeds."""
  144. after_callback, exception_callback = self._run_interaction(lambda txn: None)
  145. after_callback.assert_called_once_with(123, 456, extra=789)
  146. exception_callback.assert_not_called()
  147. def test_exception_callback(self) -> None:
  148. """Test that the exception callback is called when a transaction fails."""
  149. _test_txn = Mock(side_effect=ZeroDivisionError)
  150. after_callback, exception_callback = self._run_interaction(_test_txn)
  151. after_callback.assert_not_called()
  152. exception_callback.assert_called_once_with(987, 654, extra=321)
  153. def test_failed_retry(self) -> None:
  154. """Test that the exception callback is called for every failed attempt."""
  155. # Always raise an `OperationalError`.
  156. _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
  157. after_callback, exception_callback = self._run_interaction(_test_txn)
  158. after_callback.assert_not_called()
  159. exception_callback.assert_has_calls(
  160. [
  161. call(987, 654, extra=321),
  162. call(987, 654, extra=321),
  163. call(987, 654, extra=321),
  164. call(987, 654, extra=321),
  165. call(987, 654, extra=321),
  166. call(987, 654, extra=321),
  167. ]
  168. )
  169. self.assertEqual(exception_callback.call_count, 6) # no additional calls
  170. def test_successful_retry(self) -> None:
  171. """Test callbacks for a failed transaction followed by a successful attempt."""
  172. # Raise an `OperationalError` on the first attempt only.
  173. _test_txn = Mock(
  174. side_effect=[self.db_pool.engine.module.OperationalError, None]
  175. )
  176. after_callback, exception_callback = self._run_interaction(_test_txn)
  177. # Calling both `after_callback`s when the first attempt failed is rather
  178. # surprising (https://github.com/matrix-org/synapse/issues/12184).
  179. # Let's document the behaviour in a test.
  180. after_callback.assert_has_calls(
  181. [
  182. call(123, 456, extra=789),
  183. call(123, 456, extra=789),
  184. ]
  185. )
  186. self.assertEqual(after_callback.call_count, 2) # no additional calls
  187. exception_callback.assert_not_called()
  188. class CancellationTestCase(unittest.HomeserverTestCase):
  189. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  190. self.store = hs.get_datastores().main
  191. self.db_pool: DatabasePool = self.store.db_pool
  192. def test_after_callback(self) -> None:
  193. """Test that the after callback is called when a transaction succeeds."""
  194. d: "Deferred[None]"
  195. after_callback = Mock()
  196. exception_callback = Mock()
  197. def _test_txn(txn: LoggingTransaction) -> None:
  198. txn.call_after(after_callback, 123, 456, extra=789)
  199. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  200. d.cancel()
  201. d = defer.ensureDeferred(
  202. self.db_pool.runInteraction("test_transaction", _test_txn)
  203. )
  204. self.get_failure(d, CancelledError)
  205. after_callback.assert_called_once_with(123, 456, extra=789)
  206. exception_callback.assert_not_called()
  207. def test_exception_callback(self) -> None:
  208. """Test that the exception callback is called when a transaction fails."""
  209. d: "Deferred[None]"
  210. after_callback = Mock()
  211. exception_callback = Mock()
  212. def _test_txn(txn: LoggingTransaction) -> None:
  213. txn.call_after(after_callback, 123, 456, extra=789)
  214. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  215. d.cancel()
  216. # Simulate a retryable failure on every attempt.
  217. raise self.db_pool.engine.module.OperationalError()
  218. d = defer.ensureDeferred(
  219. self.db_pool.runInteraction("test_transaction", _test_txn)
  220. )
  221. self.get_failure(d, CancelledError)
  222. after_callback.assert_not_called()
  223. exception_callback.assert_has_calls(
  224. [
  225. call(987, 654, extra=321),
  226. call(987, 654, extra=321),
  227. call(987, 654, extra=321),
  228. call(987, 654, extra=321),
  229. call(987, 654, extra=321),
  230. call(987, 654, extra=321),
  231. ]
  232. )
  233. self.assertEqual(exception_callback.call_count, 6) # no additional calls