本文整理汇总了Python中pyspark.SparkContext.setLogLevel方法的典型用法代码示例。如果您正苦于以下问题:Python SparkContext.setLogLevel方法的具体用法?Python SparkContext.setLogLevel怎么用?Python SparkContext.setLogLevel使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类pyspark.SparkContext
的用法示例。
在下文中一共展示了SparkContext.setLogLevel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: save_data_to_db
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def save_data_to_db():
from pyspark import SparkContext, SparkConf
from pyspark.streaming import StreamingContext
conf = SparkConf().setMaster("localhost")
sc = SparkContext("local[*]", "camera_mechine_gen")
sc.setLogLevel("WARN")
data_used_by_camera_mechine_gen.drop()
path = '/3/2014-10-15'
for station in stations:
station_info = station_destinations_by_directions.find_one({"station_name":station})
if station_info == None:
continue
destinations_by_directions = station_info['destinations_by_directions']
full_path = data_dir_path+'v0/'+station+path
print full_path
func = map_anlalyser_gen(station, destinations_by_directions)
file_data = sc.textFile(full_path).map(pre_process_1).groupByKey().map(func).collect()
for i in sorted(file_data, key=lambda x:x[0]):
time = i[0]
C1_by_directions = list(i[1].iteritems())
#print station, time, C1_by_directions
data_used_by_camera_mechine_gen.insert({'station_name':station, 'time':time, 'C1_by_directions':C1_by_directions})
示例2: main
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def main():
sc = SparkContext(appName="MyApp")
sc.setLogLevel('ERROR')
# Parse data
train_labels, train_data = load_data('train.csv')
dummy_labels, test_data = load_data('test.csv', use_labels=False)
# Truncate the last 2 features of the data
for dataPoint in train_data:
len = np.size(dataPoint)
dataPoint = np.delete(dataPoint, [len - 2, len - 1])
for dataPoint in test_data:
len = np.size(dataPoint)
dataPoint = np.delete(dataPoint, [len - 2, len - 1])
# Map each data point's label to its features
train_set = reformatData(train_data, train_labels)
test_set = reformatData(test_data, dummy_labels)
# Parallelize the data
parallelized_train_set = sc.parallelize(train_set)
parallelized_test_set = sc.parallelize(test_set)
# Split the data
trainSet, validationSet = parallelized_train_set.randomSplit([0.01, 0.99], seed=42)
# Train the models
randomForestModel = RandomForest.trainClassifier(trainSet, numClasses=4, impurity='gini', categoricalFeaturesInfo={},
numTrees=750, seed=42, maxDepth=30, maxBins=32)
# Test the model
testRandomForest(randomForestModel, parallelized_test_set)
示例3: main
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def main():
sc = SparkContext(appName="MyApp")
sc.setLogLevel('ERROR')
# Parse data
train_labels, train_data = load_data('train.csv')
dummy_labels, test_data = load_data('test.csv', use_labels=False)
# Map each data point's label to its features
train_set = reformatData(train_data, train_labels)
test_set = reformatData(test_data, dummy_labels)
# Parallelize the data
parallelized_train_set = sc.parallelize(train_set)
parallelized_test_set = sc.parallelize(test_set)
# Split the data
trainSet, validationSet = parallelized_train_set.randomSplit([1.0, 0.0], seed=42)
# Train the models
decisionTreeModel = DecisionTree.trainClassifier(trainSet, numClasses=5, categoricalFeaturesInfo={},
impurity='gini', maxBins=55, maxDepth=30, minInstancesPerNode=2)
# Test the model
testDecisionTree(decisionTreeModel, parallelized_test_set)
示例4: create_spark_application
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def create_spark_application(app_name):
"""Creates and returns a Spark & SQL Context."""
conf = (SparkConf().setAppName(app_name))
spark_context = SparkContext(conf=conf)
spark_context.setLogLevel('WARN')
sql_context = SQLContext(spark_context)
return (spark_context, sql_context)
示例5: spark_context
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def spark_context(request):
"""
Pytest fixture for creating a spark context.
Args:
:param request: pytest.FixtureRequest object
"""
conf = (SparkConf().setMaster("local").setAppName("pyspark-local-testing"))
sc = SparkContext(conf=conf)
sc.setLogLevel("ERROR")
request.addfinalizer(lambda: sc.stop())
return sc
示例6: functionToCreateContext
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def functionToCreateContext():
sc = SparkContext("local[*]", "streaming_part")
sc.setLogLevel("ERROR")
ssc = StreamingContext(sc, 5)
data_from_ticket_mechine = ssc.socketTextStream("localhost", 9999)
data_from_camera_mechine = ssc.socketTextStream("localhost", 9998)
#meat
data_from_ticket_mechine.map(ticket_mechine_pre_process).updateStateByKey(updateFunction).foreachRDD(ticket_mechine_RDD_handler)
data_from_camera_mechine.map(camera_mechine_pre_process).updateStateByKey(updateFunction).foreachRDD(camera_mechine_RDD_handler)
ssc.checkpoint(checkpointDirectory) # set checkpoint directory
return ssc
示例7: save_data_to_db
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def save_data_to_db():
from pyspark import SparkContext, SparkConf
from pyspark.streaming import StreamingContext
conf = SparkConf().setMaster("localhost")
sc = SparkContext("local[*]", "tikcket_mechine_gen")
sc.setLogLevel("WARN")
sc.addFile(lib_dir+'/getDistance.py')
data_used_by_ticket_mechine_gen.drop()
path = '/3/2014-10-15'
for s in stations:
full_path = data_dir_path+'v0/'+s+path
print full_path
data_to_save = getDistance.get_one_day_group_by_time(full_path, sc)
for item in data_to_save:
data_used_by_ticket_mechine_gen.insert({'station_name':s, 'time':item[0], 'data':item[1]})
示例8: spark_context
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def spark_context(request):
# If RIAK_HOSTS is not set, use Docker to start a Riak node
if not os.environ.has_key('RIAK_HOSTS'):
docker_cli = request.getfuncargvalue('docker_cli')
host_and_port = get_host_and_port(docker_cli)
os.environ['RIAK_HOSTS'] = host_and_port
os.environ['USE_DOCKER'] = 'true'
# Start new spark context
conf = SparkConf().setMaster('local[*]').setAppName('pytest-pyspark-local-testing')
conf.set('spark.riak.connection.host', os.environ['RIAK_HOSTS'])
conf.set('spark.driver.memory', '4g')
conf.set('spark.executor.memory', '4g')
spark_context = SparkContext(conf=conf)
spark_context.setLogLevel('INFO')
pyspark_riak.riak_context(spark_context)
request.addfinalizer(lambda: spark_context.stop())
return spark_context
示例9: test
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def test():
sc = SparkContext(master='local[4]', appName='lda')
sc.setLogLevel('ERROR')
def train():
data = sc.textFile(corpus_filename).map(lambda line: Vectors.dense([float(i) for i in line.strip().split()]))
corpus = data.zipWithIndex().map(lambda x: [x[1], x[0]]).cache()
# print(corpus.take(5))
lda_model = LDA.train(rdd=corpus, maxIterations=max_iter, seed=seed, checkpointInterval=checkin_point_interval,
k=K,
optimizer=optimizer, docConcentration=alpha, topicConcentration=beta)
if os.path.exists('./ldamodel'): __import__('shutil').rmtree('./ldamodel')
lda_model.save(sc, "./ldamodel")
# train()
lda_model = LDAModel.load(sc, "./ldamodel")
# topic-word分布(未归一化的dist,每列代表一个topic)
topics = lda_model.topicsMatrix()
# for tid in range(3):
# print('Topic' + str(tid) + ':')
# for wid in range(0, lda_model.vocabSize()):
# print(' ' + str(topics[wid, tid] / sum(topics[:, tid]))) # 加一个归一化
# # print(' ' + str(topics[wid, tid]))
# topic-word按词序排列分布([词id,按权重从大到小排列], [词在主题上的权重])
topics_dist = lda_model.describeTopics()
for tid, topic in enumerate(topics_dist):
print('Topic' + str(tid) + ':' + '\n', topic)
# 文档的主题分布(mllib不能,ml才可以)
# doc_topic = lda_model
sc.stop()
示例10: main
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def main():
conf = (SparkConf()
.setMaster("local[*]")
.setAppName("compare_engine"))
sc = SparkContext(conf = conf)
sc.setLogLevel('INFO')
sc.addFile(primary)
# rdd_primary = sc.textFile(primary, minPartitions=4, use_unicode=True).distinct()
rdd_primary = sc.textFile(SparkFiles.get(primary), minPartitions=4, use_unicode=True).distinct()
rdd_primary.partitionBy(10).cache()
os.system('rm -Rf collects_*')
os.system('rm -Rf holder.txt')
rdd_secondary = sc.textFile(secondary, minPartitions=4, use_unicode=True).distinct()
rdd_secondary.partitionBy(10).cache()
primary_count = rdd_primary.count()
primary_report['count'] = primary_count
print(primary_report)
secondary_count = rdd_secondary.count()
secondary_report['count'] = secondary_count
print(secondary_report)
# Return each Primary file line/record not contained in Secondary
not_in_primary = rdd_primary.subtract(rdd_secondary)
primary_diff = not_in_primary.count()
primary_report['diff'] = primary_diff
os.system('rm -Rf collects_*.csv')
primary_dir = 'collects_{}_primary'.format(run_date)
primary_report_name = 'collects_{}_primary_report.csv'.format(run_date)
not_in_primary.coalesce(1, True).saveAsTextFile(primary_dir)
# os.system('cat collects_{}_primary/part-0000* >> collects_{}_primary_report.csv'.format(run_date, run_date))
os.system('cat {}/part-0000* >> {}'.format(primary_dir, primary_report_name))
os.system('wc -l collects_{}_primary_report.csv'.format(run_date))
# Flip Primary Vs Secondary
# Return each Secondary file line/record not contained in Primary
not_in_secondary = rdd_secondary.subtract(rdd_primary)
secondary_diff = not_in_secondary.count()
secondary_report['diff'] = secondary_diff
not_in_secondary.coalesce(1,True).saveAsTextFile('collects_{}_secondary'.format(run_date))
os.system('cat collects_{}_secondary/part-0000* >> collects_{}_secondary_report.csv'.format(run_date, run_date))
os.system('wc -l collects_{}_secondary_report.csv'.format(run_date))
process_report['primary'] = primary_report
process_report['secondary'] = secondary_report
print("=" * 100)
print('\n')
print(process_report)
print('\n')
print("=" * 100)
spark_details(sc)
sc.stop()
示例11: aggregate
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def aggregate(hdir, cond, precision, min_date, max_date):
"Collect aggregated statistics from HDFS"
start_time = time.time()
print("Aggregating {} FWJR performance data in {} matching {} from {} to {}...".format(precision.replace('y', 'i') + 'ly', hdir, cond, min_date, max_date))
conf = SparkConf().setAppName("wmarchive fwjr aggregator")
sc = SparkContext(conf=conf)
sc.setLogLevel("ERROR")
sqlContext = HiveContext(sc)
# To test the procedure in an interactive pyspark shell:
#
# 1. Open a pyspark shell with appropriate configuration with:
#
# ```
# pyspark --packages com.databricks:spark-avro_2.10:2.0.1 --driver-class-path=/usr/lib/hive/lib/* --driver-java-options=-Dspark.executor.extraClassPath=/usr/lib/hive/lib/*
# ```
# 2. Paste this:
#
# >>>
# from pyspark.sql.functions import *
# from pyspark.sql.types import *
# hdir = '/cms/wmarchive/avro/2016/06/28*'
# precision = 'day'
fwjr_df = sqlContext.read.format("com.databricks.spark.avro").load(hdir)
# <<<
# Here we process the filters given by `cond`.
# TODO: Filter by min_date and max_date and possibly just remove the `hdir` option and instead process the entire dataset, or make it optional.
fwjr_df = make_filters(fwjr_df, cond)
# 3. Paste this:
#
# >>>
# Select the data we are interested in
jobs = fwjr_df.select(
fwjr_df['meta_data.ts'].alias('timestamp'),
fwjr_df['meta_data.jobstate'],
fwjr_df['meta_data.host'],
fwjr_df['meta_data.jobtype'],
fwjr_df['task'],
fwjr_df['steps.site'].getItem(0).alias('site'), # TODO: improve
fwjr_df['steps'], # TODO: `explode` here, see below
# TODO: also select `meta_data.fwjr_id`
)
# Transfrom each record to the data we then want to group by:
# Transform timestamp to start_date and end_date with given precision,
# thus producing many jobs that have the same start_date and end_date.
# These will later be grouped by.
timestamp = jobs['timestamp']
if precision == "hour":
start_date = floor(timestamp / 3600) * 3600
end_date = start_date + 3600
elif precision == "day":
start_date = floor(timestamp / 86400) * 86400
end_date = start_date + 86400
elif precision == "week":
end_date = next_day(to_date(from_unixtime(timestamp)), 'Mon')
start_date = date_sub(end_date, 7)
start_date = to_utc_timestamp(start_date, 'UTC')
end_date = to_utc_timestamp(end_date, 'UTC')
elif precision == "month":
start_date = trunc(to_date(from_unixtime(timestamp)), 'month')
end_date = date_add(last_day(start_date), 1)
start_date = to_utc_timestamp(start_date, 'UTC')
end_date = to_utc_timestamp(end_date, 'UTC')
jobs = jobs.withColumn('start_date', start_date)
jobs = jobs.withColumn('end_date', end_date)
jobs = jobs.withColumn('timeframe_precision', lit(precision))
jobs = jobs.drop('timestamp')
# Transform `task` to task and workflow name
jobs = jobs.withColumn('taskname_components', split(jobs['task'], '/'))
jobs = jobs.withColumn('workflow', jobs['taskname_components'].getItem(1))
jobs = jobs.withColumn('task', jobs['taskname_components'].getItem(size(jobs['taskname_components'])))
jobs = jobs.drop('taskname_components')
# Extract exit code and acquisition era
stepScopeStruct = StructType([
StructField('exitCode', StringType(), True),
StructField('exitStep', StringType(), True),
StructField('acquisitionEra', StringType(), True),
])
def extract_step_scope(step_names, step_errors, step_outputs):
# TODO: improve this rather crude implementation
exitCode = None
exitStep = None
for (i, errors) in enumerate(step_errors):
if len(errors) > 0:
exitCode = errors[0].exitCode
exitStep = step_names[i]
break
acquisitionEra = None
for outputs in step_outputs:
#.........这里部分代码省略.........
示例12: SparkContext
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
from pyspark.mllib.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.param import Param, Params
from pyspark.sql import SQLContext
from pyspark import SparkContext
sc = SparkContext(appName="ML Example")
sc.setLogLevel("FATAL")
sqlContext = SQLContext(sc)
# Prepare training data from a list of (label, features) tuples.
training = sqlContext.createDataFrame([
(1.0, Vectors.dense([0.0, 1.1, 0.1])),
(0.0, Vectors.dense([2.0, 1.0, -1.0])),
(0.0, Vectors.dense([2.0, 1.3, 1.0])),
(1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"])
# Create a LogisticRegression instance. This instance is an Estimator.
lr = LogisticRegression(maxIter=10, regParam=0.01)
# Print out the parameters, documentation, and any default values.
print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
# Learn a LogisticRegression model. This uses the parameters stored in lr.
model1 = lr.fit(training)
# Since model1 is a Model (i.e., a transformer produced by an Estimator),
# we can view the parameters it used during fit().
# This prints the parameter (name: value) pairs, where names are unique IDs for this
# LogisticRegression instance.
print("Model 1 was fit using parameters: ")
print(model1.extractParamMap())
示例13: __init__
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def __init__(self, arglist, _sc = None, _sqlContext = None):
sc = SparkContext() if _sc is None else _sc
sqlContext = HiveContext(sc) if _sqlContext is None else _sqlContext
sc.setLogLevel("ERROR")
self.sqlContext = sqlContext
self.sc = sc
self._jvm = sc._jvm
from py4j.java_gateway import java_import
java_import(self._jvm, "org.tresamigos.smv.ColumnHelper")
java_import(self._jvm, "org.tresamigos.smv.SmvDFHelper")
java_import(self._jvm, "org.tresamigos.smv.dqm.*")
java_import(self._jvm, "org.tresamigos.smv.panel.*")
java_import(self._jvm, "org.tresamigos.smv.python.SmvPythonHelper")
java_import(self._jvm, "org.tresamigos.smv.SmvRunInfoCollector")
self.j_smvPyClient = self.create_smv_pyclient(arglist)
# shortcut is meant for internal use only
self.j_smvApp = self.j_smvPyClient.j_smvApp()
self.log = self.j_smvApp.log()
# AFTER app is available but BEFORE stages,
# use the dynamically configured app dir to set the source path
self.prepend_source(self.SRC_PROJECT_PATH)
# issue #429 set application name from smv config
sc._conf.setAppName(self.appName())
# user may choose a port for the callback server
gw = sc._gateway
cbsp = self.j_smvPyClient.callbackServerPort()
cbs_port = cbsp.get() if cbsp.isDefined() else gw._python_proxy_port
# check wither the port is in-use or not. Try 10 times, if all fail, error out
check_counter = 0
while(not check_socket(cbs_port) and check_counter < 10):
cbs_port += 1
check_counter += 1
if (not check_socket(cbs_port)):
raise SmvRuntimeError("Start Python callback server failed. Port {0}-{1} are all in use".format(cbs_port - check_counter, cbs_port))
# this was a workaround for py4j 0.8.2.1, shipped with spark
# 1.5.x, to prevent the callback server from hanging the
# python, and hence the java, process
from pyspark.streaming.context import _daemonize_callback_server
_daemonize_callback_server()
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
print("SMV starting Py4j callback server on port {0}".format(cbs_port))
gw._shutdown_callback_server() # in case another has already started
gw._start_callback_server(cbs_port)
gw._python_proxy_port = gw._callback_server.port
# get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
gw.jvm.SmvPythonHelper.updatePythonGatewayPort(jgws, gw._python_proxy_port)
self.repoFactory = DataSetRepoFactory(self)
self.j_smvPyClient.registerRepoFactory('Python', self.repoFactory)
# Initialize DataFrame and Column with helper methods
smv.helpers.init_helpers()
示例14: main
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def main():
"Main function"
optmgr = OptionParser()
opts = optmgr.parser.parse_args()
# setup spark/sql context to be used for communication with HDFS
sc = SparkContext(appName="phedex_br")
if not opts.yarn:
sc.setLogLevel("ERROR")
sqlContext = HiveContext(sc)
schema_def = schema()
# read given file(s) into RDD
if opts.fname:
pdf = sqlContext.read.format('com.databricks.spark.csv')\
.options(treatEmptyValuesAsNulls='true', nullValue='null')\
.load(opts.fname, schema = schema_def)
elif opts.basedir:
fromdate, todate = defDates(opts.fromdate, opts.todate)
files = getFileList(opts.basedir, fromdate, todate)
msg = "Between dates %s and %s found %d directories" % (fromdate, todate, len(files))
print msg
if not files:
return
pdf = unionAll([sqlContext.read.format('com.databricks.spark.csv')
.options(treatEmptyValuesAsNulls='true', nullValue='null')\
.load(file_path, schema = schema_def) \
for file_path in files])
else:
raise ValueError("File or directory not specified. Specify fname or basedir parameters.")
# parsing additional data (to given data adding: group name, node kind, acquisition era, data tier, now date)
groupdic, nodedic = getJoinDic()
acquisition_era_reg = r"^/[^/]*/([^/^-]*)-[^/]*/[^/]*$"
data_tier_reg = r"^/[^/]*/[^/^-]*-[^/]*/([^/]*)$"
groupf = udf(lambda x: groupdic[x], StringType())
nodef = udf(lambda x: nodedic[x], StringType())
ndf = pdf.withColumn("br_user_group", groupf(pdf.br_user_group_id)) \
.withColumn("node_kind", nodef(pdf.node_id)) \
.withColumn("now", from_unixtime(pdf.now_sec, "YYYY-MM-dd")) \
.withColumn("acquisition_era", when(regexp_extract(pdf.dataset_name, acquisition_era_reg, 1) == "",\
lit("null")).otherwise(regexp_extract(pdf.dataset_name, acquisition_era_reg, 1))) \
.withColumn("data_tier", when(regexp_extract(pdf.dataset_name, data_tier_reg, 1) == "",\
lit("null")).otherwise(regexp_extract(pdf.dataset_name, data_tier_reg, 1)))
# print dataframe schema
if opts.verbose:
ndf.show()
print("pdf data type", type(ndf))
ndf.printSchema()
# process aggregation parameters
keys = [key.lower().strip() for key in opts.keys.split(',')]
results = [result.lower().strip() for result in opts.results.split(',')]
aggregations = [agg.strip() for agg in opts.aggregations.split(',')]
order = [orde.strip() for orde in opts.order.split(',')] if opts.order else []
asc = [asce.strip() for asce in opts.asc.split(',')] if opts.order else []
filtc, filtv = opts.filt.split(":") if opts.filt else (None,None)
validateAggregationParams(keys, results, aggregations, order, filtc)
if filtc and filtv:
ndf = ndf.filter(getattr(ndf, filtc) == filtv)
# if delta aggregation is used
if DELTA in aggregations:
validateDeltaParam(opts.interval, results)
result = results[0]
#1 for all dates generate interval group dictionary
datedic = generateDateDict(fromdate, todate, opts.interval)
boundic = generateBoundDict(datedic)
max_interval = max(datedic.values())
interval_group = udf(lambda x: datedic[x], IntegerType())
interval_start = udf(lambda x: boundic[x][0], StringType())
interval_end = udf(lambda x: boundic[x][1], StringType())
#2 group data by block, node, interval and last result in the interval
ndf = ndf.select(ndf.block_name, ndf.node_name, ndf.now, getattr(ndf, result))
idf = ndf.withColumn("interval_group", interval_group(ndf.now))
win = Window.partitionBy(idf.block_name, idf.node_name, idf.interval_group).orderBy(idf.now.desc())
idf = idf.withColumn("row_number", rowNumber().over(win))
rdf = idf.where((idf.row_number == 1) & (idf.interval_group != 0))\
.withColumn(result, when(idf.now == interval_end(idf.interval_group), getattr(idf, result)).otherwise(lit(0)))
rdf = rdf.select(rdf.block_name, rdf.node_name, rdf.interval_group, getattr(rdf, result))
rdf.cache()
#3 create intervals that not exist but has minus delta
win = Window.partitionBy(idf.block_name, idf.node_name).orderBy(idf.interval_group)
adf = rdf.withColumn("interval_group_aft", lead(rdf.interval_group, 1, 0).over(win))
hdf = adf.filter(((adf.interval_group + 1) != adf.interval_group_aft) & (adf.interval_group != max_interval))\
.withColumn("interval_group", adf.interval_group + 1)\
.withColumn(result, lit(0))\
.drop(adf.interval_group_aft)
#4 join data frames
#.........这里部分代码省略.........
示例15: main
# 需要导入模块: from pyspark import SparkContext [as 别名]
# 或者: from pyspark.SparkContext import setLogLevel [as 别名]
def main():
root = os.path.dirname(os.path.abspath(__file__))
print("Digits Handwriting Recognition using Spark")
print("Root file path is = %s" %root)
conf = SparkConf().setAppName("OCR")
sc = SparkContext(conf = conf)
sc.setLogLevel("WARN")
sqlContext = SQLContext(sc)
print("loading dataset")
trainRDD = MLUtils.loadLibSVMFile(sc, root + "/dataset/svm/mnist")
testRDD = MLUtils.loadLibSVMFile(sc, root + "/dataset/svm/mnist.t")
# check if rdd support toDF
if not hasattr(trainRDD, "toDF"):
print("ERROR: RDD does not support toDF")
os.exit(1)
## convert RDDs to data frames
trainDF = trainRDD.toDF()
testDF = testRDD.toDF()
print("INFO: train dataframe count = %u" %trainDF.count())
print("INFO: test dataframe count = %u" %testDF.count())
indexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
dtc = DecisionTreeClassifier(labelCol="indexedLabel")
pipeline = Pipeline(stages=[indexer, dtc])
model = pipeline.fit(trainDF)
## train multiple depth models
variedMaxDepthModels = []
print("Create varied depth CNN models [1..8]")
for mdepth in xrange(1, 9):
start = time.time()
## maximum depth
dtc.setMaxDepth(mdepth)
## create pipeline
pipeline = Pipeline(stages = [indexer, dtc])
## create the model
model = pipeline.fit(trainDF)
## add to varied container
variedMaxDepthModels.append(model)
end = time.time()
print("trained a CNN depth of %u, duration = [%.3f] secs" %(mdepth, end - start))
print("=================================================")
## report model accuraries
evaluator = MulticlassClassificationEvaluator(labelCol="indexedLabel", metricName="precision")
## mdepth
print("Evaluate all models precision")
for mdepth in xrange(1, 9):
model = variedMaxDepthModels[mdepth - 1]
predictions = model.transform(testDF)
precision = evaluator.evaluate(predictions)
print("CNN depth = %u, precision = %.3f" %(mdepth, precision))
print("Finished processing %u digits" %testDF.count())