当前位置: 首页>>代码示例>>Golang>>正文


Golang base.InstancesTrainTestSplit函数代码示例

本文整理汇总了Golang中github.com/sjwhitworth/golearn/base.InstancesTrainTestSplit函数的典型用法代码示例。如果您正苦于以下问题:Golang InstancesTrainTestSplit函数的具体用法?Golang InstancesTrainTestSplit怎么用?Golang InstancesTrainTestSplit使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了InstancesTrainTestSplit函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Golang代码示例。

示例1: TestPruning

func TestPruning(testEnv *testing.T) {
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}
	trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
	filt := filters.NewChiMergeFilter(inst, 0.90)
	filt.AddAllNumericAttributes()
	filt.Build()
	fmt.Println(testData)
	filt.Run(testData)
	filt.Run(trainData)
	root := NewRandomTree(2)
	fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
	root.Fit(fittrainData)
	root.Prune(fittestData)
	fmt.Println(root)
	predictions := root.Predict(testData)
	fmt.Println(predictions)
	confusionMat := eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(confusionMat)
	fmt.Println(eval.GetMacroPrecision(confusionMat))
	fmt.Println(eval.GetMacroRecall(confusionMat))
	fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:25,代码来源:tree_test.go

示例2: main

func main() {
	// Load in a dataset, with headers. Header attributes will be stored.
	// Think of instances as a Data Frame structure in R or Pandas.
	// You can also create instances from scratch.
	rawData, err := base.ParseCSVToInstances("datasets/iris.csv", false)
	if err != nil {
		panic(err)
	}

	// Print a pleasant summary of your data.
	fmt.Println(rawData)

	//Initialises a new KNN classifier
	cls := knn.NewKnnClassifier("euclidean", 2)

	//Do a training-test split
	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
	cls.Fit(trainData)

	//Calculates the Euclidean distance and returns the most popular label
	predictions := cls.Predict(testData)
	fmt.Println(predictions)

	// Prints precision/recall metrics
	confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(confusionMat))
}
开发者ID:raghavkgarg,项目名称:gotutorial,代码行数:30,代码来源:ml1.go

示例3: TestRandomForest1

func TestRandomForest1(testEnv *testing.T) {
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	rand.Seed(time.Now().UnixNano())
	trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
	filt := filters.NewChiMergeFilter(inst, 0.90)
	for _, a := range base.NonClassFloatAttributes(inst) {
		filt.AddAttribute(a)
	}
	filt.Train()
	trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
	testDataf := base.NewLazilyFilteredInstances(testData, filt)
	rf := new(BaggedModel)
	for i := 0; i < 10; i++ {
		rf.AddModel(trees.NewRandomTree(2))
	}
	rf.Fit(trainDataf)
	fmt.Println(rf)
	predictions := rf.Predict(testDataf)
	fmt.Println(predictions)
	confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
	fmt.Println(confusionMat)
	fmt.Println(eval.GetMacroPrecision(confusionMat))
	fmt.Println(eval.GetMacroRecall(confusionMat))
	fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:Gudym,项目名称:golearn,代码行数:29,代码来源:bagging_test.go

示例4: main

func main() {

	var tree base.Classifier

	rand.Seed(time.Now().UTC().UnixNano())

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.99)
	filt.AddAllNumericAttributes()
	filt.Build()
	filt.Run(iris)

	// Create a 60-40 training-test split
	insts := base.InstancesTrainTestSplit(iris, 0.60)

	//
	// First up, use ID3
	//
	tree = trees.NewID3DecisionTree(0.6)
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	tree.Fit(insts[0])

	// Generate predictions
	predictions := tree.Predict(insts[1])

	// Evaluate
	fmt.Println("ID3 Performance")
	cf := eval.GetConfusionMatrix(insts[1], predictions)
	fmt.Println(eval.GetSummary(cf))

	//
	// Next up, Random Trees
	//

	// Consider two randomly-chosen attributes
	tree = trees.NewRandomTree(2)
	tree.Fit(insts[0])
	predictions = tree.Predict(insts[1])
	fmt.Println("RandomTree Performance")
	cf = eval.GetConfusionMatrix(insts[1], predictions)
	fmt.Println(eval.GetSummary(cf))

	//
	// Finally, Random Forests
	//
	tree = ensemble.NewRandomForest(100, 3)
	tree.Fit(insts[0])
	predictions = tree.Predict(insts[1])
	fmt.Println("RandomForest Performance")
	cf = eval.GetConfusionMatrix(insts[1], predictions)
	fmt.Println(eval.GetSummary(cf))
}
开发者ID:24hours,项目名称:golearn,代码行数:60,代码来源:trees.go

示例5: TestPredict

func TestPredict(t *testing.T) {

	a := NewAveragePerceptron(10, 1.2, 0.5, 0.3)

	if a == nil {

		t.Errorf("Unable to create average perceptron")
	}

	absPath, _ := filepath.Abs("../examples/datasets/house-votes-84.csv")
	rawData, err := base.ParseCSVToInstances(absPath, true)
	if err != nil {
		t.Fail()
	}

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.5)
	a.Fit(trainData)

	if a.trained == false {
		t.Errorf("Perceptron was not trained")
	}

	predictions := a.Predict(testData)
	cf, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		t.Errorf("Couldn't get confusion matrix: %s", err)
		t.Fail()
	}
	fmt.Println(evaluation.GetSummary(cf))
	fmt.Println(trainData)
	fmt.Println(testData)
	if evaluation.GetAccuracy(cf) < 0.65 {
		t.Errorf("Perceptron not trained correctly")
	}
}
开发者ID:CTLife,项目名称:golearn,代码行数:35,代码来源:average_test.go

示例6: main

func main() {

	var tree base.Classifier

	rand.Seed(44111342)

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("/home/kralli/go/src/github.com/sjwhitworth/golearn/examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.999)
	for _, a := range base.NonClassFloatAttributes(iris) {
		filt.AddAttribute(a)
	}
	filt.Train()
	irisf := base.NewLazilyFilteredInstances(iris, filt)

	// Create a 60-40 training-test split
	//testData
	trainData, _ := base.InstancesTrainTestSplit(iris, 0.60)

	findBestSplit(trainData)

	//fmt.Println(trainData)
	//fmt.Println(testData)

	fmt.Println(tree)
	fmt.Println(irisf)
}
开发者ID:krallistic,项目名称:go_stuff,代码行数:32,代码来源:cart_tree.go

示例7: TestRandomForest1

func TestRandomForest1(testEnv *testing.T) {
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	rand.Seed(time.Now().UnixNano())
	insts := base.InstancesTrainTestSplit(inst, 0.6)
	filt := filters.NewChiMergeFilter(inst, 0.90)
	filt.AddAllNumericAttributes()
	filt.Build()
	filt.Run(insts[1])
	filt.Run(insts[0])
	rf := new(BaggedModel)
	for i := 0; i < 10; i++ {
		rf.AddModel(trees.NewRandomTree(2))
	}
	rf.Fit(insts[0])
	fmt.Println(rf)
	predictions := rf.Predict(insts[1])
	fmt.Println(predictions)
	confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
	fmt.Println(confusionMat)
	fmt.Println(eval.GetMacroPrecision(confusionMat))
	fmt.Println(eval.GetMacroRecall(confusionMat))
	fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:24hours,项目名称:golearn,代码行数:27,代码来源:bagging_test.go

示例8: TestLinearRegression

func TestLinearRegression(t *testing.T) {
	lr := NewLinearRegression()

	rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
	if err != nil {
		t.Fatal(err)
	}

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
	err = lr.Fit(trainData)
	if err != nil {
		t.Fatal(err)
	}

	predictions, err := lr.Predict(testData)
	if err != nil {
		t.Fatal(err)
	}

	_, rows := predictions.Size()

	for i := 0; i < rows; i++ {
		fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
	}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:25,代码来源:linear_regression_test.go

示例9: TestLinearRegression

func TestLinearRegression(t *testing.T) {
	Convey("Doing a  linear regression", t, func() {
		lr := NewLinearRegression()

		Convey("With no training data", func() {
			Convey("Predicting", func() {
				testData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
				So(err, ShouldBeNil)

				_, err = lr.Predict(testData)

				Convey("Should result in a NoTrainingDataError", func() {
					So(err, ShouldEqual, NoTrainingDataError)
				})

			})
		})

		Convey("With not enough training data", func() {
			trainingDatum, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
			So(err, ShouldBeNil)

			Convey("Fitting", func() {
				err = lr.Fit(trainingDatum)

				Convey("Should result in a NotEnoughDataError", func() {
					So(err, ShouldEqual, NotEnoughDataError)
				})
			})
		})

		Convey("With sufficient training data", func() {
			instances, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
			So(err, ShouldBeNil)
			trainData, testData := base.InstancesTrainTestSplit(instances, 0.1)

			Convey("Fitting and Predicting", func() {
				err := lr.Fit(trainData)
				So(err, ShouldBeNil)

				predictions, err := lr.Predict(testData)
				So(err, ShouldBeNil)

				Convey("It makes reasonable predictions", func() {
					_, rows := predictions.Size()

					for i := 0; i < rows; i++ {
						actualValue, _ := strconv.ParseFloat(base.GetClass(testData, i), 64)
						expectedValue, _ := strconv.ParseFloat(base.GetClass(predictions, i), 64)

						So(actualValue, ShouldAlmostEqual, expectedValue, actualValue*0.05)
					}
				})
			})
		})
	})
}
开发者ID:CTLife,项目名称:golearn,代码行数:57,代码来源:linear_regression_test.go

示例10: Fit

// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error {
	if t.PruneSplit > 0.001 {
		trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
		t.Root = InferID3Tree(trainData, t.Rule)
		t.Root.Prune(testData)
	} else {
		t.Root = InferID3Tree(on, t.Rule)
	}
	return nil
}
开发者ID:tanduong,项目名称:golearn,代码行数:11,代码来源:id3.go

示例11: TestRandomTreeClassificationWithoutDiscretisation

func TestRandomTreeClassificationWithoutDiscretisation(t *testing.T) {
	Convey("Predictions on filtered data with a Random Tree", t, func() {
		instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
		So(err, ShouldBeNil)

		trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)

		verifyTreeClassification(trainData, testData)
	})
}
开发者ID:CTLife,项目名称:golearn,代码行数:10,代码来源:tree_test.go

示例12: Fit

// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on *base.Instances) {
	rule := new(InformationGainRuleGenerator)
	if t.PruneSplit > 0.001 {
		insts := base.InstancesTrainTestSplit(on, t.PruneSplit)
		t.Root = InferID3Tree(insts[0], rule)
		t.Root.Prune(insts[1])
	} else {
		t.Root = InferID3Tree(on, rule)
	}
}
开发者ID:24hours,项目名称:golearn,代码行数:11,代码来源:id3.go

示例13: Fit

// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
	rule := new(InformationGainRuleGenerator)
	if t.PruneSplit > 0.001 {
		trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
		t.Root = InferID3Tree(trainData, rule)
		t.Root.Prune(testData)
	} else {
		t.Root = InferID3Tree(on, rule)
	}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:11,代码来源:id3.go

示例14: BenchmarkFit

func BenchmarkFit(b *testing.B) {

	a := NewAveragePerceptron(10, 1.2, 0.5, 0.3)
	absPath, _ := filepath.Abs("../examples/datasets/house-votes-84.csv")
	rawData, _ := base.ParseCSVToInstances(absPath, true)
	trainData, _ := base.InstancesTrainTestSplit(rawData, 0.5)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		a.Fit(trainData)
	}
}
开发者ID:CTLife,项目名称:golearn,代码行数:11,代码来源:average_test.go

示例15: main

func main() {
	data, err := base.ParseCSVToInstances("iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	cls := knn.NewKnnClassifier("euclidean", 2)

	trainData, testData := base.InstancesTrainTestSplit(data, 0.8)
	cls.Fit(trainData)

	predictions := cls.Predict(testData)
	fmt.Println(predictions)

	confusionMat := evaluation.GetConfusionMatrix(testData, predictions)
	fmt.Println(evaluation.GetSummary(confusionMat))
}
开发者ID:vkarthi46,项目名称:ml-algorithms-simple,代码行数:17,代码来源:golearn_sample.go


注:本文中的github.com/sjwhitworth/golearn/base.InstancesTrainTestSplit函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。