Is there a standard way to implement caching for functions that take numpy arrays as input? I don’t want to rely on global variables, but for most of my workflow I work on one dataset that is a collection of arrays, and want to avoid recalculating things all the time.
Are the input arrays typically large in your use case?
yes. I’m aware that I could use tinyarray
, but because @
and slicing is not implemented it is inconvenient that one has to convert back and forth to np.array
.
I don’t know of any standard caching approach that would work out of the box. Let me think of your options.
- You can implement your own version of caching e.g. by computing a hash of an array.
- You can pass the arrays together with another unique identifier (e.g. number in your dataset, a random string that you pre-assign), and use that as a cache key.
- (my favorite option) Instead of caching by input, use a library that manages computations. dask is likely the most feature-complete solution, but ipyparallel, and joblib (both listed in the link I shared) also allow to define your computations as a graph where some tasks depend on previous tasks. If you use such a library, you specify your computation as a graph (a DAG to be more specific), where follow-up tasks depend on outputs of prior tasks. The library then managed the passing of the outputs to further tasks.
I want to clarify what you mean by the first two options:
- Would this require subclassing
np.ndarray
to define the hash? - This I don’t understand, as far as I see
functools.lru_cache
doesn’t have any options to specify how the inputs are hashed.
In both cases you wouldn’t use @lru_cache
, but you would rather need to implement something on your own. Here’s a sketch implementation:
from functools import wraps
import hashlib
import numpy as np
# Own hash approach
def array_cache(function):
cache = {}
@wraps(function)
def wrapped(array):
# Not using built-in hash because it's prone to collisions.
return cache.get(hashlib.md5(array.data.tobytes()).digest(), function(array))
return wrapped
@array_cache
def square(x):
return x @ x
a = np.arange(101)
%time square(a)
%time square(a)
# uuid approach
# Same idea, except we assign the identifier externally.
# This makes sense if e.g. the array is generated by some parameters.
def key_cache(function):
cache = {}
@wraps(function)
def wrapped(array, uuid):
# Not using built-in hash because it's prone to collisions.
return cache.get(uuid, function(array))
return wrapped
@key_cache
def square(array, array_id):
return array**2
I think both of these are inferior to the approach of defining a job graph, especially in a parallel context.