keras + scikit-learn wrapper, appears to hang when GridSearchCV with n_jobs >1

后端 未结 3 1200
北海茫月
北海茫月 2021-02-09 22:07

UPDATE: I have to re-write this question as after some investigation I realise that this is a different problem.

Context: running keras in a gridsearch

相关标签:
3条回答
  • 2021-02-09 22:34

    TLDR Answer: You can't because your Keras model can't be serialized, and serialization is needed for parallelizing in Python with joblib.

    This problem is much detailed here: https://www.neuraxle.org/stable/scikit-learn_problems_solutions.html#problem-you-can-t-parallelize-nor-save-pipelines-using-steps-that-can-t-be-serialized-as-is-by-joblib

    The solution to parallelize your code is to make your Keras estimator serializable. This can be done using savers as described at the link above.

    If you're lucky enough to be using TensorFlow v2's prebuilt Keras module, the following practical code sample will reveal to be useful to you as you'd practically just need to take the code and modify it with yours:

    • https://github.com/guillaume-chevalier/seq2seq-signal-prediction

    In this example, all the saving and loading code is all pre-written for you using Neuraxle-TensorFlow, and this makes it parallelizeable if you use Neuraxle's AutoML methods (e.g.: Neuraxle's grid search and Neuraxle's own parallelism things).

    0 讨论(0)
  • 2021-02-09 22:38

    I know this is a late answer, but I dealt with this problem too and it really slowed me down not being able to run what is essentially trivially-parallelizable code. The issue is indeed with the tensorflow session. If a session in created in the parent process before GridSearchCV.fit(), it will hang!

    The solution for me was to keep all session/graph creation code restricted to the KerasClassifer class and the model creation function i passed to it.

    Also what Felipe said about the memory is true, you will want to restrict the memory usage of TF in either the model creation function or a subclass of KerasClassifier.

    Related info:

    • Session hang issue with python multiprocessing
    • Keras + Tensorflow and Multiprocessing in Python
    0 讨论(0)
  • 2021-02-09 22:41

    Are you using a GPU? If so, you can't have multiple threads running each variation of the params because they won't be able to share the GPU.

    Here's a full example on how to use keras, sklearn wrappers in a Pipeline with GridsearchCV: Pipeline with a Keras Model

    If you really want to have multiple jobs in the GridSearchCV, you can try to limit the GPU fraction used by each job (e.g. if each job only allocates 0.5 of the available GPU memory, you can run 2 jobs simultaneously)

    See these issues:

    • Limit the resource usage for tensorflow backend

    • GPU memory fraction does not work in keras 2.0.9 but it works in 2.0.8

    0 讨论(0)
提交回复
热议问题