当前位置: 首页>>代码示例>>Python>>正文


Python Earth.predict_deriv方法代码示例

本文整理汇总了Python中pyearth.Earth.predict_deriv方法的典型用法代码示例。如果您正苦于以下问题:Python Earth.predict_deriv方法的具体用法?Python Earth.predict_deriv怎么用?Python Earth.predict_deriv使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pyearth.Earth的用法示例。


在下文中一共展示了Earth.predict_deriv方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_deriv

# 需要导入模块: from pyearth import Earth [as 别名]
# 或者: from pyearth.Earth import predict_deriv [as 别名]
def test_deriv():

    model = Earth(**default_params)
    model.fit(X, y)
    assert_equal(X.shape + (1,), model.predict_deriv(X).shape)
    assert_equal((X.shape[0], 1, 1), model.predict_deriv(X, variables=0).shape)
    assert_equal((X.shape[0], 1, 1), model.predict_deriv(X, variables='x0').shape)
    assert_equal((X.shape[0], 3, 1),
                 model.predict_deriv(X, variables=[1, 5, 7]).shape)
    assert_equal((X.shape[0], 0, 1), model.predict_deriv(X, variables=[]).shape)

    res_deriv = model.predict_deriv(X, variables=['x2', 'x7', 'x0', 'x1'])
    assert_equal((X.shape[0], 4, 1), res_deriv.shape)

    res_deriv = model.predict_deriv(X, variables=['x0'])
    assert_equal((X.shape[0], 1, 1), res_deriv.shape)

    assert_equal((X.shape[0], 1, 1), model.predict_deriv(X, variables=[0]).shape)
开发者ID:RPGOne,项目名称:han-solo,代码行数:20,代码来源:test_earth.py

示例2: Earth

# 需要导入模块: from pyearth import Earth [as 别名]
# 或者: from pyearth.Earth import predict_deriv [as 别名]
n = 10
X = 20 * numpy.random.uniform(size=(m, n)) - 10
y = 10*numpy.sin(X[:, 6]) + 0.25*numpy.random.normal(size=m)

# Compute the known true derivative with respect to the predictive variable
y_prime = 10*numpy.cos(X[:, 6])

# Fit an Earth model
model = Earth(max_degree=2, minspan_alpha=.5, smooth=True)
model.fit(X, y)

# Print the model
print(model.trace())
print(model.summary())

# Get the predicted values and derivatives
y_hat = model.predict(X)
y_prime_hat = model.predict_deriv(X, 'x6')

# Plot true and predicted function values and derivatives
# for the predictive variable
plt.subplot(211)
plt.plot(X[:, 6], y, 'r.')
plt.plot(X[:, 6], y_hat, 'b.')
plt.ylabel('function')
plt.subplot(212)
plt.plot(X[:, 6], y_prime, 'r.')
plt.plot(X[:, 6], y_prime_hat[:, 0], 'b.')
plt.ylabel('derivative')
plt.show()
开发者ID:Biodun,项目名称:py-earth,代码行数:32,代码来源:plot_derivatives.py

示例3: MARSInterpolant

# 需要导入模块: from pyearth import Earth [as 别名]
# 或者: from pyearth.Earth import predict_deriv [as 别名]
class MARSInterpolant(Surrogate):
    """Compute and evaluate a MARS interpolant

    MARS builds a model of the form

    .. math::

        \\hat{f}(x) = \\sum_{i=1}^{k} c_i B_i(x).

    The model is a weighted sum of basis functions :math:`B_i(x)`. Each basis
    function :math:`B_i(x)` takes one of the following three forms:

    1. a constant 1.
    2. a hinge function of the form :math:`\\max(0, x - const)` or \
       :math:`\\max(0, const - x)`. MARS automatically selects variables \
       and values of those variables for knots of the hinge functions.
    3. a product of two or more hinge functions. These basis functions c \
       an model interaction between two or more variables.

    :param dim: Number of dimensions
    :type dim: int

    :ivar dim: Number of dimensions
    :ivar num_pts: Number of points in surrogate model
    :ivar X: Point incorporated in surrogate model (num_pts x dim)
    :ivar fX: Function values in surrogate model (num_pts x 1)
    :ivar updated: True if model is up-to-date (no refit needed)
    :ivar model: Earth object
    """
    def __init__(self, dim):
        self.num_pts = 0
        self.X = np.empty([0, dim])
        self.fX = np.empty([0, 1])
        self.dim = dim
        self.updated = False

        try:
            from pyearth import Earth
            self.model = Earth()
        except ImportError as err:
            print("Failed to import pyearth")
            raise err

    def _fit(self):
        """Compute new coefficients if the MARS interpolant is not updated."""
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")  # Surpress deprecation warnings
            if self.updated is False:
                self.model.fit(self.X, self.fX)
                self.updated = True

    def predict(self, xx):
        """Evaluate the MARS interpolant at the points xx

        :param xx: Prediction points, must be of size num_pts x dim or (dim, )
        :type xx: numpy.ndarray

        :return: Prediction of size num_pts x 1
        :rtype: numpy.ndarray
        """
        self._fit()
        xx = np.atleast_2d(xx)
        return np.expand_dims(self.model.predict(xx), axis=1)

    def predict_deriv(self, xx):
        """Evaluate the derivative of the MARS interpolant at points xx

        :param xx: Prediction points, must be of size num_pts x dim or (dim, )
        :type xx: numpy.array

        :return: Derivative of the RBF interpolant at xx
        :rtype: numpy.array
        """
        self._fit()
        xx = np.expand_dims(xx, axis=0)
        dfx = self.model.predict_deriv(xx, variables=None)
        return dfx[0]
开发者ID:dme65,项目名称:pySOT,代码行数:79,代码来源:surrogate.py

示例4: MARSInterpolant

# 需要导入模块: from pyearth import Earth [as 别名]
# 或者: from pyearth.Earth import predict_deriv [as 别名]

#.........这里部分代码省略.........

    def reset(self):
        """Reset the interpolation."""
        self.nump = 0
        self.x = None
        self.fx = None
        self.updated = False

    def _alloc(self, dim):
        """Allocate storage for x, fx, rhs, and A.

        :param dim: Number of dimensions
        """
        maxp = self.maxp
        self.dim = dim
        self.x = np.zeros((maxp, dim))
        self.fx = np.zeros((maxp, 1))

    def _realloc(self, dim, extra=1):
        """Expand allocation to accommodate more points (if needed)

        :param dim: Number of dimensions
        :param extra: Number of additional points to accommodate
        """
        if self.nump == 0:
            self._alloc(dim)
        elif self.nump+extra > self.maxp:
            self.maxp = max(self.maxp*2, self.maxp+extra)
            self.x.resize((self.maxp, dim))
            self.fx.resize((self.maxp, 1))

    def get_x(self):
        """Get the list of data points

        :return: List of data points
        """
        return self.x[:self.nump, :]

    def get_fx(self):
        """Get the list of function values for the data points.

        :return: List of function values
        """
        return self.fx[:self.nump, :]

    def add_point(self, xx, fx):
        """Add a new function evaluation

        :param xx: Point to add
        :param fx: The function value of the point to add
        """
        dim = len(xx)
        self._realloc(dim)
        self.x[self.nump, :] = xx
        self.fx[self.nump, :] = fx
        self.nump += 1
        self.updated = False

    def eval(self, xx, d=None):
        """Evaluate the MARS interpolant at the point xx

        :param xx: Point where to evaluate
        :return: Value of the MARS interpolant at x
        """
        if self.updated is False:
            self.model.fit(self.x, self.fx)
        self.updated = True

        xx = np.expand_dims(xx, axis=0)
        fx = self.model.predict(xx)
        return fx[0]

    def evals(self, xx, d=None):
        """Evaluate the MARS interpolant at the points xx

        :param xx: Points where to evaluate
        :return: Values of the MARS interpolant at x
        """
        if self.updated is False:
            self.model.fit(self.x, self.fx)
        self.updated = True

        fx = np.zeros(shape=(xx.shape[0], 1))
        fx[:, 0] = self.model.predict(xx)
        return fx

    def deriv(self, x, d=None):
        """Evaluate the derivative of the MARS interpolant at x

        :param x: Data point
        :return: Derivative of the MARS interpolant at x
        """

        if self.updated is False:
            self.model.fit(self.x, self.fx)
        self.updated = True

        x = np.expand_dims(x, axis=0)
        dfx = self.model.predict_deriv(x, variables=None)
        return dfx[0]
开发者ID:NoobSajbot,项目名称:pySOT,代码行数:104,代码来源:mars_interpolant.py


注:本文中的pyearth.Earth.predict_deriv方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。