当前位置: 首页>>代码示例>>Golang>>正文


Golang base.GetClass函数代码示例

本文整理汇总了Golang中github.com/sjwhitworth/golearn/base.GetClass函数的典型用法代码示例。如果您正苦于以下问题:Golang GetClass函数的具体用法?Golang GetClass怎么用?Golang GetClass使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了GetClass函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Golang代码示例。

示例1: TestLinearRegression

func TestLinearRegression(t *testing.T) {
	lr := NewLinearRegression()

	rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
	if err != nil {
		t.Fatal(err)
	}

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
	err = lr.Fit(trainData)
	if err != nil {
		t.Fatal(err)
	}

	predictions, err := lr.Predict(testData)
	if err != nil {
		t.Fatal(err)
	}

	_, rows := predictions.Size()

	for i := 0; i < rows; i++ {
		fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
	}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:25,代码来源:linear_regression_test.go

示例2: TestKnnClassifier

func TestKnnClassifier(t *testing.T) {
	Convey("Given labels, a classifier and data", t, func() {
		trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
		So(err, ShouldBeNil)

		testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
		So(err, ShouldBeNil)

		cls := NewKnnClassifier("euclidean", 2)
		cls.Fit(trainingData)
		predictions := cls.Predict(testingData)
		So(predictions, ShouldNotEqual, nil)

		Convey("When predicting the label for our first vector", func() {
			result := base.GetClass(predictions, 0)
			Convey("The result should be 'blue", func() {
				So(result, ShouldEqual, "blue")
			})
		})

		Convey("When predicting the label for our second vector", func() {
			result2 := base.GetClass(predictions, 1)
			Convey("The result should be 'red", func() {
				So(result2, ShouldEqual, "red")
			})
		})
	})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:28,代码来源:knn_test.go

示例3: TestLinearRegression

func TestLinearRegression(t *testing.T) {
	Convey("Doing a  linear regression", t, func() {
		lr := NewLinearRegression()

		Convey("With no training data", func() {
			Convey("Predicting", func() {
				testData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
				So(err, ShouldBeNil)

				_, err = lr.Predict(testData)

				Convey("Should result in a NoTrainingDataError", func() {
					So(err, ShouldEqual, NoTrainingDataError)
				})

			})
		})

		Convey("With not enough training data", func() {
			trainingDatum, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
			So(err, ShouldBeNil)

			Convey("Fitting", func() {
				err = lr.Fit(trainingDatum)

				Convey("Should result in a NotEnoughDataError", func() {
					So(err, ShouldEqual, NotEnoughDataError)
				})
			})
		})

		Convey("With sufficient training data", func() {
			instances, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
			So(err, ShouldBeNil)
			trainData, testData := base.InstancesTrainTestSplit(instances, 0.1)

			Convey("Fitting and Predicting", func() {
				err := lr.Fit(trainData)
				So(err, ShouldBeNil)

				predictions, err := lr.Predict(testData)
				So(err, ShouldBeNil)

				Convey("It makes reasonable predictions", func() {
					_, rows := predictions.Size()

					for i := 0; i < rows; i++ {
						actualValue, _ := strconv.ParseFloat(base.GetClass(testData, i), 64)
						expectedValue, _ := strconv.ParseFloat(base.GetClass(predictions, i), 64)

						So(actualValue, ShouldAlmostEqual, expectedValue, actualValue*0.05)
					}
				})
			})
		})
	})
}
开发者ID:CTLife,项目名称:golearn,代码行数:57,代码来源:linear_regression_test.go

示例4: TestLayeredXORInline

func TestLayeredXORInline(t *testing.T) {

	Convey("Given an inline XOR dataset...", t, func() {

		data := mat64.NewDense(4, 3, []float64{
			1, 0, 1,
			0, 1, 1,
			0, 0, 0,
			1, 1, 0,
		})

		XORData := base.InstancesFromMat64(4, 3, data)
		classAttr := base.GetAttributeByName(XORData, "2")
		XORData.AddClassAttribute(classAttr)

		net := NewMultiLayerNet([]int{3})
		net.MaxIterations = 20000
		net.Fit(XORData)

		Convey("After running for 20000 iterations, should have some predictive power...", func() {

			Convey("The right nodes should be connected in the network...", func() {
				So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
				So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)

				for i := 1; i <= 6; i++ {
					So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
				}

			})
			out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
			net.network.Activate(out, 2)
			So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)

			Convey("And Predict() should do OK too...", func() {

				pred := net.Predict(XORData)

				for _, a := range pred.AllAttributes() {
					af, ok := a.(*base.FloatAttribute)
					So(ok, ShouldBeTrue)

					af.Precision = 1
				}

				So(base.GetClass(pred, 0), ShouldEqual, "1.0")
				So(base.GetClass(pred, 1), ShouldEqual, "1.0")
				So(base.GetClass(pred, 2), ShouldEqual, "0.0")
				So(base.GetClass(pred, 3), ShouldEqual, "0.0")

			})
		})

	})

}
开发者ID:thedadams,项目名称:golearn,代码行数:56,代码来源:layered_test.go

示例5: TestLayeredXOR

func TestLayeredXOR(t *testing.T) {

	Convey("Given an XOR dataset...", t, func() {

		XORData, err := base.ParseCSVToInstances("xor.csv", false)
		So(err, ShouldEqual, nil)

		fmt.Println(XORData)
		net := NewMultiLayerNet([]int{3})
		net.MaxIterations = 20000
		net.Fit(XORData)

		Convey("After running for 20000 iterations, should have some predictive power...", func() {

			Convey("The right nodes should be connected in the network...", func() {

				fmt.Println(net.network)
				So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
				So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)

				for i := 1; i <= 6; i++ {
					So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
				}

			})
			out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
			net.network.Activate(out, 2)
			fmt.Println(out)
			So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)

			Convey("And Predict() should do OK too...", func() {

				pred := net.Predict(XORData)

				for _, a := range pred.AllAttributes() {
					af, ok := a.(*base.FloatAttribute)
					if !ok {
						panic("All of these should be FloatAttributes!")
					}
					af.Precision = 1
				}

				So(base.GetClass(pred, 0), ShouldEqual, "0.0")
				So(base.GetClass(pred, 1), ShouldEqual, "1.0")
				So(base.GetClass(pred, 2), ShouldEqual, "1.0")
				So(base.GetClass(pred, 3), ShouldEqual, "0.0")

			})
		})

	})

}
开发者ID:JacobXie,项目名称:golearn,代码行数:53,代码来源:layered_test.go

示例6: ChiMBuildFrequencyTable

func ChiMBuildFrequencyTable(attr base.Attribute, inst base.FixedDataGrid) []*FrequencyTableEntry {
	ret := make([]*FrequencyTableEntry, 0)
	attribute := attr.(*base.FloatAttribute)

	attrSpec, err := inst.GetAttribute(attr)
	if err != nil {
		panic(err)
	}
	attrSpecs := []base.AttributeSpec{attrSpec}

	err = inst.MapOverRows(attrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		value := row[0]
		valueConv := attribute.GetFloatFromSysVal(value)
		class := base.GetClass(inst, rowNo)
		// Search the frequency table for the value
		found := false
		for _, entry := range ret {
			if entry.Value == valueConv {
				found = true
				entry.Frequency[class] += 1
			}
		}
		if !found {
			newEntry := &FrequencyTableEntry{
				valueConv,
				make(map[string]int),
			}
			newEntry.Frequency[class] = 1
			ret = append(ret, newEntry)
		}
		return true, nil
	})

	return ret
}
开发者ID:Gudym,项目名称:golearn,代码行数:35,代码来源:chimerge_funcs.go

示例7: vote

func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
	// Reset maxMap
	for a := range maxmap {
		maxmap[a] = 0
	}

	// Refresh maxMap
	for _, elem := range values {
		label := base.GetClass(KNN.TrainingData, elem)
		if _, ok := maxmap[label]; ok {
			maxmap[label]++
		} else {
			maxmap[label] = 1
		}
	}

	// Sort the maxMap
	var maxClass string
	maxVal := -1
	for a := range maxmap {
		if maxmap[a] > maxVal {
			maxVal = maxmap[a]
			maxClass = a
		}
	}
	return maxClass
}
开发者ID:nickpoorman,项目名称:golearn,代码行数:27,代码来源:knn.go

示例8: getNumericAttributeEntropy

func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) {

	// Resolve Attribute
	attrSpec, err := f.GetAttribute(attr)
	if err != nil {
		panic(err)
	}

	// Build sortable vector
	_, rows := f.Size()
	refs := make([]numericSplitRef, rows)
	f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
		cls := base.GetClass(f, row)
		v := base.UnpackBytesToFloat(val[0])
		refs[row] = numericSplitRef{v, cls}
		return true, nil
	})

	// Sort
	sort.Sort(splitVec(refs))

	generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
		presplit := make(map[string]int)
		postplit := make(map[string]int)
		for _, i := range refs {
			if i.val < val {
				presplit[i.class]++
			} else {
				postplit[i.class]++
			}
		}
		ret := make(map[string]map[string]int)
		ret["0"] = presplit
		ret["1"] = postplit
		return ret
	}

	minSplitEntropy := math.Inf(1)
	minSplitVal := math.Inf(1)
	// Consider each possible function
	for i := 0; i < len(refs)-1; i++ {
		val := refs[i].val + refs[i+1].val
		val /= 2
		splitDist := generateCandidateSplitDistribution(val)
		splitEntropy := getSplitEntropy(splitDist)
		if splitEntropy < minSplitEntropy {
			minSplitEntropy = splitEntropy
			minSplitVal = val
		}
	}

	return minSplitEntropy, minSplitVal
}
开发者ID:CTLife,项目名称:golearn,代码行数:53,代码来源:entropy.go

示例9: GetConfusionMatrix

// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) (map[string]map[string]int, error) {
	_, refRows := ref.Size()
	_, genRows := gen.Size()

	if refRows != genRows {
		return nil, errors.New(fmt.Sprintf("Row count mismatch: ref has %d rows, gen has %d rows", refRows, genRows))
	}

	ret := make(map[string]map[string]int)

	for i := 0; i < int(refRows); i++ {
		referenceClass := base.GetClass(ref, i)
		predictedClass := base.GetClass(gen, i)
		if _, ok := ret[referenceClass]; ok {
			ret[referenceClass][predictedClass] += 1
		} else {
			ret[referenceClass] = make(map[string]int)
			ret[referenceClass][predictedClass] = 1
		}
	}
	return ret, nil
}
开发者ID:CTLife,项目名称:golearn,代码行数:24,代码来源:confusion.go

示例10: GetConfusionMatrix

// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) map[string]map[string]int {

	_, refRows := ref.Size()
	_, genRows := gen.Size()

	if refRows != genRows {
		panic("Row counts should match")
	}

	ret := make(map[string]map[string]int)

	for i := 0; i < int(refRows); i++ {
		referenceClass := base.GetClass(ref, i)
		predictedClass := base.GetClass(gen, i)
		if _, ok := ret[referenceClass]; ok {
			ret[referenceClass][predictedClass] += 1
		} else {
			ret[referenceClass] = make(map[string]int)
			ret[referenceClass][predictedClass] = 1
		}
	}
	return ret
}
开发者ID:Gudym,项目名称:golearn,代码行数:25,代码来源:confusion.go

示例11: main

func main() {
	xOne, err := strconv.ParseInt(os.Args[1], 10, 64)
	xTwo, err := strconv.ParseInt(os.Args[2], 10, 64)
	fmt.Println("")

	if err != nil {
		panic(err)
	}

	XNORData, _ := base.ParseCSVToInstances("xnor.csv", false)

	net := neural.NewMultiLayerNet([]int{3})
	net.MaxIterations = 20000
	net.Fit(XNORData)

	pred := net.Predict(XNORData)

	var inputVectorIndex int
	if xOne == 0 && xTwo == 0 {
		inputVectorIndex = 0
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else if xOne == 0 && xTwo == 1 {
		inputVectorIndex = 1
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else if xOne == 1 && xTwo == 0 {
		inputVectorIndex = 2
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else if xOne == 1 && xTwo == 1 {
		inputVectorIndex = 3
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else {
		fmt.Println("Your input is incorrect. Quitting. ")
	}

	fmt.Println("")

}
开发者ID:nick11roberts,项目名称:neural-xnor,代码行数:37,代码来源:main.go

示例12: processData

func processData(x base.FixedDataGrid) instances {
	_, rows := x.Size()

	result := make(instances, rows)

	// Retrieve numeric non-class Attributes
	numericAttrs := base.NonClassFloatAttributes(x)
	numericAttrSpecs := base.ResolveAttributes(x, numericAttrs)

	// Retrieve class Attributes
	classAttrs := x.AllClassAttributes()
	if len(classAttrs) != 1 {
		panic("Only one classAttribute supported!")
	}

	// Check that the class Attribute is categorical
	// (with two values) or binary
	classAttr := classAttrs[0]
	if attr, ok := classAttr.(*base.CategoricalAttribute); ok {
		if len(attr.GetValues()) != 2 {
			panic("To many values for Attribute!")
		}
	} else if _, ok := classAttr.(*base.BinaryAttribute); ok {
	} else {
		panic("Wrong class Attribute type!")
	}

	// Convert each row
	x.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		// Allocate a new row
		probRow := make([]float64, len(numericAttrSpecs))

		// Read out the row
		for i, _ := range numericAttrSpecs {
			probRow[i] = base.UnpackBytesToFloat(row[i])
		}

		// Get the class for the values
		class := base.GetClass(x, rowNo)
		instance := instance{class, probRow}
		result[rowNo] = instance
		return true, nil
	})
	return result
}
开发者ID:CTLife,项目名称:golearn,代码行数:45,代码来源:average.go

示例13: Predict

// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {

	// Check what distance function we are using
	var distanceFunc pairwise.PairwiseDistanceFunc
	switch KNN.DistanceFunc {
	case "euclidean":
		distanceFunc = pairwise.NewEuclidean()
	case "manhattan":
		distanceFunc = pairwise.NewManhattan()
	default:
		panic("unsupported distance function")

	}
	// Check Compatibility
	allAttrs := base.CheckCompatible(what, KNN.TrainingData)
	if allAttrs == nil {
		// Don't have the same Attributes
		return nil
	}

	// Remove the Attributes which aren't numeric
	allNumericAttrs := make([]base.Attribute, 0)
	for _, a := range allAttrs {
		if fAttr, ok := a.(*base.FloatAttribute); ok {
			allNumericAttrs = append(allNumericAttrs, fAttr)
		}
	}

	// Generate return vector
	ret := base.GeneratePredictionVector(what)

	// Resolve Attribute specifications for both
	whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
	trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)

	// Reserve storage for most the most similar items
	distances := make(map[int]float64)

	// Reserve storage for voting map
	maxmap := make(map[string]int)

	// Reserve storage for row computations
	trainRowBuf := make([]float64, len(allNumericAttrs))
	predRowBuf := make([]float64, len(allNumericAttrs))

	// Iterate over all outer rows
	what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
		// Read the float values out
		for i, _ := range allNumericAttrs {
			predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
		}

		predMat := utilities.FloatsToMatrix(predRowBuf)

		// Find the closest match in the training data
		KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {

			// Read the float values out
			for i, _ := range allNumericAttrs {
				trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
			}

			// Compute the distance
			trainMat := utilities.FloatsToMatrix(trainRowBuf)
			distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
			return true, nil
		})

		sorted := utilities.SortIntMap(distances)
		values := sorted[:KNN.NearestNeighbours]

		// Reset maxMap
		for a := range maxmap {
			maxmap[a] = 0
		}

		// Refresh maxMap
		for _, elem := range values {
			label := base.GetClass(KNN.TrainingData, elem)
			if _, ok := maxmap[label]; ok {
				maxmap[label]++
			} else {
				maxmap[label] = 1
			}
		}

		// Sort the maxMap
		var maxClass string
		maxVal := -1
		for a := range maxmap {
			if maxmap[a] > maxVal {
				maxVal = maxmap[a]
				maxClass = a
			}
		}

		base.SetClass(ret, predRowNo, maxClass)
		return true, nil

//.........这里部分代码省略.........
开发者ID:hpxro7,项目名称:golearn,代码行数:101,代码来源:knn.go

示例14: TestSimple

func TestSimple(t *testing.T) {
	Convey("Given a simple training data", t, func() {
		trainingData, err := base.ParseCSVToInstances("test/simple_train.csv", false)
		So(err, ShouldBeNil)

		nb := NewBernoulliNBClassifier()
		nb.Fit(convertToBinary(trainingData))

		Convey("Check if Fit is working as expected", func() {
			Convey("All data needed for prior should be correctly calculated", func() {
				So(nb.classInstances["blue"], ShouldEqual, 2)
				So(nb.classInstances["red"], ShouldEqual, 2)
				So(nb.trainingInstances, ShouldEqual, 4)
			})

			Convey("'red' conditional probabilities should be correct", func() {
				logCondProbTok0 := nb.condProb["red"][0]
				logCondProbTok1 := nb.condProb["red"][1]
				logCondProbTok2 := nb.condProb["red"][2]

				So(logCondProbTok0, ShouldAlmostEqual, 1.0)
				So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
				So(logCondProbTok2, ShouldAlmostEqual, 1.0)
			})

			Convey("'blue' conditional probabilities should be correct", func() {
				logCondProbTok0 := nb.condProb["blue"][0]
				logCondProbTok1 := nb.condProb["blue"][1]
				logCondProbTok2 := nb.condProb["blue"][2]

				So(logCondProbTok0, ShouldAlmostEqual, 1.0)
				So(logCondProbTok1, ShouldAlmostEqual, 1.0)
				So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
			})
		})

		Convey("PredictOne should work as expected", func() {
			Convey("Using a document with different number of cols should panic", func() {
				testDoc := [][]byte{[]byte{0}, []byte{2}}
				So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
			})

			Convey("Token 1 should be a good predictor of the blue class", func() {
				testDoc := [][]byte{[]byte{0}, []byte{1}, []byte{0}}
				So(nb.PredictOne(testDoc), ShouldEqual, "blue")

				testDoc = [][]byte{[]byte{1}, []byte{1}, []byte{0}}
				So(nb.PredictOne(testDoc), ShouldEqual, "blue")
			})

			Convey("Token 2 should be a good predictor of the red class", func() {
				testDoc := [][]byte{[]byte{0}, []byte{0}, []byte{1}}
				So(nb.PredictOne(testDoc), ShouldEqual, "red")
				testDoc = [][]byte{[]byte{1}, []byte{0}, []byte{1}}
				So(nb.PredictOne(testDoc), ShouldEqual, "red")
			})
		})

		Convey("Predict should work as expected", func() {
			testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
			So(err, ShouldBeNil)

			predictions := nb.Predict(convertToBinary(testData))

			Convey("All simple predicitions should be correct", func() {
				So(base.GetClass(predictions, 0), ShouldEqual, "blue")
				So(base.GetClass(predictions, 1), ShouldEqual, "red")
				So(base.GetClass(predictions, 2), ShouldEqual, "blue")
				So(base.GetClass(predictions, 3), ShouldEqual, "red")
			})
		})
	})
}
开发者ID:CTLife,项目名称:golearn,代码行数:73,代码来源:bernoulli_nb_test.go

示例15: Fit

// Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {

	// Check that all Attributes are binary
	classAttrs := X.AllClassAttributes()
	allAttrs := X.AllAttributes()
	featAttrs := base.AttributeDifference(allAttrs, classAttrs)
	for i := range featAttrs {
		if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
			panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
		}
	}
	featAttrSpecs := base.ResolveAttributes(X, featAttrs)

	// Check that only one classAttribute is defined
	if len(classAttrs) != 1 {
		panic("Only one class Attribute can be used")
	}

	// Number of features and instances in this training set
	_, nb.trainingInstances = X.Size()
	nb.attrs = featAttrs
	nb.features = len(featAttrs)

	// Number of instances in class
	nb.classInstances = make(map[string]int)

	// Number of documents with given term (by class)
	docsContainingTerm := make(map[string][]int)

	// This algorithm could be vectorized after binarizing the data
	// matrix. Since mat64 doesn't have this function, a iterative
	// version is used.
	X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
		class := base.GetClass(X, r)

		// increment number of instances in class
		t, ok := nb.classInstances[class]
		if !ok {
			t = 0
		}
		nb.classInstances[class] = t + 1

		for feat := 0; feat < len(docVector); feat++ {
			v := docVector[feat]
			// In Bernoulli Naive Bayes the presence and absence of
			// features are considered. All non-zero values are
			// treated as presence.
			if v[0] > 0 {
				// Update number of times this feature appeared within
				// given label.
				t, ok := docsContainingTerm[class]
				if !ok {
					t = make([]int, nb.features)
					docsContainingTerm[class] = t
				}
				t[feat] += 1
			}
		}
		return true, nil
	})

	// Pre-calculate conditional probabilities for each class
	for c, _ := range nb.classInstances {
		nb.condProb[c] = make([]float64, nb.features)
		for feat := 0; feat < nb.features; feat++ {
			classTerms, _ := docsContainingTerm[c]
			numDocs := classTerms[feat]
			docsInClass, _ := nb.classInstances[c]

			classCondProb, _ := nb.condProb[c]
			// Calculate conditional probability with laplace smoothing
			classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
		}
	}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:77,代码来源:bernoulli_nb.go


注:本文中的github.com/sjwhitworth/golearn/base.GetClass函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。