Using functools.lru_cache on functions with constant but non-hashable objects

这一生的挚爱 提交于 2021-02-16 04:47:09

问题


Is it possible to use functools.lru_cache for caching a partial function created by functools.partial?

My problem is a function that takes hashable parameters and contant, non-hashable objects such as NumPy arrays.

Consider this toy example:

import numpy as np
from functools import lru_cache, partial

def foo(key, array):
    print('%s:' % key, array)
a = np.array([1,2,3])

Since NumPy arrays are not hashable, this will not work:

@lru_cache(maxsize=None)
def foo(key, array):
    print('%s:' % key, array)
foo(1, a)

As expected you get following error:

/Users/ch/miniconda/envs/sci34/lib/python3.4/functools.py in __init__(self, tup, hash)
    349     def __init__(self, tup, hash=hash):
    350         self[:] = tup
--> 351         self.hashvalue = hash(tup)
    352 
    353     def __hash__(self):

TypeError: unhashable type: 'numpy.ndarray'

So my next idea was to use functools.partial to get rid of the NumPy array (which is constant anyway)

pfoo = partial(foo, array=a)
pfoo(2)

So now I have a function that only takes hashable arguments, and should be perfect for lru_cache. But is it possible to use lru_cache in this situation? I cannot use it as a wrapping function instead of the @lru_cache decorator, can I?

Is there a clever way to solve this?


回答1:


As the array is constant you can use a wrapper around the actual lru cached function and simply pass the key value to it:

from functools import lru_cache, partial
import numpy as np


def lru_wrapper(array=None):
    @lru_cache(maxsize=None)
    def foo(key):
        return '%s:' % key, array
    return foo


arr = np.array([1, 2, 3])
func = lru_wrapper(array=arr)

for x in [0, 0, 1, 2, 2, 1, 2, 0]:
    print (func(x))

print (func.cache_info())

Outputs:

('0:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
CacheInfo(hits=5, misses=3, maxsize=None, currsize=3)



回答2:


Here is an example of how to use lru_cache with functools.partial:

from functools import lru_cache, partial
import numpy as np


def foo(key, array):
    return '%s:' % key, array


arr = np.array([1, 2, 3])
pfoo = partial(foo, array=arr)
func = lru_cache(maxsize=None)(pfoo)

for x in [0, 0, 1, 2, 2, 1, 2, 0]:
    print(func(x))

print(func.cache_info())

Output:

('0:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
CacheInfo(hits=5, misses=3, maxsize=None, currsize=3)

This is more concise than solution of @AshwiniChaudhary, and also uses the functools.partial following the OP's requirement.


P.S.: This solution was adapted from Applying functools.lru_cache to lambda



来源:https://stackoverflow.com/questions/37609772/using-functools-lru-cache-on-functions-with-constant-but-non-hashable-objects

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!