Caching of python functions with array input

Is there a standard way to implement caching for functions that take numpy arrays as input? I don’t want to rely on global variables, but for most of my workflow I work on one dataset that is a collection of arrays, and want to avoid recalculating things all the time.

Are the input arrays typically large in your use case?

yes. I’m aware that I could use tinyarray, but because @ and slicing is not implemented it is inconvenient that one has to convert back and forth to np.array.

I don’t know of any standard caching approach that would work out of the box. Let me think of your options.

  • You can implement your own version of caching e.g. by computing a hash of an array.
  • You can pass the arrays together with another unique identifier (e.g. number in your dataset, a random string that you pre-assign), and use that as a cache key.
  • (my favorite option) Instead of caching by input, use a library that manages computations. dask is likely the most feature-complete solution, but ipyparallel, and joblib (both listed in the link I shared) also allow to define your computations as a graph where some tasks depend on previous tasks. If you use such a library, you specify your computation as a graph (a DAG to be more specific), where follow-up tasks depend on outputs of prior tasks. The library then managed the passing of the outputs to further tasks.

I want to clarify what you mean by the first two options:

  • Would this require subclassing np.ndarray to define the hash?
  • This I don’t understand, as far as I see functools.lru_cache doesn’t have any options to specify how the inputs are hashed.

In both cases you wouldn’t use @lru_cache, but you would rather need to implement something on your own. Here’s a sketch implementation:

from functools import wraps
import hashlib

import numpy as np

# Own hash approach

def array_cache(function):
    cache = {}
    
    @wraps(function)
    def wrapped(array):
        
        # Not using built-in hash because it's prone to collisions.
        return cache.get(hashlib.md5(array.data.tobytes()).digest(), function(array))
    
    return wrapped

@array_cache
def square(x):
    return x @ x

a = np.arange(101)

%time square(a)
%time square(a)

# uuid approach

# Same idea, except we assign the identifier externally.
# This makes sense if e.g. the array is generated by some parameters.

def key_cache(function):
    cache = {}
    
    @wraps(function)
    def wrapped(array, uuid):
        
        # Not using built-in hash because it's prone to collisions.
        return cache.get(uuid, function(array))
    
    return wrapped

@key_cache
def square(array, array_id):
    return array**2

I think both of these are inferior to the approach of defining a job graph, especially in a parallel context.