當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python pyspark GeneralizedLinearRegression用法及代碼示例

本文簡要介紹 pyspark.ml.regression.GeneralizedLinearRegression 的用法。

用法:

class pyspark.ml.regression.GeneralizedLinearRegression(*, labelCol='label', featuresCol='features', predictionCol='prediction', family='gaussian', link=None, fitIntercept=True, maxIter=25, tol=1e-06, regParam=0.0, weightCol=None, solver='irls', linkPredictionCol=None, variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)

廣義線性回歸。

通過給出線性預測變量(鏈接函數)的符號說明和誤差分布(族)的說明來擬合指定的廣義線性模型。它支持“gaussian”, “binomial”, “poisson”, “gamma”和“tweedie”作為係列。下麵列出了每個係列的有效鏈接函數。每個族的第一個鏈接函數是默認的。

  • “gaussian” -> “identity”, “log”, “inverse”

  • “binomial” -> “logit”, “probit”, “cloglog”

  • “poisson” -> “log”, “identity”, “sqrt”

  • “gamma” -> “inverse”, “identity”, “log”

  • “tweedie” -> 通過“linkPower” 指定的電源鏈接函數。 tweedie 係列中的默認鏈接功率為 1 - variancePower。

2.0.0 版中的新函數。

注意

有關更多信息,請參閱GLM 上的維基百科頁麵

例子

>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([
...     (1.0, Vectors.dense(0.0, 0.0)),
...     (1.0, Vectors.dense(1.0, 2.0)),
...     (2.0, Vectors.dense(0.0, 0.0)),
...     (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
>>> glr.setRegParam(0.1)
GeneralizedLinearRegression...
>>> glr.getRegParam()
0.1
>>> glr.clear(glr.regParam)
>>> glr.setMaxIter(10)
GeneralizedLinearRegression...
>>> glr.getMaxIter()
10
>>> glr.clear(glr.maxIter)
>>> model = glr.fit(df)
>>> model.setFeaturesCol("features")
GeneralizedLinearRegressionModel...
>>> model.getMaxIter()
25
>>> model.getAggregationDepth()
2
>>> transformed = model.transform(df)
>>> abs(transformed.head().prediction - 1.5) < 0.001
True
>>> abs(transformed.head().p - 1.5) < 0.001
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
>>> model.numFeatures
2
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
>>> glr.save(glr_path)
>>> glr2 = GeneralizedLinearRegression.load(glr_path)
>>> glr.getFamily() == glr2.getFamily()
True
>>> model_path = temp_path + "/glr_model"
>>> model.save(model_path)
>>> model2 = GeneralizedLinearRegressionModel.load(model_path)
>>> model.intercept == model2.intercept
True
>>> model.coefficients[0] == model2.coefficients[0]
True
>>> model.transform(df).take(1) == model2.transform(df).take(1)
True

相關用法


注:本文由純淨天空篩選整理自spark.apache.org大神的英文原創作品 pyspark.ml.regression.GeneralizedLinearRegression。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。