当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python pyspark StringIndexer用法及代码示例


本文简要介绍 pyspark.ml.feature.StringIndexer 的用法。

用法:

class pyspark.ml.feature.StringIndexer(*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleInvalid='error', stringOrderType='frequencyDesc')

将标签字符串列映射到标签索引的 ML 列的标签索引器。如果输入列是数字,我们将其转换为字符串并为字符串值建立索引。索引位于 [0, numLabels) 中。默认情况下,这是按标签频率排序的,因此最常见的标签获得索引 0。排序行为通过设置 stringOrderType 进行控制。它的默认值为“FrequencyDesc”。

1.4.0 版中的新函数。

例子

>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",
...     stringOrderType="frequencyDesc")
>>> stringIndexer.setHandleInvalid("error")
StringIndexer...
>>> model = stringIndexer.fit(stringIndDf)
>>> model.setHandleInvalid("error")
StringIndexerModel...
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
...     key=lambda x: x[0])
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
>>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels)
>>> itd = inverter.transform(td)
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
...     key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
>>> stringIndexerPath = temp_path + "/string-indexer"
>>> stringIndexer.save(stringIndexerPath)
>>> loadedIndexer = StringIndexer.load(stringIndexerPath)
>>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid()
True
>>> modelPath = temp_path + "/string-indexer-model"
>>> model.save(modelPath)
>>> loadedModel = StringIndexerModel.load(modelPath)
>>> loadedModel.labels == model.labels
True
>>> indexToStringPath = temp_path + "/index-to-string"
>>> inverter.save(indexToStringPath)
>>> loadedInverter = IndexToString.load(indexToStringPath)
>>> loadedInverter.getLabels() == inverter.getLabels()
True
>>> loadedModel.transform(stringIndDf).take(1) == model.transform(stringIndDf).take(1)
True
>>> stringIndexer.getStringOrderType()
'frequencyDesc'
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
...     stringOrderType="alphabetDesc")
>>> model = stringIndexer.fit(stringIndDf)
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
...     key=lambda x: x[0])
[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
>>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
...     inputCol="label", outputCol="indexed", handleInvalid="error")
>>> result = fromlabelsModel.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
...     key=lambda x: x[0])
[(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
>>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"),
...                            Row(id=1, label1="b", label2="f"),
...                            Row(id=2, label1="c", label2="e"),
...                            Row(id=3, label1="a", label2="f"),
...                            Row(id=4, label1="a", label2="f"),
...                            Row(id=5, label1="c", label2="f")], 3)
>>> multiRowDf = spark.createDataFrame(testData)
>>> inputs = ["label1", "label2"]
>>> outputs = ["index1", "index2"]
>>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)
>>> model = stringIndexer.fit(multiRowDf)
>>> result = model.transform(multiRowDf)
>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
...     result.index2).collect()]), key=lambda x: x[0])
[(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)]
>>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], ["e", "f"]],
...     inputCols=inputs, outputCols=outputs)
>>> result = fromlabelsModel.transform(multiRowDf)
>>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1,
...     result.index2).collect()]), key=lambda x: x[0])
[(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.0, 1.0)]

相关用法


注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.feature.StringIndexer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。