@@ -12,19 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generic,
Iterable,
Optional,
TypeVar,
)
import attr
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import (
active_span,
start_active_span,
start_active_span_follows_from,
)
from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import opentracing
# the type of the key in the cache
KV = TypeVar("KV")
@@ -54,6 +72,20 @@ class ResponseCacheContext(Generic[KV]):
"""
@attr.s(auto_attribs=True)
class ResponseCacheEntry:
result: AbstractObservableDeferred
"""The (possibly incomplete) result of the operation.
Note that we continue to store an ObservableDeferred even after the operation
completes (rather than switching to an immediate value), since that makes it
easier to cache Failure results.
"""
opentracing_span_context: "Optional[opentracing.SpanContext]"
"""The opentracing span which generated/is generating the result"""
class ResponseCache(Generic[KV]):
"""
This caches a deferred response. Until the deferred completes it will be
@@ -63,10 +95,7 @@ class ResponseCache(Generic[KV]):
"""
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# This is poorly-named: it includes both complete and incomplete results.
# We keep complete results rather than switching to absolute values because
# that makes it easier to cache Failure results.
self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
self._result_cache: Dict[KV, ResponseCacheEntry] = {}
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
@@ -75,56 +104,63 @@ class ResponseCache(Generic[KV]):
self._metrics = register_cache("response_cache", name, self, resizable=False)
def size(self) -> int:
return len(self.pending _result_cache)
return len(self._result_cache)
def __len__(self) -> int:
return self.size()
def get(self, key: KV) -> Optional[defer.Deferred ]:
"""Look up the given key.
def keys(self) -> Iterable[KV ]:
"""Get the keys currently in the result cache
Returns a new Deferred (which also doesn't follow the synapse
logcontext rules). You will probably want to make_deferred_yieldable the result .
Returns both incomplete entries, and (if the timeout on this cache is non-zero),
complete entries which are still in the cache .
If there is no entry for the key, returns None.
Note that the returned iterator is not safe in the face of concurrent execution:
behaviour is undefined if `wrap` is called during iteration.
"""
return self._result_cache.keys()
def _get(self, key: KV) -> Optional[ResponseCacheEntry]:
"""Look up the given key.
Args:
key: key to get/set in the cache
key: key to get in the cache
Returns:
None if there is no entry for this key; otherwise a deferred which
resolves to the result.
The entry for this key, if any; else None.
"""
result = self.pending _result_cache.get(key)
if result is not None:
entry = self. _result_cache.get(key)
if entry is not None:
self._metrics.inc_hits()
return result.observe()
return entry
else:
self._metrics.inc_misses()
return None
def _set(
self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
) -> "defer.Deferred[RV]":
self,
context: ResponseCacheContext[KV],
deferred: "defer.Deferred[RV]",
opentracing_span_context: "Optional[opentracing.SpanContext]",
) -> ResponseCacheEntry:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
synapse.logging.context.run_in_background).
Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
You will probably want to make_deferred_yieldable the result.
Args:
context: Information about the cache miss
deferred: The deferred which resolves to the result.
opentracing_span_context: An opentracing span wrapping the calculation
Returns:
A new deferred which resolves to the actual resul t.
The cache entry objec t.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
key = context.cache_key
self.pending_result_cache[key] = result
entry = ResponseCacheEntry(result, opentracing_span_context)
self._result_cache[key] = entry
def on_complete(r: RV) -> RV:
# if this cache has a non-zero timeout, and the callback has not cleared
@@ -132,18 +168,18 @@ class ResponseCache(Generic[KV]):
# its removal later.
if self.timeout_sec and context.should_cache:
self.clock.call_later(
self.timeout_sec, self.pending _result_cache.pop, key, None
self.timeout_sec, self._result_cache.pop, key, None
)
else:
# otherwise, remove the result immediately.
self.pending _result_cache.pop(key, None)
self._result_cache.pop(key, None)
return r
# make sure we do this *after* adding the entry to pending_ result_cache,
# make sure we do this *after* adding the entry to result_cache,
# in case the result is already complete (in which case flipping the order would
# leave us with a stuck entry in the cache).
result.addBoth(on_complete)
return result.observe()
return entry
async def wrap(
self,
@@ -189,20 +225,41 @@ class ResponseCache(Generic[KV]):
Returns:
The result of the callback (from the cache, or otherwise)
"""
result = self. get(key)
if not result :
entry = self._ get(key)
if not entry :
logger.debug(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
context = ResponseCacheContext(cache_key=key)
if cache_context:
kwargs["cache_context"] = context
d = run_in_background(callback, *args, **kwargs)
result = self._set(context, d)
elif not isinstance(result, defer.Deferred) or result.called:
span_context: Optional[opentracing.SpanContext] = None
async def cb() -> RV:
# NB it is important that we do not `await` before setting span_context!
nonlocal span_context
with start_active_span(f"ResponseCache[{self._name}].calculate"):
span = active_span()
if span:
span_context = span.context
return await callback(*args, **kwargs)
d = run_in_background(cb)
entry = self._set(context, d, span_context)
return await make_deferred_yieldable(entry.result.observe())
result = entry.result.observe()
if result.called:
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
logger.info(
"[%s]: using incomplete cached result for [%s]", self._name, key
)
return await make_deferred_yieldable(result)
span_context = entry.opentracing_span_context
with start_active_span_follows_from(
f"ResponseCache[{self._name}].wait",
contexts=(span_context,) if span_context else (),
):
return await make_deferred_yieldable(result)