You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

194 lines
6.0 KiB

  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import threading
  15. from io import BytesIO
  16. from typing import BinaryIO, Generator, Optional, cast
  17. from unittest.mock import NonCallableMock
  18. from zope.interface import implementer
  19. from twisted.internet import defer, reactor as _reactor
  20. from twisted.internet.interfaces import IPullProducer
  21. from synapse.types import ISynapseReactor
  22. from synapse.util.file_consumer import BackgroundFileConsumer
  23. from tests import unittest
  24. reactor = cast(ISynapseReactor, _reactor)
  25. class FileConsumerTests(unittest.TestCase):
  26. @defer.inlineCallbacks
  27. def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
  28. string_file = BytesIO()
  29. consumer = BackgroundFileConsumer(string_file, reactor=reactor)
  30. try:
  31. producer = DummyPullProducer()
  32. yield producer.register_with_consumer(consumer)
  33. yield producer.write_and_wait(b"Foo")
  34. self.assertEqual(string_file.getvalue(), b"Foo")
  35. yield producer.write_and_wait(b"Bar")
  36. self.assertEqual(string_file.getvalue(), b"FooBar")
  37. finally:
  38. consumer.unregisterProducer()
  39. yield consumer.wait() # type: ignore[misc]
  40. self.assertTrue(string_file.closed)
  41. @defer.inlineCallbacks
  42. def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
  43. string_file = BlockingBytesWrite()
  44. consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
  45. try:
  46. producer = NonCallableMock(spec_set=[])
  47. consumer.registerProducer(producer, True)
  48. consumer.write(b"Foo")
  49. yield string_file.wait_for_n_writes(1) # type: ignore[misc]
  50. self.assertEqual(string_file.buffer, b"Foo")
  51. consumer.write(b"Bar")
  52. yield string_file.wait_for_n_writes(2) # type: ignore[misc]
  53. self.assertEqual(string_file.buffer, b"FooBar")
  54. finally:
  55. consumer.unregisterProducer()
  56. yield consumer.wait() # type: ignore[misc]
  57. self.assertTrue(string_file.closed)
  58. @defer.inlineCallbacks
  59. def test_push_producer_feedback(
  60. self,
  61. ) -> Generator["defer.Deferred[object]", object, None]:
  62. string_file = BlockingBytesWrite()
  63. consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
  64. try:
  65. producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
  66. resume_deferred: defer.Deferred = defer.Deferred()
  67. producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
  68. None
  69. )
  70. consumer.registerProducer(producer, True)
  71. number_writes = 0
  72. with string_file.write_lock:
  73. for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
  74. consumer.write(b"Foo")
  75. number_writes += 1
  76. producer.pauseProducing.assert_called_once()
  77. yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
  78. yield resume_deferred
  79. producer.resumeProducing.assert_called_once()
  80. finally:
  81. consumer.unregisterProducer()
  82. yield consumer.wait() # type: ignore[misc]
  83. self.assertTrue(string_file.closed)
  84. @implementer(IPullProducer)
  85. class DummyPullProducer:
  86. def __init__(self) -> None:
  87. self.consumer: Optional[BackgroundFileConsumer] = None
  88. self.deferred: "defer.Deferred[object]" = defer.Deferred()
  89. def resumeProducing(self) -> None:
  90. d = self.deferred
  91. self.deferred = defer.Deferred()
  92. d.callback(None)
  93. def stopProducing(self) -> None:
  94. raise RuntimeError("Unexpected call")
  95. def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
  96. assert self.consumer is not None
  97. d = self.deferred
  98. self.consumer.write(write_bytes)
  99. return d
  100. def register_with_consumer(
  101. self, consumer: BackgroundFileConsumer
  102. ) -> "defer.Deferred[object]":
  103. d = self.deferred
  104. self.consumer = consumer
  105. self.consumer.registerProducer(self, False)
  106. return d
  107. class BlockingBytesWrite:
  108. def __init__(self) -> None:
  109. self.buffer = b""
  110. self.closed = False
  111. self.write_lock = threading.Lock()
  112. self._notify_write_deferred: Optional[defer.Deferred] = None
  113. self._number_of_writes = 0
  114. def write(self, write_bytes: bytes) -> None:
  115. with self.write_lock:
  116. self.buffer += write_bytes
  117. self._number_of_writes += 1
  118. reactor.callFromThread(self._notify_write)
  119. def close(self) -> None:
  120. self.closed = True
  121. def _notify_write(self) -> None:
  122. "Called by write to indicate a write happened"
  123. with self.write_lock:
  124. if not self._notify_write_deferred:
  125. return
  126. d = self._notify_write_deferred
  127. self._notify_write_deferred = None
  128. d.callback(None)
  129. @defer.inlineCallbacks
  130. def wait_for_n_writes(
  131. self, n: int
  132. ) -> Generator["defer.Deferred[object]", object, None]:
  133. "Wait for n writes to have happened"
  134. while True:
  135. with self.write_lock:
  136. if n <= self._number_of_writes:
  137. return
  138. if not self._notify_write_deferred:
  139. self._notify_write_deferred = defer.Deferred()
  140. d = self._notify_write_deferred
  141. yield d