本文整理汇总了Python中pystruct.learners.FrankWolfeSSVM.train方法的典型用法代码示例。如果您正苦于以下问题:Python FrankWolfeSSVM.train方法的具体用法?Python FrankWolfeSSVM.train怎么用?Python FrankWolfeSSVM.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类pystruct.learners.FrankWolfeSSVM
的用法示例。
在下文中一共展示了FrankWolfeSSVM.train方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: PassageTagger
# 需要导入模块: from pystruct.learners import FrankWolfeSSVM [as 别名]
# 或者: from pystruct.learners.FrankWolfeSSVM import train [as 别名]
class PassageTagger(object):
def __init__(self, do_train=False, trained_model_name="passage_crf_model", algorithm="crf"):
self.trained_model_name = trained_model_name
self.fp = FeatureProcessing()
self.do_train = do_train
self.algorithm = algorithm
if algorithm == "crf":
if do_train:
self.trainer = Trainer()
else:
self.tagger = Tagger()
else:
if do_train:
model = ChainCRF()
self.trainer = FrankWolfeSSVM(model=model)
self.feat_index = {}
self.label_index = {}
else:
self.tagger = pickle.load(open(self.trained_model_name, "rb"))
self.feat_index = pickle.load(open("ssvm_feat_index.pkl", "rb"))
label_index = pickle.load(open("ssvm_label_index.pkl", "rb"))
self.rev_label_index = {i: x for x, i in label_index.items()}
def read_input(self, filename):
str_seqs = []
str_seq = []
feat_seqs = []
feat_seq = []
label_seqs = []
label_seq = []
for line in codecs.open(filename, "r", "utf-8"):
lnstrp = line.strip()
if lnstrp == "":
if len(str_seq) != 0:
str_seqs.append(str_seq)
str_seq = []
feat_seqs.append(feat_seq)
feat_seq = []
label_seqs.append(label_seq)
label_seq = []
else:
if self.do_train:
clause, label = lnstrp.split("\t")
label_seq.append(label)
else:
clause = lnstrp
str_seq.append(clause)
feats = self.fp.get_features(clause)
feat_dict = {}
for f in feats:
if f in feat_dict:
feat_dict[f] += 1
else:
feat_dict[f] = 1
#feat_dict = {i: v for i, v in enumerate(feats)}
feat_seq.append(feat_dict)
if len(str_seq) != 0:
str_seqs.append(str_seq)
str_seq = []
feat_seqs.append(feat_seq)
feat_seq = []
label_seqs.append(label_seq)
label_seq = []
return str_seqs, feat_seqs, label_seqs
def predict(self, feat_seqs):
print >>sys.stderr, "Tagging %d sequences"%len(feat_seqs)
if self.algorithm == "crf":
self.tagger.open(self.trained_model_name)
preds = [self.tagger.tag(ItemSequence(feat_seq)) for feat_seq in feat_seqs]
else:
Xs = []
for fs in feat_seqs:
X = []
for feat_dict in fs:
x = [0] * len(self.feat_index)
for f in feat_dict:
if f in self.feat_index:
x[self.feat_index[f]] = feat_dict[f]
X.append(x)
Xs.append(numpy.asarray(X))
pred_ind_seqs = self.tagger.predict(Xs)
preds = []
for ps in pred_ind_seqs:
pred = []
for pred_ind in ps:
pred.append(self.rev_label_index[pred_ind])
preds.append(pred)
return preds
def train(self, feat_seqs, label_seqs):
print >>sys.stderr, "Training on %d sequences"%len(feat_seqs)
if self.algorithm == "crf":
for feat_seq, label_seq in zip(feat_seqs, label_seqs):
self.trainer.append(ItemSequence(feat_seq), label_seq)
self.trainer.train(self.trained_model_name)
else:
for fs in feat_seqs:
for feat_dict in fs:
for f in feat_dict:
#.........这里部分代码省略.........