当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python pyspark GeneralizedLinearRegression用法及代码示例


本文简要介绍 pyspark.ml.regression.GeneralizedLinearRegression 的用法。

用法:

class pyspark.ml.regression.GeneralizedLinearRegression(*, labelCol='label', featuresCol='features', predictionCol='prediction', family='gaussian', link=None, fitIntercept=True, maxIter=25, tol=1e-06, regParam=0.0, weightCol=None, solver='irls', linkPredictionCol=None, variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)

广义线性回归。

通过给出线性预测变量(链接函数)的符号说明和误差分布(族)的说明来拟合指定的广义线性模型。它支持“gaussian”, “binomial”, “poisson”, “gamma”和“tweedie”作为系列。下面列出了每个系列的有效链接函数。每个族的第一个链接函数是默认的。

  • “gaussian” -> “identity”, “log”, “inverse”

  • “binomial” -> “logit”, “probit”, “cloglog”

  • “poisson” -> “log”, “identity”, “sqrt”

  • “gamma” -> “inverse”, “identity”, “log”

  • “tweedie” -> 通过“linkPower” 指定的电源链接函数。 tweedie 系列中的默认链接功率为 1 - variancePower。

2.0.0 版中的新函数。

注意

有关更多信息,请参阅GLM 上的维基百科页面

例子

>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([
...     (1.0, Vectors.dense(0.0, 0.0)),
...     (1.0, Vectors.dense(1.0, 2.0)),
...     (2.0, Vectors.dense(0.0, 0.0)),
...     (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
>>> glr.setRegParam(0.1)
GeneralizedLinearRegression...
>>> glr.getRegParam()
0.1
>>> glr.clear(glr.regParam)
>>> glr.setMaxIter(10)
GeneralizedLinearRegression...
>>> glr.getMaxIter()
10
>>> glr.clear(glr.maxIter)
>>> model = glr.fit(df)
>>> model.setFeaturesCol("features")
GeneralizedLinearRegressionModel...
>>> model.getMaxIter()
25
>>> model.getAggregationDepth()
2
>>> transformed = model.transform(df)
>>> abs(transformed.head().prediction - 1.5) < 0.001
True
>>> abs(transformed.head().p - 1.5) < 0.001
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
>>> model.numFeatures
2
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
>>> glr.save(glr_path)
>>> glr2 = GeneralizedLinearRegression.load(glr_path)
>>> glr.getFamily() == glr2.getFamily()
True
>>> model_path = temp_path + "/glr_model"
>>> model.save(model_path)
>>> model2 = GeneralizedLinearRegressionModel.load(model_path)
>>> model.intercept == model2.intercept
True
>>> model.coefficients[0] == model2.coefficients[0]
True
>>> model.transform(df).take(1) == model2.transform(df).take(1)
True

相关用法


注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.regression.GeneralizedLinearRegression。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。