当前位置: 首页>>代码示例>>Python>>正文


Python GridSearchCV.save方法代码示例

本文整理汇总了Python中sklearn.model_selection.GridSearchCV.save方法的典型用法代码示例。如果您正苦于以下问题:Python GridSearchCV.save方法的具体用法?Python GridSearchCV.save怎么用?Python GridSearchCV.save使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sklearn.model_selection.GridSearchCV的用法示例。


在下文中一共展示了GridSearchCV.save方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from sklearn.model_selection import GridSearchCV [as 别名]
# 或者: from sklearn.model_selection.GridSearchCV import save [as 别名]

#.........这里部分代码省略.........
	elif args.model == "gradientboosting":
		if args.classify:
			model = ensemble.GradientBoostingClassifier
		else:
			model = ensemble.GradientBoostingRegressor
	elif args.model == "decisiontree":
		model = DecisionTreeRegressor
	elif args.model == "adaboost":
		model = ensemble.AdaBoostRegressor
	elif args.model == "linreg":
		model = LinearRegression
	elif args.model == "autolearn":
		printd("AutoLearn disabled as it does not work properly")
		sys.exit(-1)
		#model = AutoSklearnClassifier
		fitargs["dataset_name"] = "semeval"
	elif args.model == "nn":
		model = NNModel
		fitargs["nb_epoch"] = 10
		fitargs["batch_size"] = 32
		fitargs["verbose"] = 2
		predictargs["verbose"] = 0
	elif args.model == "None":
		printd("No Model specified, exiting")
		sys.exit(-1)
	else:
		printd("Invalid model %s" % args.model)
		sys.exit(-1)

	if args.classify:
		# Forest Classifiers do not allow non-binary labels, so we do it by sample weight instead
		byweight = issubclass(model, ensemble.forest.ForestClassifier)
		lintrainlabels = np.copy(fs.trainlabels)
		fs.discretizeLabels(byweight=byweight)
		if byweight:
			fitargs["sample_weight"] = fs.trainweights
	else:
		lintrainlabels = np.array(fs.trainlabels)
	fs.freeze()
	printd("Train labels:" + str(fs.trainlabels.shape))

	if (not args.force) and args.modelfile and os.path.exists(args.modelfile):
		if args.model == "nn":
			import keras
			model = keras.models.load_model(args.modelfile)
		else:
			model = joblib.load(args.modelfile)
	else:
		params = default_params[args.model]
		for param_name, param_value in params.items():
			try:
				pval = getattr(args, param_name)
				if pval is not None:
					params[param_name] = pval
			except AttributeError:
				pass

		if "input_dim" in params:
			# -1 for label
			params["input_dim"] = len(fs.names) - 1
		model = model(**params)

		if args.gridsearch:
			model = GridSearchCV(model, scoring=evalModel, cv=5, error_score=0,
					param_grid=param_grids[args.model], n_jobs=16, pre_dispatch="2*n_jobs", verbose=10)
		#model = Pipeline(steps=[('pca', kpca), ('dtree', dtree)])
		printd("Training")
		model.fit(fs.train, fs.trainlabels, **fitargs)
		#X_kpca = kpca.fit_transform(X)
		#dtree.fit(traindata, trainlabels)
		if args.modelfile:
			try:
				if args.model == "nn":
					model.save(args.modelfile)
				else:
					joblib.dump(model, args.modelfile)
			except Exception:
				printd("Could not save model, autolearn does not support saving")

	printd("Evaluating")
	print "Using Features: %s" % args.basefeatures
	print "Using Matchers: %s" % args.matchers
	print "Train Accuracy"
	evalData(model=model, data=fs.train, labels=lintrainlabels, classify=args.classify, obs=model.oob_prediction_, **predictargs)
	# trainobs = _

	print "Test Accuracy"
	testobs = evalData(model=model, data=fs.test, labels=fs.testlabels, classify=args.classify, **predictargs)

	if args.writematches:
		try:
			fs.data.writer
		except AttributeError:
			fs.data = dataset.Dataset.load(args.input_data, args.dataset)
		trainwriter = fs.data.writer(args.writematches + ".train")
		testwriter = fs.data.writer(args.writematches + ".test")
		for pair in fs.data.train():
			trainwriter.write(pair, Match(score=pair.label, autop=0))
		for pair, obs in izip(fs.data.test(), testobs):
			testwriter.write(pair, Match(score=obs))
开发者ID:mattea,项目名称:mattea-utils,代码行数:104,代码来源:featurize.py


注:本文中的sklearn.model_selection.GridSearchCV.save方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。