当前位置: 首页>>代码示例>>Java>>正文


Java Dirichlet类代码示例

本文整理汇总了Java中cc.mallet.types.Dirichlet的典型用法代码示例。如果您正苦于以下问题:Java Dirichlet类的具体用法?Java Dirichlet怎么用?Java Dirichlet使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


Dirichlet类属于cc.mallet.types包,在下文中一共展示了Dirichlet类的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createRandomChain

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
public static FactorGraph createRandomChain (cc.mallet.util.Randoms r, int length)
{
  Variable[] vars = new Variable[length];
  for (int i = 0; i < length; i++)
    vars[i] = new Variable (2);

  Dirichlet dirichlet = new Dirichlet (new double[] { 1, 1, 1, 1 });

  FactorGraph mdl = new FactorGraph (vars);
  for (int i = 0; i < length - 1; i++) {
    Multinomial m = dirichlet.randomMultinomial (r);
    double[] probs = m.getValues ();
    mdl.addFactor (vars[i], vars[i + 1], probs);
  }

  return mdl;
}
 
开发者ID:mimno,项目名称:GRMM,代码行数:18,代码来源:RandomGraphs.java

示例2: createDirectedModel

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private DirectedModel createDirectedModel ()
{
  int NUM_OUTCOMES = 2;
  cc.mallet.util.Randoms random = new cc.mallet.util.Randoms (13413);

  Dirichlet dirichlet = new Dirichlet (NUM_OUTCOMES, 1.0);
  double[] pA = dirichlet.randomVector (random);
  double[] pB = dirichlet.randomVector (random);

  TDoubleArrayList pC = new TDoubleArrayList (NUM_OUTCOMES * NUM_OUTCOMES * NUM_OUTCOMES);
  for (int i = 0; i < (NUM_OUTCOMES * NUM_OUTCOMES); i++) {
    pC.add (dirichlet.randomVector (random));
  }

  Variable[] vars = new Variable[] { new Variable (NUM_OUTCOMES), new Variable (NUM_OUTCOMES),
          new Variable (NUM_OUTCOMES) };
  DirectedModel mdl = new DirectedModel ();
  mdl.addFactor (new CPT (new TableFactor (vars[0], pA), vars[0]));
  mdl.addFactor (new CPT (new TableFactor (vars[1], pB), vars[1]));
  mdl.addFactor (new CPT (new TableFactor (vars, pC.toNativeArray ()), vars[2]));

  return mdl;
}
 
开发者ID:mimno,项目名称:GRMM,代码行数:24,代码来源:TestInference.java

示例3: createDirectedModel

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private DirectedModel createDirectedModel ()
{
  int NUM_OUTCOMES = 2;
  cc.mallet.util.Randoms random = new cc.mallet.util.Randoms (13413);

  Dirichlet dirichlet = new Dirichlet (NUM_OUTCOMES, 1.0);
  double[] pA = dirichlet.randomVector (random);
  double[] pB = dirichlet.randomVector (random);

  TDoubleArrayList pC = new TDoubleArrayList (NUM_OUTCOMES * NUM_OUTCOMES * NUM_OUTCOMES);
  for (int i = 0; i < (NUM_OUTCOMES * NUM_OUTCOMES); i++) {
    pC.add (dirichlet.randomVector (random));
  }

  Variable[] vars = new Variable[] { new Variable (NUM_OUTCOMES), new Variable (NUM_OUTCOMES),
          new Variable (NUM_OUTCOMES) };
  DirectedModel mdl = new DirectedModel ();
  mdl.addFactor (new CPT (new TableFactor (vars[0], pA), vars[0]));
  mdl.addFactor (new CPT (new TableFactor (vars[1], pB), vars[1]));
  mdl.addFactor (new CPT (new TableFactor (vars, pC.toArray ()), vars[2]));

  return mdl;
}
 
开发者ID:iamxiatian,项目名称:wikit,代码行数:24,代码来源:TestInference.java

示例4: createDirectedModel

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private DirectedModel createDirectedModel ()
{
  int NUM_OUTCOMES = 2;
  cc.mallet.util.Randoms random = new cc.mallet.util.Randoms (13413);

  Dirichlet dirichlet = new Dirichlet (NUM_OUTCOMES, 1.0);
  double[] pA = dirichlet.randomVector (random);
  double[] pB = dirichlet.randomVector (random);

  DoubleArrayList pC = new DoubleArrayList (NUM_OUTCOMES * NUM_OUTCOMES * NUM_OUTCOMES);
  for (int i = 0; i < (NUM_OUTCOMES * NUM_OUTCOMES); i++) {
    pC.add (dirichlet.randomVector (random));
  }

  Variable[] vars = new Variable[] { new Variable (NUM_OUTCOMES), new Variable (NUM_OUTCOMES),
          new Variable (NUM_OUTCOMES) };
  DirectedModel mdl = new DirectedModel ();
  mdl.addFactor (new CPT (new TableFactor (vars[0], pA), vars[0]));
  mdl.addFactor (new CPT (new TableFactor (vars[1], pB), vars[1]));
  mdl.addFactor (new CPT (new TableFactor (vars, pC.toArray ()), vars[2]));

  return mdl;
}
 
开发者ID:cmoen,项目名称:mallet,代码行数:24,代码来源:TestInference.java

示例5: testRandomTrainedOn

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private double testRandomTrainedOn (InstanceList training)
{
  ClassifierTrainer trainer = new MaxEntTrainer ();

  Alphabet fd = dictOfSize (3);
  String[] classNames = new String[] {"class0", "class1", "class2"};

  Randoms r = new Randoms (1);
  Iterator<Instance> iter = new RandomTokenSequenceIterator (r,  new Dirichlet(fd, 2.0),
        30, 0, 10, 200, classNames);
  training.addThruPipe (iter);

  InstanceList testing = new InstanceList (training.getPipe ());
  testing.addThruPipe (new RandomTokenSequenceIterator (r,  new Dirichlet(fd, 2.0),
        30, 0, 10, 200, classNames));

  System.out.println ("Training set size = "+training.size());
  System.out.println ("Testing set size = "+testing.size());

  Classifier classifier = trainer.train (training);

  System.out.println ("Accuracy on training set:");
  System.out.println (classifier.getClass().getName()
                        + ": " + new Trial (classifier, training).getAccuracy());

  System.out.println ("Accuracy on testing set:");
  double testAcc = new Trial (classifier, testing).getAccuracy();
  System.out.println (classifier.getClass().getName()
                        + ": " + testAcc);

  return testAcc;
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:33,代码来源:TestPagedInstanceList.java

示例6: RandomFeatureVectorIterator

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
public RandomFeatureVectorIterator (Randoms r,
																		// the generator of all random-ness used here
																		Dirichlet classCentroidDistribution,
																		// includes a Alphabet
																		double classCentroidAvergeAlphaMean,
																		// Gaussian mean on the sum of alphas
																		double classCentroidAvergeAlphaVariance,
																		// Gaussian variance on the sum of alphas
																		double featureVectorSizePoissonLambda,
																		double classInstanceCountPoissonLamba,
																		String[] classNames)
{
	this.r = r;
	this.classCentroidDistribution = classCentroidDistribution;
	assert (classCentroidDistribution.getAlphabet() instanceof Alphabet);
	this.classCentroidAvergeAlphaMean = classCentroidAvergeAlphaMean;
	this.classCentroidAvergeAlphaVariance = classCentroidAvergeAlphaVariance;
	this.featureVectorSizePoissonLambda = featureVectorSizePoissonLambda;
	this.classInstanceCountPoissonLamba = classInstanceCountPoissonLamba;
	this.classNames = classNames;
	this.numInstancesPerClass = new int[classNames.length];
	this.classCentroid = new Dirichlet[classNames.length];
	for (int i = 0; i < classNames.length; i++) {
		logger.fine ("classCentroidAvergeAlphaMean = "+classCentroidAvergeAlphaMean);
		double aveAlpha = r.nextGaussian (classCentroidAvergeAlphaMean,
																			classCentroidAvergeAlphaVariance);
		logger.fine ("aveAlpha = "+aveAlpha);
		classCentroid[i] = classCentroidDistribution.randomDirichlet (r, aveAlpha);
		//logger.fine ("Dirichlet for class "+classNames[i]);	classCentroid[i].print();
	}
	reset ();
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:33,代码来源:RandomFeatureVectorIterator.java

示例7: RandomTokenSequenceIterator

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
public RandomTokenSequenceIterator (Randoms r,
																		// the generator of all random-ness used here
																		Dirichlet classCentroidDistribution,
																		// includes a Alphabet
																		double classCentroidAvergeAlphaMean,
																		// Gaussian mean on the sum of alphas
																		double classCentroidAvergeAlphaVariance,
																		// Gaussian variance on the sum of alphas
																		double featureVectorSizePoissonLambda,
																		double classInstanceCountPoissonLamba,
																		String[] classNames)
{
	this.r = r;
	this.classCentroidDistribution = classCentroidDistribution;
	assert (classCentroidDistribution.getAlphabet() instanceof Alphabet);
	this.classCentroidAvergeAlphaMean = classCentroidAvergeAlphaMean;
	this.classCentroidAvergeAlphaVariance = classCentroidAvergeAlphaVariance;
	this.featureVectorSizePoissonLambda = featureVectorSizePoissonLambda;
	this.classInstanceCountPoissonLamba = classInstanceCountPoissonLamba;
	this.classNames = classNames;
	this.numInstancesPerClass = new int[classNames.length];
	this.classCentroid = new Dirichlet[classNames.length];
	for (int i = 0; i < classNames.length; i++) {
		logger.fine ("classCentroidAvergeAlphaMean = "+classCentroidAvergeAlphaMean);
		double aveAlpha = r.nextGaussian (classCentroidAvergeAlphaMean,
																			classCentroidAvergeAlphaVariance);
		logger.fine ("aveAlpha = "+aveAlpha);
		classCentroid[i] = classCentroidDistribution.randomDirichlet (r, aveAlpha);
		//logger.fine ("Dirichlet for class "+classNames[i]);	classCentroid[i].print();
	}
	reset ();
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:33,代码来源:RandomTokenSequenceIterator.java

示例8: testRandomTrainedOn

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private double testRandomTrainedOn (InstanceList training)
{
  ClassifierTrainer trainer = new MaxEntTrainer();

  Alphabet fd = dictOfSize (3);
  String[] classNames = new String[] {"class0", "class1", "class2"};

  Randoms r = new Randoms (1);
  Iterator<Instance> iter = new RandomTokenSequenceIterator (r,  new Dirichlet(fd, 2.0),
        30, 0, 10, 200, classNames);
  training.addThruPipe (iter);

  InstanceList testing = new InstanceList (training.getPipe ());
  testing.addThruPipe (new RandomTokenSequenceIterator (r,  new Dirichlet(fd, 2.0),
        30, 0, 10, 200, classNames));

  System.out.println ("Training set size = "+training.size());
  System.out.println ("Testing set size = "+testing.size());

  Classifier classifier = trainer.train (training);

  System.out.println ("Accuracy on training set:");
  System.out.println (classifier.getClass().getName()
                        + ": " + new Trial(classifier, training).getAccuracy());

  System.out.println ("Accuracy on testing set:");
  double testAcc = new Trial (classifier, testing).getAccuracy();
  System.out.println (classifier.getClass().getName()
                        + ": " + testAcc);

  return testAcc;
}
 
开发者ID:mimno,项目名称:Mallet,代码行数:33,代码来源:TestPagedInstanceList.java

示例9: updateSBPWeights

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private void updateSBPWeights() {
    if (sbpWeights.size() != topicWords.getNumComponents() + 1) {
        throw new RuntimeException("Mismatch: " + sbpWeights.size()
                + " vs. " + topicWords.getNumComponents());
    }

    SparseCount counts = new SparseCount();
    for (int k : topicWords.getIndices()) {
        for (int dd = 0; dd < D; dd++) {
            int count = docTopics[dd].getCount(k);
            if (count > 1) {
                int c = SamplerUtils.randAntoniak(
                        hyperparams.get(ALPHA_LOCAL) * sbpWeights.get(k),
                        count);
                counts.changeCount(k, c);
            } else {
                counts.changeCount(k, count);
            }
        }
    }
    double[] dirPrior = new double[topicWords.getNumComponents() + 1];
    ArrayList<Integer> indices = new ArrayList<Integer>();

    int idx = 0;
    for (int kk : topicWords.getIndices()) {
        indices.add(kk);
        dirPrior[idx++] = counts.getCount(kk);
    }

    indices.add(NEW_COMPONENT_INDEX);
    dirPrior[idx] = hyperparams.get(ALPHA_GLOBAL);

    Dirichlet dir = new Dirichlet(dirPrior);
    double[] wts = dir.nextDistribution();
    this.sbpWeights = new SparseVector();
    for (int ii = 0; ii < wts.length; ii++) {
        this.sbpWeights.set(indices.get(ii), wts[ii]);
    }
}
 
开发者ID:vietansegan,项目名称:segan,代码行数:40,代码来源:SHDP.java

示例10: sampleGlobalWeights

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
/**
 * Sample global distribution over topics.
 */
private void sampleGlobalWeights() {
    if (globalWeights.size() != topicWords.getNumComponents() + 1) {
        throw new RuntimeException("Mismatch: " + globalWeights.size()
                + " vs. " + topicWords.getNumComponents());
    }

    SparseCount counts = new SparseCount();
    for (int k : topicWords.getIndices()) {
        for (int ii = 0; ii < D; ii++) {
            int count = docTopics[ii].getCount(k);
            if (count > 1) {
                int c = SamplerUtils.randAntoniak(
                        hyperparams.get(ALPHA_LOCAL) * globalWeights.get(k),
                        count);
                counts.changeCount(k, c);
            } else {
                counts.changeCount(k, count);
            }
        }
    }
    double[] dirPrior = new double[topicWords.getNumComponents() + 1];
    ArrayList<Integer> indices = new ArrayList<Integer>();

    int idx = 0;
    for (int kk : topicWords.getIndices()) {
        indices.add(kk);
        dirPrior[idx++] = counts.getCount(kk);
    }

    indices.add(NEW_COMPONENT_INDEX);
    dirPrior[idx] = hyperparams.get(ALPHA_GLOBAL);

    Dirichlet dir = new Dirichlet(dirPrior);
    double[] wts = dir.nextDistribution();
    this.globalWeights = new SparseVector();
    for (int ii = 0; ii < wts.length; ii++) {
        this.globalWeights.set(indices.get(ii), wts[ii]);
    }
}
 
开发者ID:vietansegan,项目名称:segan,代码行数:43,代码来源:HDP.java

示例11: initializeModelStructure

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
private void initializeModelStructure() {
//        DirichletMultinomialModel dmModel = new DirichletMultinomialModel(V, betas[0], uniform);
        SparseCount count = new SparseCount();
        Dirichlet dir = new Dirichlet(betas[0] * V, uniform);
        this.globalTreeRoot = new RCRPNode(0, 0, count, dir.nextDistribution(), null);

        this.localRestaurants = new Restaurant[D];
        for (int d = 0; d < D; d++) {
            this.localRestaurants[d] = new Restaurant<RCRPTable, Integer, RCRPNode>();
        }

//        this.emptyModels = new DirichletMultinomialModel[L-1];
//        for(int l=0; l<emptyModels.length; l++)
//            this.emptyModels[l] = new DirichletMultinomialModel(V, betas[l+1], uniform);
    }
 
开发者ID:vietansegan,项目名称:segan,代码行数:16,代码来源:RCRPSampler.java

示例12: createGlobalNode

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
/**
     * Create a child node of a global node
     *
     * @param parentNode The parent node
     * @return The newly created child node
     */
    private RCRPNode createGlobalNode(RCRPNode parentNode) {
        int childIndex = parentNode.getNextChildIndex();
        int childLevel = parentNode.getLevel() + 1;
//        DirichletMultinomialModel llhModel = new DirichletMultinomialModel(V, betas[childLevel], uniform);
        SparseCount count = new SparseCount();
        Dirichlet dir = new Dirichlet(betas[childLevel], parentNode.getTopic());
        RCRPNode childNode = new RCRPNode(childIndex, childLevel, count, dir.nextDistribution(), parentNode);
        parentNode.addChild(childIndex, childNode);
        return childNode;
    }
 
开发者ID:vietansegan,项目名称:segan,代码行数:17,代码来源:RCRPSampler.java

示例13: computeLogLikelihoods

import cc.mallet.types.Dirichlet; //导入依赖的package包/类
/**
     * Recursively compute the log likelihoods of each node in the global tree
     * given a set of observations.
     *
     * @param nodeLlhs The hash table to store the result
     * @param curNode The current node
     * @param observations The set of observations
     */
    private void computeLogLikelihoods(HashMap<String, Double> nodeLlhs,
            RCRPNode curNode,
            HashMap<Integer, Integer> observations) {
//        double curNodeLlh = curNode.getContent().getLogLikelihood(observations);
        double curNodeLlh = curNode.computeLogLikelihood(observations);
        nodeLlhs.put(curNode.getPathString(), curNodeLlh);

        if (!this.isLeafNode(curNode)) {
//            double[] pseudoPrior = new double[V];
//            for(int v=0; v<V; v++)
//                pseudoPrior[v] = betas[curNode.getLevel()] / V * (curNode.getContent().getCount(v) + 
//                        curNode.getContent().getConcentration() * curNode.getContent().getCenterElement(v));
//            double pseudoChildLlh = SamplerUtils.computeLogLhood(obsCounts, obsCountSum, pseudoPrior);

//            double pseudoChildLlh = this.emptyModels[curNode.getLevel()].getLogLikelihood(observations);
//            System.out.println(curNode.toString());
//            System.out.println(MiscUtils.arrayToString(curNode.getTopic()));

            Dirichlet dir = new Dirichlet(betas[curNode.getLevel() + 1], curNode.getTopic());
            double[] newTopic = dir.nextDistribution();
            double pseudoChildLlh = computeLogLikelihood(newTopic, observations);
            nodeLlhs.put(curNode.getPseudoChildPathString(), pseudoChildLlh);

            for (RCRPNode child : curNode.getChildren()) {
                computeLogLikelihoods(nodeLlhs, child, observations);
            }
        }
    }
 
开发者ID:vietansegan,项目名称:segan,代码行数:37,代码来源:RCRPSampler.java


注:本文中的cc.mallet.types.Dirichlet类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。