當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python pyspark StringIndexer用法及代碼示例


本文簡要介紹 pyspark.ml.feature.StringIndexer 的用法。

用法:

class pyspark.ml.feature.StringIndexer(*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleInvalid='error', stringOrderType='frequencyDesc')

將標簽字符串列映射到標簽索引的 ML 列的標簽索引器。如果輸入列是數字,我們將其轉換為字符串並為字符串值建立索引。索引位於 [0, numLabels) 中。默認情況下,這是按標簽頻率排序的,因此最常見的標簽獲得索引 0。排序行為通過設置 stringOrderType 進行控製。它的默認值為“FrequencyDesc”。

1.4.0 版中的新函數。

例子

>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",
...     stringOrderType="frequencyDesc")
>>> stringIndexer.setHandleInvalid("error")
StringIndexer...
>>> model = stringIndexer.fit(stringIndDf)
>>> model.setHandleInvalid("error")
StringIndexerModel...
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
...     key=lambda x: x[0])
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
>>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels)
>>> itd = inverter.transform(td)
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
...     key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
>>> stringIndexerPath = temp_path + "/string-indexer"
>>> stringIndexer.save(stringIndexerPath)
>>> loadedIndexer = StringIndexer.load(stringIndexerPath)
>>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid()
True
>>> modelPath = temp_path + "/string-indexer-model"
>>> model.save(modelPath)
>>> loadedModel = StringIndexerModel.load(modelPath)
>>> loadedModel.labels == model.labels
True
>>> indexToStringPath = temp_path + "/index-to-string"
>>> inverter.save(indexToStringPath)
>>> loadedInverter = IndexToString.load(indexToStringPath)
>>> loadedInverter.getLabels() == inverter.getLabels()
True
>>> loadedModel.transform(stringIndDf).take(1) == model.transform(stringIndDf).take(1)
True
>>> stringIndexer.getStringOrderType()
'frequencyDesc'
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
...     stringOrderType="alphabetDesc")
>>> model = stringIndexer.fit(stringIndDf)
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
...     key=lambda x: x[0])
[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
>>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
...     inputCol="label", outputCol="indexed", handleInvalid="error")
>>> result = fromlabelsModel.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
...     key=lambda x: x[0])
[(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
>>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),
...                            Row(id=1, label1="b", label2="f"),
...                            Row(id=2, label1="c", label2="e"),
...                            Row(id=3, label1="a", label2="f"),
...                            Row(id=4, label1="a", label2="f"),
...                            Row(id=5, label1="c", label2="f")], 3)
>>> multiRowDf = spark.createDataFrame(testData)
>>> inputs = ["label1", "label2"]
>>> outputs = ["index1", "index2"]
>>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)
>>> model = stringIndexer.fit(multiRowDf)
>>> result = model.transform(multiRowDf)
>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
...     result.index2).collect()]), key=lambda x: x[0])
[(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)]
>>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], ["e", "f"]],
...     inputCols=inputs, outputCols=outputs)
>>> result = fromlabelsModel.transform(multiRowDf)
>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
...     result.index2).collect()]), key=lambda x: x[0])
[(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.0, 1.0)]

相關用法


注:本文由純淨天空篩選整理自spark.apache.org大神的英文原創作品 pyspark.ml.feature.StringIndexer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。