本文整理汇总了Java中cc.mallet.fst.CRFTrainerByLabelLikelihood.train方法的典型用法代码示例。如果您正苦于以下问题:Java CRFTrainerByLabelLikelihood.train方法的具体用法?Java CRFTrainerByLabelLikelihood.train怎么用?Java CRFTrainerByLabelLikelihood.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类cc.mallet.fst.CRFTrainerByLabelLikelihood
的用法示例。
在下文中一共展示了CRFTrainerByLabelLikelihood.train方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: train
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
/**
*
* @param num_iterations
* @return
*/
public Boolean train(Integer num_iterations) {
this.model = new CRF(this.train_data.getPipe(), (Pipe) null);
for (int i = 0; i < this.model.numStates(); i++)
this.model.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
String startName = this.model.addOrderNStates(this.train_data, new int[] { 1 }, null, DEFAULT_LABEL, Pattern.compile("\\s"), Pattern.compile(".*"), true);
this.model.getState(startName).setInitialWeight(0.0);
CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(this.model);
crft.setGaussianPriorVariance(DEFAULT_PRIOR_VARIANCE);
crft.setUseSparseWeights(true);
crft.setUseSomeUnsupportedTrick(true);
for (int i = 0; i < num_iterations; i++)
if (crft.train(this.train_data, 1))
break;
return this.model != null;
}
示例2: testDenseFeatureSelection
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
public void testDenseFeatureSelection() {
Pipe p = makeSpacePredictionPipe();
InstanceList instances = new InstanceList(p);
instances.addThruPipe(new ArrayIterator(data));
// Test that dense observations wights aren't added for
// "default-feature" edges.
CRF crf1 = new CRF(p, null);
crf1.addOrderNStates(instances, new int[] { 0 }, null, "start", null,
null, true);
CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood(
crf1);
crft1.setUseSparseWeights(false);
crft1.train(instances, 1); // Set weights dimension
int nParams1 = crft1.getOptimizableCRF(instances).getNumParameters();
CRF crf2 = new CRF(p, null);
crf2.addOrderNStates(instances, new int[] { 0, 1 }, new boolean[] {
false, true }, "start", null, null, true);
CRFTrainerByLabelLikelihood crft2 = new CRFTrainerByLabelLikelihood(
crf2);
crft2.setUseSparseWeights(false);
crft2.train(instances, 1); // Set weights dimension
int nParams2 = crft2.getOptimizableCRF(instances).getNumParameters();
assertEquals(nParams2, nParams1 + 4);
}
示例3: testXis
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
public void testXis() {
Pipe p = makeSpacePredictionPipe();
InstanceList instances = new InstanceList(p);
instances.addThruPipe(new ArrayIterator(data));
CRF crf1 = new CRF(p, null);
crf1.addFullyConnectedStatesForLabels();
CRFTrainerByLabelLikelihood crft1 = new CRFTrainerByLabelLikelihood(
crf1);
crft1.train(instances, 10); // Let's get some parameters
Instance inst = instances.get(0);
Sequence input = (Sequence) inst.getData();
SumLatticeDefault lattice = new SumLatticeDefault(crf1, input,
(Sequence) inst.getTarget(), null, true);
for (int ip = 0; ip < lattice.length() - 1; ip++) {
for (int i = 0; i < crf1.numStates(); i++) {
Transducer.State state = crf1.getState(i);
Transducer.TransitionIterator it = state.transitionIterator(
input, ip);
double gamma = lattice.getGammaProbability(ip, state);
double xiSum = 0;
while (it.hasNext()) {
Transducer.State dest = it.nextState();
double xi = lattice.getXiProbability(ip, state, dest);
xiSum += xi;
}
assertEquals(gamma, xiSum, 1e-5);
}
}
}
示例4: testDualSpaceViewer
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
public void testDualSpaceViewer () throws IOException
{
Pipe pipe = TestMEMM.makeSpacePredictionPipe ();
String[] data0 = { TestCRF.data[0] };
String[] data1 = TestCRF.data;
InstanceList training = new InstanceList (pipe);
training.addThruPipe (new ArrayIterator (data0));
InstanceList testing = new InstanceList (pipe);
testing.addThruPipe (new ArrayIterator (data1));
CRF crf = new CRF (pipe, null);
crf.addFullyConnectedStatesForLabels ();
CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
TokenAccuracyEvaluator eval = new TokenAccuracyEvaluator (new InstanceList[] {training, testing}, new String[] {"Training", "Testing"});
for (int i = 0; i < 5; i++) {
crft.train (training, 1);
eval.evaluate(crft);
}
CRFExtractor extor = hackCrfExtor (crf);
Extraction e1 = extor.extract (new ArrayIterator (data1));
Pipe pipe2 = TestMEMM.makeSpacePredictionPipe ();
InstanceList training2 = new InstanceList (pipe2);
training2.addThruPipe (new ArrayIterator (data0));
InstanceList testing2 = new InstanceList (pipe2);
testing2.addThruPipe (new ArrayIterator (data1));
MEMM memm = new MEMM (pipe2, null);
memm.addFullyConnectedStatesForLabels ();
MEMMTrainer memmt = new MEMMTrainer (memm);
TransducerEvaluator memmeval = new TokenAccuracyEvaluator (new InstanceList[] {training2, testing2}, new String[] {"Training2", "Testing2"});
memmt.train (training2, 5);
memmeval.evaluate(memmt);
CRFExtractor extor2 = hackCrfExtor (memm);
Extraction e2 = extor2.extract (new ArrayIterator (data1));
if (!htmlDir.exists ()) htmlDir.mkdir ();
LatticeViewer.viewDualResults (htmlDir, e1, extor, e2, extor2);
}
示例5: ignoretestDualSpaceViewer
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
public void ignoretestDualSpaceViewer () throws IOException
{
Pipe pipe = TestMEMM.makeSpacePredictionPipe ();
String[] data0 = { TestCRF.data[0] };
String[] data1 = TestCRF.data;
InstanceList training = new InstanceList (pipe);
training.addThruPipe (new ArrayIterator (data0));
InstanceList testing = new InstanceList (pipe);
testing.addThruPipe (new ArrayIterator (data1));
CRF crf = new CRF (pipe, null);
crf.addFullyConnectedStatesForLabels ();
CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
TokenAccuracyEvaluator eval = new TokenAccuracyEvaluator (new InstanceList[] {training, testing}, new String[] {"Training", "Testing"});
for (int i = 0; i < 5; i++) {
crft.train (training, 1);
eval.evaluate(crft);
}
CRFExtractor extor = hackCrfExtor (crf);
Extraction e1 = extor.extract (new ArrayIterator (data1));
Pipe pipe2 = TestMEMM.makeSpacePredictionPipe ();
InstanceList training2 = new InstanceList (pipe2);
training2.addThruPipe (new ArrayIterator (data0));
InstanceList testing2 = new InstanceList (pipe2);
testing2.addThruPipe (new ArrayIterator (data1));
MEMM memm = new MEMM (pipe2, null);
memm.addFullyConnectedStatesForLabels ();
MEMMTrainer memmt = new MEMMTrainer (memm);
TransducerEvaluator memmeval = new TokenAccuracyEvaluator (new InstanceList[] {training2, testing2}, new String[] {"Training2", "Testing2"});
memmt.train (training2, 5);
memmeval.evaluate(memmt);
CRFExtractor extor2 = hackCrfExtor (memm);
Extraction e2 = extor2.extract (new ArrayIterator (data1));
if (!htmlDir.exists ()) htmlDir.mkdir ();
LatticeViewer.viewDualResults (htmlDir, e1, extor, e2, extor2);
}
示例6: TrainCRF
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
public TrainCRF(String trainingFilename, String testingFilename) throws IOException {
ArrayList<Pipe> pipes = new ArrayList<Pipe>();
int[][] conjunctions = new int[2][];
conjunctions[0] = new int[] { -1 };
conjunctions[1] = new int[] { 1 };
pipes.add(new SimpleTaggerSentence2TokenSequence());
pipes.add(new OffsetConjunctions(conjunctions));
//pipes.add(new FeaturesInWindow("PREV-", -1, 1));
pipes.add(new TokenTextCharSuffix("C1=", 1));
pipes.add(new TokenTextCharSuffix("C2=", 2));
pipes.add(new TokenTextCharSuffix("C3=", 3));
pipes.add(new RegexMatches("CAPITALIZED", Pattern.compile("^\\p{Lu}.*")));
pipes.add(new RegexMatches("STARTSNUMBER", Pattern.compile("^[0-9].*")));
pipes.add(new RegexMatches("HYPHENATED", Pattern.compile(".*\\-.*")));
pipes.add(new RegexMatches("DOLLARSIGN", Pattern.compile(".*\\$.*")));
pipes.add(new TokenFirstPosition("FIRSTTOKEN"));
pipes.add(new TokenSequence2FeatureVectorSequence());
Pipe pipe = new SerialPipes(pipes);
InstanceList trainingInstances = new InstanceList(pipe);
InstanceList testingInstances = new InstanceList(pipe);
trainingInstances.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFilename)))), Pattern.compile("^\\s*$"), true));
testingInstances.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(testingFilename)))), Pattern.compile("^\\s*$"), true));
CRF crf = new CRF(pipe, null);
//crf.addStatesForLabelsConnectedAsIn(trainingInstances);
crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingInstances);
crf.addStartState();
CRFTrainerByLabelLikelihood trainer =
new CRFTrainerByLabelLikelihood(crf);
trainer.setGaussianPriorVariance(10.0);
//CRFTrainerByStochasticGradient trainer =
//new CRFTrainerByStochasticGradient(crf, 1.0);
//CRFTrainerByL1LabelLikelihood trainer =
// new CRFTrainerByL1LabelLikelihood(crf, 0.75);
//trainer.addEvaluator(new PerClassAccuracyEvaluator(trainingInstances, "training"));
trainer.addEvaluator(new PerClassAccuracyEvaluator(testingInstances, "testing"));
trainer.addEvaluator(new TokenAccuracyEvaluator(testingInstances, "testing"));
trainer.train(trainingInstances);
}
示例7: TrainWikiCRF
import cc.mallet.fst.CRFTrainerByLabelLikelihood; //导入方法依赖的package包/类
public TrainWikiCRF(String trainingFilename, String testingFilename) throws IOException {
ArrayList<Pipe> pipes = new ArrayList<Pipe>();
int[][] conjunctions = new int[2][];
conjunctions[0] = new int[] { -1 };
conjunctions[1] = new int[] { 1 };
pipes.add(new SimpleTaggerSentence2TokenSequence());
pipes.add(new OffsetConjunctions(conjunctions));
//pipes.add(new FeaturesInWindow("PREV-", -1, 1));
pipes.add(new TokenTextCharSuffix("C1=", 1));
pipes.add(new TokenTextCharSuffix("C2=", 2));
pipes.add(new TokenTextCharSuffix("C3=", 3));
pipes.add(new RegexMatches("CAPITALIZED", Pattern.compile("^\\p{Lu}.*")));
pipes.add(new RegexMatches("STARTSNUMBER", Pattern.compile("^[0-9].*")));
pipes.add(new RegexMatches("HYPHENATED", Pattern.compile(".*\\-.*")));
pipes.add(new RegexMatches("DOLLARSIGN", Pattern.compile(".*\\$.*")));
pipes.add(new TokenFirstPosition("FIRSTTOKEN"));
pipes.add(new TokenSequence2FeatureVectorSequence());
Pipe pipe = new SerialPipes(pipes);
InstanceList trainingInstances = new InstanceList(pipe);
InstanceList testingInstances = new InstanceList(pipe);
trainingInstances.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFilename)))), Pattern.compile("^\\s*$"), true));
testingInstances.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(testingFilename)))), Pattern.compile("^\\s*$"), true));
CRF crf = new CRF(pipe, null);
//crf.addStatesForLabelsConnectedAsIn(trainingInstances);
crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingInstances);
crf.addStartState();
CRFTrainerByLabelLikelihood trainer =
new CRFTrainerByLabelLikelihood(crf);
trainer.setGaussianPriorVariance(10.0);
//CRFTrainerByStochasticGradient trainer =
//new CRFTrainerByStochasticGradient(crf, 1.0);
//CRFTrainerByL1LabelLikelihood trainer =
// new CRFTrainerByL1LabelLikelihood(crf, 0.75);
//trainer.addEvaluator(new PerClassAccuracyEvaluator(trainingInstances, "training"));
trainer.addEvaluator(new PerClassAccuracyEvaluator(testingInstances, "testing"));
trainer.addEvaluator(new TokenAccuracyEvaluator(testingInstances, "testing"));
trainer.train(trainingInstances);
}