1import inspect
2import pickle
3from collections.abc import Callable
4from functools import update_wrapper
5from hashlib import sha256 as hash_algorithm
6from logging import getLogger
7from os import PathLike
8from pathlib import Path
9from shutil import rmtree
10from typing import Any, Generic, ParamSpec, TypeVar, cast
11
12import pandas as pd
13from platformdirs import user_cache_path
14
15logger = getLogger(__name__)
16
17P = ParamSpec("P")
18R = TypeVar("R")
19
20
[docs]
21class StashedFunction(Generic[P, R]):
22 """Callable object that wraps your original function."""
23
24 def __init__(self, base_dir: Path, f: Callable[P, R]) -> None:
25 self.f = f
26 self.signature = inspect.signature(f)
27 # Parent folder for this function's data is computed from its name and
28 # source code.
29 self.function_dir = base_dir / f.__qualname__ / digest_function(f)
30 update_wrapper(self, f)
31
[docs]
32 def path_for(self, *args: P.args, **kwargs: P.kwargs) -> Path:
33 """File where data for a function call with these arguments should be stored."""
34 cache_path = self.function_dir / digest_args(
35 self.signature.bind(*args, **kwargs)
36 )
37 return cache_path.with_suffix(".pickle")
38
[docs]
39 def clear_for(self, *args: P.args, **kwargs: P.kwargs) -> None:
40 """Delete cached data (if any exists) for this specific set of arguments."""
41 self.path_for(*args, **kwargs).unlink(missing_ok=True)
42
[docs]
43 def clear(self) -> None:
44 """Delete all cached data for this function."""
45 logger.info(f"Deleting cached data for function {self.f.__qualname__}")
46 rmtree(self.function_dir)
47
[docs]
48 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
49 """If a cached file for these arguments exist, returns the cached result.
50
51 Otherwise, calls the wrapped function, caches that result to a file, and
52 returns that result.
53 """
54 cache_path = self.path_for(*args, **kwargs)
55 logger.debug(f"Call to {self.f.__name__} will use cache path: {cache_path}")
56 # Try fetching from cache.
57 if cache_path.exists():
58 return cast(R, load(cache_path))
59 # Cache miss; fallback to actual function and cache the result
60 result = self.f(*args, **kwargs)
61 self.function_dir.mkdir(parents=True, exist_ok=True)
62 dump(result, cache_path)
63 return result
64
65
[docs]
66class Stash:
67 """An object that can be used to decorate functions to transparently cache its calls.
68
69 If the chosen directory doesn't exist, it will be created (along with its
70 parents) the first time a function call is cached to disk.
71 """
72
73 base_dir: Path
74 """Base directory for storing cached data."""
75
76 def __init__(self, base_dir: PathLike[str] | None = None) -> None:
77 """
78
79 :param base_dir: Directory for storing cached data. If the value is
80 :py:data:`None` (the default), an appropriate cache directory is
81 automatically chosen in the user's home directory. This automatically chosen
82 value can be seen using :py:attr:`Stash.base_dir`.
83
84 """
85 if base_dir is None:
86 base_dir = user_cache_path("bamboo-stash")
87 self.base_dir = Path(base_dir)
88 logger.info(f"Data will be cached in {base_dir}")
89
[docs]
90 def clear(self) -> None:
91 """Delete all cached data."""
92 logger.info(f"Deleting cached data in {self.base_dir}")
93 rmtree(self.base_dir)
94
[docs]
95 def __call__(self, f: Callable[P, R]) -> StashedFunction[P, R]:
96 """Decorator to wrap a function to cache its calls.
97
98 You wouldn't call this method explicitly; this method exists to make the
99 :py:class:`Stash` object itself callable as a decorator.
100
101 For example:
102
103 .. code:: python
104
105 from bamboo_stash import Stash
106
107 stash = Stash()
108
109 @stash # <-- This line invokes stash.__call__
110 def my_function(): ...
111 """
112 return StashedFunction(self.base_dir, f)
113
114
115def load(path: Path) -> Any:
116 with path.open("rb") as f:
117 return pickle.load(f)
118
119
120def dump(result: R, path: Path) -> None:
121 with path.open("wb") as f:
122 pickle.dump(result, f)
123
124
125def arg_to_bytes(x: Any) -> bytes:
126 """Lossily condense arbitrary value to a byte sequence."""
127 if isinstance(x, (pd.Series, pd.DataFrame)):
128 hashes = pd.util.hash_pandas_object(x)
129 return hashes.to_numpy().tobytes()
130 hashed = hash(x)
131 byte_length = (hashed.bit_length() + 7) // 8
132 return hashed.to_bytes(byte_length, signed=True, byteorder="little")
133
134
135def digest_args(binding: inspect.BoundArguments) -> str:
136 """Lossily condense function arguments to a fixed-length string."""
137 h = hash_algorithm()
138 for name, value in sorted(binding.arguments.items(), key=lambda x: x[0]):
139 h.update(name.encode())
140 h.update(arg_to_bytes(value))
141 return h.hexdigest()
142
143
144def digest_function(f: Callable[P, R]) -> str:
145 """Lossily condense function definition into a fixed-length string."""
146 return hash_algorithm(inspect.getsource(f).encode()).hexdigest()
147
148
[docs]
149def stash(f: Callable[P, R]) -> StashedFunction[P, R]:
150 """Convenience decorator for when you don't care about where the cached data is stored.
151
152 The first time this function is called, this automatically creates a
153 :py:class:`Stash` object for you with default arguments. Subsequent calls
154 will re-use that object.
155
156 The automatically created :py:class:`Stash` object is intentionally hidden
157 from you. If you need to access attributes such as
158 :py:attr:`Stash.base_dir`, you should explicitly create a :py:class:`Stash`
159 object instead.
160
161 """
162 global default_stash
163 if default_stash is None:
164 default_stash = Stash()
165 return default_stash(f)
166
167
168default_stash: Stash | None = None
169"""Default-constructed Stash that is only initialized if stash() is called."""