Optimize the Kernel parameters of RBF kernel for GPR in scikit-learn using internally supported optimizers

微笑、不失礼 提交于 2019-12-03 21:41:20

There are three primary problems here:

  1. The objective function that is being optimized is the rosenbrock function which is a test function for optimization purposes. It needs to be a cost function to be optimized based on the kernel parameters, internally for the GaussianProcessRegressor this is the log-marginal-likelihood and can be passed to the optimizer as a parameter.
  2. The log-marginal-likelihood optimizer internally needs to be maximized. See section 1.7.1 here. Scipy least squares minimizes the objective function, so you will likely need to minimize the inverse of the objective function.
  3. The formatting of the optimizer that is being passed into GaussianProcessRegressor, it needs to be passed in the format specified under the 'optimizer' parameter in the docs.

As a partially working example,ignoring the kernel definition to emphasize the optimizer:

import numpy as np
from scipy.optimize import minimize,least_squares
from sklearn.gaussian_process import GaussianProcessRegressor

def trust_region_optimizer(obj_func, initial_theta, bounds):
    trust_region_method = least_squares(1/obj_func,initial_theta,bounds,method='trf')
    return (trust_region_method.x,trust_region_method.fun)

X=np.random.random((10,4))
y=np.random.random((10,1))
gp = GaussianProcessRegressor(optimizer = trust_region_optimizer, alpha =1.2, n_restarts_optimizer=10)
gp.fit(X, y)

The scipy optimizers return a results object, using the minimization of the rosenbrock test function as an example:

from scipy.optimize import least_squares,rosen
res=least_squares(rosen,np.array([0,0]),method='trf')

As shown above, the optimized values can be accessed using:

res.x

and the resulting value of the function to be minimized:

res.fun

which is what the 'fun' parameter represents. However now that the optimizer is working internally, you will need to access the resulting function value from scikit-learn:

gp.log_marginal_likelihood_value_ 
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!