本文简要介绍
pyspark.ml.feature.StringIndexer
的用法。用法:
class pyspark.ml.feature.StringIndexer(*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, handleInvalid='error', stringOrderType='frequencyDesc')
将标签字符串列映射到标签索引的 ML 列的标签索引器。如果输入列是数字,我们将其转换为字符串并为字符串值建立索引。索引位于 [0, numLabels) 中。默认情况下,这是按标签频率排序的,因此最常见的标签获得索引 0。排序行为通过设置
stringOrderType
进行控制。它的默认值为“FrequencyDesc”。1.4.0 版中的新函数。
例子:
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", ... stringOrderType="frequencyDesc") >>> stringIndexer.setHandleInvalid("error") StringIndexer... >>> model = stringIndexer.fit(stringIndDf) >>> model.setHandleInvalid("error") StringIndexerModel... >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels) >>> itd = inverter.transform(td) >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) >>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid() True >>> modelPath = temp_path + "/string-indexer-model" >>> model.save(modelPath) >>> loadedModel = StringIndexerModel.load(modelPath) >>> loadedModel.labels == model.labels True >>> indexToStringPath = temp_path + "/index-to-string" >>> inverter.save(indexToStringPath) >>> loadedInverter = IndexToString.load(indexToStringPath) >>> loadedInverter.getLabels() == inverter.getLabels() True >>> loadedModel.transform(stringIndDf).take(1) == model.transform(stringIndDf).take(1) True >>> stringIndexer.getStringOrderType() 'frequencyDesc' >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error", ... stringOrderType="alphabetDesc") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"], ... inputCol="label", outputCol="indexed", handleInvalid="error") >>> result = fromlabelsModel.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)] >>> testData = sc.parallelize([Row(id=0, label1="a", label2="e"), ... Row(id=1, label1="b", label2="f"), ... Row(id=2, label1="c", label2="e"), ... Row(id=3, label1="a", label2="f"), ... Row(id=4, label1="a", label2="f"), ... Row(id=5, label1="c", label2="f")], 3) >>> multiRowDf = spark.createDataFrame(testData) >>> inputs = ["label1", "label2"] >>> outputs = ["index1", "index2"] >>> stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs) >>> model = stringIndexer.fit(multiRowDf) >>> result = model.transform(multiRowDf) >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1, ... result.index2).collect()]), key=lambda x: x[0]) [(0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)] >>> fromlabelsModel = StringIndexerModel.from_arrays_of_labels([["a", "b", "c"], ["e", "f"]], ... inputCols=inputs, outputCols=outputs) >>> result = fromlabelsModel.transform(multiRowDf) >>> sorted(set([(i[0], i[1], i[2]) for i in result.select(result.id, result.index1, ... result.index2).collect()]), key=lambda x: x[0]) [(0, 0.0, 0.0), (1, 1.0, 1.0), (2, 2.0, 0.0), (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 2.0, 1.0)]
相关用法
- Python pyspark StructType用法及代码示例
- Python pyspark StreamingQueryManager.get用法及代码示例
- Python pyspark StructField用法及代码示例
- Python pyspark StreamingQueryManager.resetTerminated用法及代码示例
- Python pyspark StreamingKMeansModel用法及代码示例
- Python pyspark StructType.fieldNames用法及代码示例
- Python pyspark StreamingQueryManager.active用法及代码示例
- Python pyspark StructType.add用法及代码示例
- Python pyspark StreamingQuery.explain用法及代码示例
- Python pyspark StopWordsRemover用法及代码示例
- Python pyspark Statistics.corr用法及代码示例
- Python pyspark StandardScaler用法及代码示例
- Python pyspark Statistics.kolmogorovSmirnovTest用法及代码示例
- Python pyspark Statistics.colStats用法及代码示例
- Python pyspark Statistics.chiSqTest用法及代码示例
- Python pyspark Series.asof用法及代码示例
- Python pyspark Series.to_frame用法及代码示例
- Python pyspark Series.rsub用法及代码示例
- Python pyspark Series.mod用法及代码示例
- Python pyspark Series.str.join用法及代码示例
- Python pyspark Series.str.startswith用法及代码示例
- Python pyspark Series.dt.is_quarter_end用法及代码示例
- Python pyspark Series.dropna用法及代码示例
- Python pyspark Series.sub用法及代码示例
- Python pyspark Series.sum用法及代码示例
注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.feature.StringIndexer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。