當前位置: 首頁>>代碼示例>>Golang>>正文


Golang base.ResolveAttributes函數代碼示例

本文整理匯總了Golang中github.com/sjwhitworth/golearn/base.ResolveAttributes函數的典型用法代碼示例。如果您正苦於以下問題:Golang ResolveAttributes函數的具體用法?Golang ResolveAttributes怎麽用?Golang ResolveAttributes使用的例子?那麽, 這裏精選的函數代碼示例或許可以為您提供幫助。


在下文中一共展示了ResolveAttributes函數的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Golang代碼示例。

示例1: Predict

// Predict issues predictions. Each class-specific classifier is expected
// to output a value between 0 (indicating that a given instance is not
// a given class) and 1 (indicating that the given instance is definitely
// that class). For each instance, the class with the highest value is chosen.
// The result is undefined if several underlying models output the same value.
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
	ret := base.GeneratePredictionVector(what)
	vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
	specs := make([]base.AttributeSpec, m.maxClassVal+1)
	for i := uint64(0); i <= m.maxClassVal; i++ {
		f := m.filters[i]
		c := base.NewLazilyFilteredInstances(what, f)
		p, err := m.classifiers[i].Predict(c)
		if err != nil {
			return nil, err
		}
		vecs[i] = p
		specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
	}
	_, rows := ret.Size()
	spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
	for i := 0; i < rows; i++ {
		class := uint64(0)
		best := 0.0
		for j := uint64(0); j <= m.maxClassVal; j++ {
			val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
			if val > best {
				class = j
				best = val
			}
		}
		ret.Set(spec, i, base.PackU64ToBytes(class))
	}
	return ret, nil
}
開發者ID:CTLife,項目名稱:golearn,代碼行數:35,代碼來源:one_v_all.go

示例2: Predict

func (lr *LogisticRegression) Predict(X base.FixedDataGrid) base.FixedDataGrid {

	// Only support 1 class Attribute
	classAttrs := X.AllClassAttributes()
	if len(classAttrs) != 1 {
		panic(fmt.Sprintf("%d Wrong number of classes", len(classAttrs)))
	}
	// Generate return structure
	ret := base.GeneratePredictionVector(X)
	classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
	// Retrieve numeric non-class Attributes
	numericAttrs := base.NonClassFloatAttributes(X)
	numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)

	// Allocate row storage
	row := make([]float64, len(numericAttrSpecs))
	X.MapOverRows(numericAttrSpecs, func(rowBytes [][]byte, rowNo int) (bool, error) {
		for i, r := range rowBytes {
			row[i] = base.UnpackBytesToFloat(r)
		}
		val := Predict(lr.model, row)
		vals := base.PackFloatToBytes(val)
		ret.Set(classAttrSpecs[0], rowNo, vals)
		return true, nil
	})

	return ret
}
開發者ID:Gudym,項目名稱:golearn,代碼行數:28,代碼來源:logistic.go

示例3: Predict

// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
	predictions := base.GeneratePredictionVector(what)
	classAttr := getClassAttr(predictions)
	classAttrSpec, err := predictions.GetAttribute(classAttr)
	if err != nil {
		panic(err)
	}
	predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
	predAttrSpecs := base.ResolveAttributes(what, predAttrs)
	what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		cur := d
		for {
			if cur.Children == nil {
				predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
				break
			} else {
				splitVal := cur.SplitRule.SplitVal
				at := cur.SplitRule.SplitAttr
				ats, err := what.GetAttribute(at)
				if err != nil {
					//predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
					//break
					panic(err)
				}

				var classVar string
				if _, ok := ats.GetAttribute().(*base.FloatAttribute); ok {
					// If it's a numeric Attribute (e.g. FloatAttribute) check that
					// the value of the current node is greater than the old one
					classVal := base.UnpackBytesToFloat(what.Get(ats, rowNo))
					if classVal > splitVal {
						classVar = "1"
					} else {
						classVar = "0"
					}
				} else {
					classVar = ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
				}
				if next, ok := cur.Children[classVar]; ok {
					cur = next
				} else {
					// Suspicious of this
					var bestChild string
					for c := range cur.Children {
						bestChild = c
						if c > classVar {
							break
						}
					}
					cur = cur.Children[bestChild]
				}
			}
		}
		return true, nil
	})
	return predictions, nil
}
開發者ID:tanduong,項目名稱:golearn,代碼行數:58,代碼來源:id3.go

示例4: Predict

// Predict is just a wrapper for the PredictOne function.
//
// IMPORTANT: Predict panics if Fit was not called or if the
// document vector and train matrix have a different number of columns.
func (nb *BernoulliNBClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
	// Generate return vector
	ret := base.GeneratePredictionVector(what)

	// Get the features
	featAttrSpecs := base.ResolveAttributes(what, nb.attrs)

	what.MapOverRows(featAttrSpecs, func(row [][]byte, i int) (bool, error) {
		base.SetClass(ret, i, nb.PredictOne(row))
		return true, nil
	})

	return ret
}
開發者ID:JacobXie,項目名稱:golearn,代碼行數:18,代碼來源:bernoulli_nb.go

示例5: TestChiMergeDiscretization

func TestChiMergeDiscretization(t *testing.T) {
	Convey("Chi-Merge Discretization", t, func() {
		chimDatasetPath := "../examples/datasets/chim.csv"

		Convey(fmt.Sprintf("With the '%s' dataset", chimDatasetPath), func() {
			instances, err := base.ParseCSVToInstances(chimDatasetPath, true)
			So(err, ShouldBeNil)

			_, rows := instances.Size()

			frequencies := chiMerge(instances, instances.AllAttributes()[0], 0.9, 0, rows)
			values := []float64{}
			for _, entry := range frequencies {
				values = append(values, entry.Value)
			}

			Convey("Computes frequencies correctly", func() {
				So(values, ShouldResemble, []float64{1.3, 56.2, 87.1})
			})
		})

		irisHeadersDatasetpath := "../examples/datasets/iris_headers.csv"

		Convey(fmt.Sprintf("With the '%s' dataset", irisHeadersDatasetpath), func() {
			instances, err := base.ParseCSVToInstances(irisHeadersDatasetpath, true)
			So(err, ShouldBeNil)

			Convey("Sorting the instances first", func() {
				allAttributes := instances.AllAttributes()
				sortedAttributesSpecs := base.ResolveAttributes(instances, allAttributes)[0:1]
				sortedInstances, err := base.Sort(instances, base.Ascending, sortedAttributesSpecs)
				So(err, ShouldBeNil)

				_, rows := sortedInstances.Size()

				frequencies := chiMerge(sortedInstances, sortedInstances.AllAttributes()[0], 0.9, 0, rows)
				values := []float64{}
				for _, entry := range frequencies {
					values = append(values, entry.Value)
				}

				Convey("Computes frequencies correctly", func() {
					So(values, ShouldResemble, []float64{4.3, 5.5, 5.8, 6.3, 7.1})
				})
			})
		})
	})
}
開發者ID:CTLife,項目名稱:golearn,代碼行數:48,代碼來源:chimerge_test.go

示例6: processData

func processData(x base.FixedDataGrid) instances {
	_, rows := x.Size()

	result := make(instances, rows)

	// Retrieve numeric non-class Attributes
	numericAttrs := base.NonClassFloatAttributes(x)
	numericAttrSpecs := base.ResolveAttributes(x, numericAttrs)

	// Retrieve class Attributes
	classAttrs := x.AllClassAttributes()
	if len(classAttrs) != 1 {
		panic("Only one classAttribute supported!")
	}

	// Check that the class Attribute is categorical
	// (with two values) or binary
	classAttr := classAttrs[0]
	if attr, ok := classAttr.(*base.CategoricalAttribute); ok {
		if len(attr.GetValues()) != 2 {
			panic("To many values for Attribute!")
		}
	} else if _, ok := classAttr.(*base.BinaryAttribute); ok {
	} else {
		panic("Wrong class Attribute type!")
	}

	// Convert each row
	x.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		// Allocate a new row
		probRow := make([]float64, len(numericAttrSpecs))

		// Read out the row
		for i, _ := range numericAttrSpecs {
			probRow[i] = base.UnpackBytesToFloat(row[i])
		}

		// Get the class for the values
		class := base.GetClass(x, rowNo)
		instance := instance{class, probRow}
		result[rowNo] = instance
		return true, nil
	})
	return result
}
開發者ID:CTLife,項目名稱:golearn,代碼行數:45,代碼來源:average.go

示例7: Predict

// Predict outputs a base.Instances containing predictions from this tree
func (d *DecisionTreeNode) Predict(what base.FixedDataGrid) base.FixedDataGrid {
	predictions := base.GeneratePredictionVector(what)
	classAttr := getClassAttr(predictions)
	classAttrSpec, err := predictions.GetAttribute(classAttr)
	if err != nil {
		panic(err)
	}
	predAttrs := base.AttributeDifferenceReferences(what.AllAttributes(), predictions.AllClassAttributes())
	predAttrSpecs := base.ResolveAttributes(what, predAttrs)
	what.MapOverRows(predAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		cur := d
		for {
			if cur.Children == nil {
				predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
				break
			} else {
				at := cur.SplitAttr
				ats, err := what.GetAttribute(at)
				if err != nil {
					predictions.Set(classAttrSpec, rowNo, classAttr.GetSysValFromString(cur.Class))
					break
				}

				classVar := ats.GetAttribute().GetStringFromSysVal(what.Get(ats, rowNo))
				if next, ok := cur.Children[classVar]; ok {
					cur = next
				} else {
					var bestChild string
					for c := range cur.Children {
						bestChild = c
						if c > classVar {
							break
						}
					}
					cur = cur.Children[bestChild]
				}
			}
		}
		return true, nil
	})
	return predictions
}
開發者ID:JacobXie,項目名稱:golearn,代碼行數:43,代碼來源:id3.go

示例8: TestChiMerge2

func TestChiMerge2(testEnv *testing.T) {
	//
	// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
	//   Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
	inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Sort the instances
	allAttrs := inst.AllAttributes()
	sortAttrSpecs := base.ResolveAttributes(inst, allAttrs)[0:1]
	instSorted, err := base.Sort(inst, base.Ascending, sortAttrSpecs)
	if err != nil {
		panic(err)
	}

	// Perform Chi-Merge
	_, rows := inst.Size()
	freq := chiMerge(instSorted, allAttrs[0], 0.90, 0, rows)
	if len(freq) != 5 {
		testEnv.Errorf("Wrong length (%d)", len(freq))
		testEnv.Error(freq)
	}
	if freq[0].Value != 4.3 {
		testEnv.Error(freq[0])
	}
	if freq[1].Value != 5.5 {
		testEnv.Error(freq[1])
	}
	if freq[2].Value != 5.8 {
		testEnv.Error(freq[2])
	}
	if freq[3].Value != 6.3 {
		testEnv.Error(freq[3])
	}
	if freq[4].Value != 7.1 {
		testEnv.Error(freq[4])
	}
}
開發者ID:Gudym,項目名稱:golearn,代碼行數:40,代碼來源:chimerge_test.go

示例9: convertInstancesToLabelVec

func convertInstancesToLabelVec(X base.FixedDataGrid) []float64 {
	// Get the class Attributes
	classAttrs := X.AllClassAttributes()
	// Only support 1 class Attribute
	if len(classAttrs) != 1 {
		panic(fmt.Sprintf("%d ClassAttributes (1 expected)", len(classAttrs)))
	}
	// ClassAttribute must be numeric
	if _, ok := classAttrs[0].(*base.FloatAttribute); !ok {
		panic(fmt.Sprintf("%s: ClassAttribute must be a FloatAttribute", classAttrs[0]))
	}
	// Allocate return structure
	_, rows := X.Size()
	labelVec := make([]float64, rows)
	// Resolve class Attribute specification
	classAttrSpecs := base.ResolveAttributes(X, classAttrs)
	X.MapOverRows(classAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		labelVec[rowNo] = base.UnpackBytesToFloat(row[0])
		return true, nil
	})
	return labelVec
}
開發者ID:Gudym,項目名稱:golearn,代碼行數:22,代碼來源:logistic.go

示例10: convertInstancesToProblemVec

func convertInstancesToProblemVec(X base.FixedDataGrid) [][]float64 {
	// Allocate problem array
	_, rows := X.Size()
	problemVec := make([][]float64, rows)

	// Retrieve numeric non-class Attributes
	numericAttrs := base.NonClassFloatAttributes(X)
	numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)

	// Convert each row
	X.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		// Allocate a new row
		probRow := make([]float64, len(numericAttrSpecs))
		// Read out the row
		for i, _ := range numericAttrSpecs {
			probRow[i] = base.UnpackBytesToFloat(row[i])
		}
		// Add the row
		problemVec[rowNo] = probRow
		return true, nil
	})
	return problemVec
}
開發者ID:Gudym,項目名稱:golearn,代碼行數:23,代碼來源:logistic.go

示例11: Predict

func (lr *LinearRegression) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
	if !lr.fitted {
		return nil, NoTrainingDataError
	}

	ret := base.GeneratePredictionVector(X)
	attrSpecs := base.ResolveAttributes(X, lr.attrs)
	clsSpec, err := ret.GetAttribute(lr.cls)
	if err != nil {
		return nil, err
	}

	X.MapOverRows(attrSpecs, func(row [][]byte, i int) (bool, error) {
		var prediction float64 = lr.disturbance
		for j, r := range row {
			prediction += base.UnpackBytesToFloat(r) * lr.regressionCoefficients[j]
		}

		ret.Set(clsSpec, i, base.PackFloatToBytes(prediction))
		return true, nil
	})

	return ret, nil
}
開發者ID:CTLife,項目名稱:golearn,代碼行數:24,代碼來源:linear_regression.go

示例12: Fit

// Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {

	// Check that all Attributes are binary
	classAttrs := X.AllClassAttributes()
	allAttrs := X.AllAttributes()
	featAttrs := base.AttributeDifference(allAttrs, classAttrs)
	for i := range featAttrs {
		if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
			panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
		}
	}
	featAttrSpecs := base.ResolveAttributes(X, featAttrs)

	// Check that only one classAttribute is defined
	if len(classAttrs) != 1 {
		panic("Only one class Attribute can be used")
	}

	// Number of features and instances in this training set
	_, nb.trainingInstances = X.Size()
	nb.attrs = featAttrs
	nb.features = len(featAttrs)

	// Number of instances in class
	nb.classInstances = make(map[string]int)

	// Number of documents with given term (by class)
	docsContainingTerm := make(map[string][]int)

	// This algorithm could be vectorized after binarizing the data
	// matrix. Since mat64 doesn't have this function, a iterative
	// version is used.
	X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
		class := base.GetClass(X, r)

		// increment number of instances in class
		t, ok := nb.classInstances[class]
		if !ok {
			t = 0
		}
		nb.classInstances[class] = t + 1

		for feat := 0; feat < len(docVector); feat++ {
			v := docVector[feat]
			// In Bernoulli Naive Bayes the presence and absence of
			// features are considered. All non-zero values are
			// treated as presence.
			if v[0] > 0 {
				// Update number of times this feature appeared within
				// given label.
				t, ok := docsContainingTerm[class]
				if !ok {
					t = make([]int, nb.features)
					docsContainingTerm[class] = t
				}
				t[feat] += 1
			}
		}
		return true, nil
	})

	// Pre-calculate conditional probabilities for each class
	for c, _ := range nb.classInstances {
		nb.condProb[c] = make([]float64, nb.features)
		for feat := 0; feat < nb.features; feat++ {
			classTerms, _ := docsContainingTerm[c]
			numDocs := classTerms[feat]
			docsInClass, _ := nb.classInstances[c]

			classCondProb, _ := nb.condProb[c]
			// Calculate conditional probability with laplace smoothing
			classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
		}
	}
}
開發者ID:JacobXie,項目名稱:golearn,代碼行數:77,代碼來源:bernoulli_nb.go

示例13: Predict

// Predict gathers predictions from all the classifiers
// and outputs the most common (majority) class
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
	n := runtime.NumCPU()
	// Channel to receive the results as they come in
	votes := make(chan base.DataGrid, n)
	// Count the votes for each class
	voting := make(map[int](map[string]int))

	// Create a goroutine to collect the votes
	var votingwait sync.WaitGroup
	votingwait.Add(1)
	go func() {
		for { // Need to resolve the voting problem
			incoming, ok := <-votes
			if ok {
				cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
				incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
					// Check if we've seen this class before...
					if _, ok := voting[predRow]; !ok {
						// If we haven't, create an entry
						voting[predRow] = make(map[string]int)
						// Continue on the current row
					}
					voting[predRow][base.GetClass(incoming, predRow)]++
					return true, nil
				})
			} else {
				votingwait.Done()
				break
			}
		}
	}()

	// Create workers to process the predictions
	processpipe := make(chan int, n)
	var processwait sync.WaitGroup
	for i := 0; i < n; i++ {
		processwait.Add(1)
		go func() {
			for {
				if i, ok := <-processpipe; ok {
					c := b.Models[i]
					l := b.generatePredictionInstances(i, from)
					votes <- c.Predict(l)
				} else {
					processwait.Done()
					break
				}
			}
		}()
	}

	// Send all the models to the workers for prediction
	for i := range b.Models {
		processpipe <- i
	}
	close(processpipe) // Finished sending models to be predicted
	processwait.Wait() // Predictors all finished processing
	close(votes)       // Close the vote channel and allow it to drain
	votingwait.Wait()  // All the votes are in

	// Generate the overall consensus
	ret := base.GeneratePredictionVector(from)
	for i := range voting {
		maxClass := ""
		maxCount := 0
		// Find the most popular class
		for c := range voting[i] {
			votes := voting[i][c]
			if votes > maxCount {
				maxClass = c
				maxCount = votes
			}
		}
		base.SetClass(ret, i, maxClass)
	}
	return ret
}
開發者ID:JacobXie,項目名稱:golearn,代碼行數:82,代碼來源:bagging.go

示例14: Predict

// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
	// Check what distance function we are using
	var distanceFunc pairwise.PairwiseDistanceFunc
	switch KNN.DistanceFunc {
	case "euclidean":
		distanceFunc = pairwise.NewEuclidean()
	case "manhattan":
		distanceFunc = pairwise.NewManhattan()
	default:
		panic("unsupported distance function")
	}
	// Check Compatibility
	allAttrs := base.CheckCompatible(what, KNN.TrainingData)
	if allAttrs == nil {
		// Don't have the same Attributes
		return nil
	}

	// Use optimised version if permitted
	if KNN.AllowOptimisations {
		if KNN.DistanceFunc == "euclidean" {
			if KNN.canUseOptimisations(what) {
				return KNN.optimisedEuclideanPredict(what.(*base.DenseInstances))
			}
		}
	}
	fmt.Println("Optimisations are switched off")

	// Remove the Attributes which aren't numeric
	allNumericAttrs := make([]base.Attribute, 0)
	for _, a := range allAttrs {
		if fAttr, ok := a.(*base.FloatAttribute); ok {
			allNumericAttrs = append(allNumericAttrs, fAttr)
		}
	}

	// Generate return vector
	ret := base.GeneratePredictionVector(what)

	// Resolve Attribute specifications for both
	whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
	trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)

	// Reserve storage for most the most similar items
	distances := make(map[int]float64)

	// Reserve storage for voting map
	maxmap := make(map[string]int)

	// Reserve storage for row computations
	trainRowBuf := make([]float64, len(allNumericAttrs))
	predRowBuf := make([]float64, len(allNumericAttrs))

	_, maxRow := what.Size()
	curRow := 0

	// Iterate over all outer rows
	what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {

		if (curRow%1) == 0 && curRow > 0 {
			fmt.Printf("KNN: %.2f %% done\n", float64(curRow)*100.0/float64(maxRow))
		}
		curRow++

		// Read the float values out
		for i, _ := range allNumericAttrs {
			predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
		}

		predMat := utilities.FloatsToMatrix(predRowBuf)

		// Find the closest match in the training data
		KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
			// Read the float values out
			for i, _ := range allNumericAttrs {
				trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
			}

			// Compute the distance
			trainMat := utilities.FloatsToMatrix(trainRowBuf)
			distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
			return true, nil
		})

		sorted := utilities.SortIntMap(distances)
		values := sorted[:KNN.NearestNeighbours]

		maxClass := KNN.vote(maxmap, values)

		base.SetClass(ret, predRowNo, maxClass)
		return true, nil

	})

	return ret
}
開發者ID:nickpoorman,項目名稱:golearn,代碼行數:97,代碼來源:knn.go

示例15: Predict

// Predict uses the underlying network to produce predictions for the
// class variables of X.
//
// Can only predict one CategoricalAttribute at a time, or up to n
// FloatAttributes. Set or unset ClassAttributes to work around this
// limitation.
func (m *MultiLayerNet) Predict(X base.FixedDataGrid) base.FixedDataGrid {

	// Create the return vector
	ret := base.GeneratePredictionVector(X)

	// Make sure everything's a FloatAttribute
	insts := m.convertToFloatInsts(X)

	// Get the input/output Attributes
	inputAttrs := base.NonClassAttributes(insts)
	outputAttrs := ret.AllClassAttributes()

	// Compute layers
	layers := 2 + len(m.layers)

	// Check that we're operating in a singular mode
	floatMode := 0
	categoricalMode := 0
	for _, a := range outputAttrs {
		if _, ok := a.(*base.CategoricalAttribute); ok {
			categoricalMode++
		} else if _, ok := a.(*base.FloatAttribute); ok {
			floatMode++
		} else {
			panic("Unsupported output Attribute type!")
		}
	}

	if floatMode > 0 && categoricalMode > 0 {
		panic("Can't predict a mix of float and categorical Attributes")
	} else if categoricalMode > 1 {
		panic("Can't predict more than one categorical class Attribute")
	}

	// Create the activation vector
	a := mat64.NewDense(m.network.size, 1, make([]float64, m.network.size))

	// Resolve the input AttributeSpecs
	inputAs := base.ResolveAttributes(insts, inputAttrs)

	// Resolve the output Attributespecs
	outputAs := base.ResolveAttributes(ret, outputAttrs)

	// Map over each input row
	insts.MapOverRows(inputAs, func(row [][]byte, rc int) (bool, error) {
		// Clear the activation vector
		for i := 0; i < m.network.size; i++ {
			a.Set(i, 0, 0.0)
		}
		// Build the activation vector
		for i, vb := range row {
			if cIndex, ok := m.attrs[inputAs[i].GetAttribute()]; !ok {
				panic("Can't resolve the Attribute!")
			} else {
				a.Set(cIndex, 0, base.UnpackBytesToFloat(vb))
			}
		}
		// Robots, activate!
		m.network.Activate(a, layers)

		// Decide which class to set
		if floatMode > 0 {
			for _, as := range outputAs {
				cIndex := m.attrs[as.GetAttribute()]
				ret.Set(as, rc, base.PackFloatToBytes(a.At(cIndex, 0)))
			}
		} else {
			maxIndex := 0
			maxVal := 0.0
			for i := m.classAttrOffset; i < m.classAttrOffset+m.classAttrCount; i++ {
				val := a.At(i, 0)
				if val > maxVal {
					maxIndex = i
					maxVal = val
				}
			}
			maxIndex -= m.classAttrOffset
			ret.Set(outputAs[0], rc, base.PackU64ToBytes(uint64(maxIndex)))
		}
		return true, nil
	})

	return ret

}
開發者ID:nickpoorman,項目名稱:golearn,代碼行數:91,代碼來源:layered.go


注:本文中的github.com/sjwhitworth/golearn/base.ResolveAttributes函數示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。