keras/scikit-learn: using fit_generator() with cross validation

后端 未结 1 1405
难免孤独
难免孤独 2021-01-03 07:00

Is it possible to use Keras\'s scikit-learn API together with fit_generator() method? Or use another way to yield batches for training? I\'m using SciPy\'s spar

1条回答
  •  有刺的猬
    2021-01-03 07:49

    Actually you can use a sparse matrix as input to Keras with a generator. Here is my version that worked on a previous project:

    > class KerasClassifier(KerasClassifier):
    >     """ adds sparse matrix handling using batch generator
    >     """
    >     
    >     def fit(self, x, y, **kwargs):
    >         """ adds sparse matrix handling """
    >         if not issparse(x):
    >             return super().fit(x, y, **kwargs)
    >         
    >         ############ adapted from KerasClassifier.fit   ######################   
    >         if self.build_fn is None:
    >             self.model = self.__call__(**self.filter_sk_params(self.__call__))
    >         elif not isinstance(self.build_fn, types.FunctionType):
    >             self.model = self.build_fn(
    >                 **self.filter_sk_params(self.build_fn.__call__))
    >         else:
    >             self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
    > 
    >         loss_name = self.model.loss
    >         if hasattr(loss_name, '__name__'):
    >             loss_name = loss_name.__name__
    >         if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
    >             y = to_categorical(y)
    >         ### fit => fit_generator
    >         fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit_generator))
    >         fit_args.update(kwargs)
    >         ############################################################
    >         self.model.fit_generator(
    >                     self.get_batch(x, y, self.sk_params["batch_size"]),
    >                                         samples_per_epoch=x.shape[0],
    >                                         **fit_args)                      
    >         return self                               
    > 
    >     def get_batch(self, x, y=None, batch_size=32):
    >         """ batch generator to enable sparse input """
    >         index = np.arange(x.shape[0])
    >         start = 0
    >         while True:
    >             if start == 0 and y is not None:
    >                 np.random.shuffle(index)
    >             batch = index[start:start+batch_size]
    >             if y is not None:
    >                 yield x[batch].toarray(), y[batch]
    >             else:
    >                 yield x[batch].toarray()
    >             start += batch_size
    >             if start >= x.shape[0]:
    >                 start = 0
    >   
    >     def predict_proba(self, x):
    >         """ adds sparse matrix handling """
    >         if not issparse(x):
    >             return super().predict_proba(x)
    >             
    >         preds = self.model.predict_generator(
    >                     self.get_batch(x, None, self.sk_params["batch_size"]), 
    >                                                val_samples=x.shape[0])
    >         return preds
    

    0 讨论(0)
提交回复
热议问题