speedup TFLite inference in python with multiprocessing pool

…衆ロ難τιáo~ 提交于 2020-05-26 06:12:09


I was playing with tflite and observed on my multicore CPU that it is not heavily stressed during inference time. I eliminated the IO bottleneck by creating random input data with numpy beforehand (random matrices resembling images) but then tflite still doesn't utilze the full potential of the CPU.

The documentation mentions the possibility to tweak the number of used threads. However I was not able to find out how to do that in the Python API. But since I have seen people using multiple interpreter instances for different models I thought one could maybe use multiple instances of the same model and run them on different threads/processes. I have written the following short script:

import numpy as np
import os, time
import tflite_runtime.interpreter as tflite
from multiprocessing import Pool

# global, but for each process the module is loaded, so only one global var per process
interpreter = None
input_details = None
output_details = None
def init_interpreter(model_path):
    global interpreter
    global input_details
    global output_details
    interpreter = tflite.Interpreter(model_path=model_path)
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print('done init')

def do_inference(img_idx, img):
    print('Processing image %d'%img_idx)
    print('interpreter: %r' % (hex(id(interpreter)),))
    print('input_details: %r' % (hex(id(input_details)),))
    print('output_details: %r' % (hex(id(output_details)),))

    tstart = time.time()

    img = np.stack([img]*3, axis=2) # replicates layer three time for RGB
    img = np.array([img]) # create batch dimension
    interpreter.set_tensor(input_details[0]['index'], img )

    logit= interpreter.get_tensor(output_details[0]['index'])
    pred = np.argmax(logit, axis=1)[0]
    logit = list(logit[0])
    duration = time.time() - tstart 

    return logit, pred, duration

def main_par():
    optimized_graph_def_file = r'./optimized_graph.lite'

    # init model once to find out input dimensions
    interpreter_main = tflite.Interpreter(model_path=optimized_graph_def_file)
    input_details = interpreter_main.get_input_details()
    input_w, intput_h = tuple(input_details[0]['shape'][1:3])

    # pregenerate random images with values in [0,1]
    test_imgs = np.random.rand(num_test_imgs, input_w,intput_h).astype(input_details[0]['dtype'])

    scores = []
    predictions = []
    it_times = []

    tstart = time.time()
    with Pool(processes=4, initializer=init_interpreter, initargs=(optimized_graph_def_file,)) as pool:         # start 4 worker processes

        results = pool.starmap(do_inference, enumerate(test_imgs))
        scores, predictions, it_times = list(zip(*results))
    duration =time.time() - tstart

    print('Parent process time for %d images: %.2fs'%(num_test_imgs, duration))
    print('Inference time for %d images: %.2fs'%(num_test_imgs, sum(it_times)))
    print('mean time per image: %.3fs +- %.3f' % (np.mean(it_times), np.std(it_times)) )

if __name__ == '__main__':
    # main_seq()

However the memory address of the interpreter instance printed via hex(id(interpreter)) is the same for every process. The memory address of the input/output details is however different. Thus I was wondering if this way of doing it is potentially wrong even though I could experience a speedup? If so how could one achieve parallel inference with TFLite and python?

tflite_runtime version: 1.14.0 from here (the x86-64 Python 3.5 version)

python version: 3.5

