当前位置: 首页>>代码示例>>Python>>正文


Python Tree.set_label方法代码示例

本文整理汇总了Python中nltk.tree.Tree.set_label方法的典型用法代码示例。如果您正苦于以下问题:Python Tree.set_label方法的具体用法?Python Tree.set_label怎么用?Python Tree.set_label使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在nltk.tree.Tree的用法示例。


在下文中一共展示了Tree.set_label方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _strip_functional_tags

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]
    def _strip_functional_tags(self, tree: Tree) -> None:
        """
        Removes all functional tags from constituency labels in an NLTK tree.
        We also strip off anything after a =, - or | character, because these
        are functional tags which we don't want to use.

        This modification is done in-place.
        """
        clean_label = tree.label().split("=")[0].split("-")[0].split("|")[0]
        tree.set_label(clean_label)
        for child in tree:
            if not isinstance(child[0], str):
                self._strip_functional_tags(child)
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:15,代码来源:penn_tree_bank.py

示例2: build_tree

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]
def build_tree(node,chain): # -> handle function tags
    """ -> PS tree of node's projection chain """
    preterminal = node['tag']
    if 'lemma' in node: # not a trace-node
        if (node['lemma'].lower() in wh_lemmas) and \
           node['tag']!='CONJ': #WH feature
            preterminal += '-WH'    
    output = Tree(preterminal,[node['word']])
    for l in chain[0][::-1]:
        for i in range(l[1]):
            output = Tree(l[0],[output])
    if chain[1]:
        if chain[1]=='PRN':
            output = Tree(chain[1],[output])
        else:
            output.set_label(output.label()+'-'+chain[1])
    return output
开发者ID:luutuntin,项目名称:SynTagRus_DS2PS,代码行数:19,代码来源:syntagrus_ds2ps.py

示例3: forward

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]

#.........这里部分代码省略.........
    xq_list = []
    qc = QUEUE_ZEROS
    q = QUEUE_ZEROS
    
    for text, wid in reversed(word_list):
      x = self.net_embed(XP.iarray([wid]))
      qc, q = self.net_encoder(qc, x, q)
      xq_list.insert(0, (text, x, q))

    # estimate
    s_list = []
    zc = SRSTATE_ZEROS
    z = SRSTATE_ZEROS
    unary_chain = 0
    if is_training:
      loss = XP.fzeros(())

    for i in itertools.count():
      text, x, q = xq_list[0] if xq_list else ('', EMBED_ZEROS, QUEUE_ZEROS)
      t1, sc1, s1 = s_list[-1] if s_list else (None, STACK_ZEROS, STACK_ZEROS)
      t2, sc2, s2 = s_list[-2] if len(s_list) >= 2 else (None, STACK_ZEROS, STACK_ZEROS)
      t3, sc3, s3 = s_list[-3] if len(s_list) >= 3 else (None, STACK_ZEROS, STACK_ZEROS)

      zc, z = self.net_sr(zc, q, s1, z)  
      o = self.net_operation(z)

      if is_training:
        loss += functions.softmax_cross_entropy(o, XP.iarray([gold_op_list[i][0]]))
        o_argmax = gold_op_list[i][0]
      else:
        o_filter = [0.0 for _ in range(NUM_OP)]
        filtered = 0
        if not xq_list:
          o_filter[OP_SHIFT] = NEG_INF
          filtered += 1
        if not s_list or unary_chain >= unary_limit:
          o_filter[OP_UNARY] = NEG_INF
          filtered += 1
        if len(s_list) < 2:
          o_filter[OP_BINARY] = NEG_INF
          filtered += 1
        if xq_list or len(s_list) > 1:
          o_filter[OP_FINISH] = NEG_INF
        if filtered == NUM_OP:
          raise RuntimeError('No possible operation!')

        o += XP.farray([o_filter])
        o_argmax = int(cuda.to_cpu(o.data.argmax(1)))

      if o_argmax == OP_SHIFT:
        t0 = Tree(None, [text])
        sc0, s0 = (STACK_ZEROS, self.net_shift(x, q, s1, z))
        xq_list.pop(0)
        unary_chain = 0
        label = self.net_semiterminal(s0)
      elif o_argmax == OP_UNARY:
        t0 = Tree(None, [t1])
        sc0, s0 = self.net_unary(sc1, q, s1, s2, z)
        s_list.pop()
        unary_chain += 1
        label = self.net_phrase(s0)
      elif o_argmax == OP_BINARY:
        t0 = Tree(None, [t2, t1])
        sc0, s0 = self.net_binary(sc1, sc2, q, s1, s2, s3, z)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_phrase(s0)
      else: # OP_FINISH
        break

      if is_training:
        loss += functions.softmax_cross_entropy(label, XP.iarray([gold_op_list[i][1]]))
        label_argmax = gold_op_list[i][1]
      else:
        label_argmax = int(cuda.to_cpu(label.data.argmax(1)))

      t0.set_label(label_argmax)
      s_list.append((t0, sc0, s0))

      '''
      if is_training:
        o_est = int(cuda.to_cpu(o.data.argmax(1)))
        label_est = int(cuda.to_cpu(label.data.argmax(1)))
        trace('%c %c gold=%d-%2d, est=%d-%2d, stack=%2d, queue=%2d' % (
            '*' if o_est == gold_op_list[i][0] else ' ',
            '*' if label_est == gold_op_list[i][1] else ' ',
            gold_op_list[i][0], gold_op_list[i][1],
            o_est, label_est,
            len(s_list), len(xq_list)))
      '''

    if is_training:
      return loss
    else:
      # combine multiple trees if they exists, and return the result.
      t0, _, __ = s_list.pop()
      if s_list:
        raise RuntimeError('There exist multiple subtrees!')
      return t0
开发者ID:odashi,项目名称:nn_parsers,代码行数:104,代码来源:parse10.py

示例4: forward

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]

#.........这里部分代码省略.........
      (text, x, j, k, a, b) \
      for (text, _), x, (j, k), a, b \
      in zip(word_list, x_list, jk_list, a_list, b_list)]

    # estimate
    s_list = []
    zc = self.SRSTATE_ZEROS
    z = self.SRSTATE_ZEROS
    unary_chain = 0
    if is_training:
      loss = XP.fzeros(())

    for i in itertools.count():
      if is_training:
        gold_op, gold_label, gold_op_vram, gold_label_vram = gold_op_list[i]
      text, x, j, k, a, b = q_list[0] if q_list else self.QUEUE_DEFAULT
      t1, sc1, s1, rc1, r1 = s_list[-1] if s_list else self.STACK_DEFAULT
      t2, sc2, s2, rc2, r2 = s_list[-2] if len(s_list) >= 2 else self.STACK_DEFAULT
      t3, sc3, s3, rc3, r3 = s_list[-3] if len(s_list) >= 3 else self.STACK_DEFAULT

      zc, z = self.net_sr(zc, a, b, s1, r1, s2, r2, z)  
      o = self.net_operation(z)

      if is_training:
        loss += functions.softmax_cross_entropy(o, gold_op_vram)
        o_argmax = gold_op
      else:
        o_filter = [0.0 for _ in range(NUM_OP)]
        filtered = 0
        if not q_list:
          o_filter[OP_SHIFT] = self.NEG_INF
          filtered += 1
        if not s_list or unary_chain >= unary_limit:
          o_filter[OP_UNARY] = self.NEG_INF
          filtered += 1
        if len(s_list) < 2:
          o_filter[OP_BINARY] = self.NEG_INF
          filtered += 1
        if q_list or len(s_list) > 1:
          o_filter[OP_FINISH] = self.NEG_INF
        if filtered == NUM_OP:
          raise RuntimeError('No possible operation!')

        o += XP.farray([o_filter])
        o_argmax = int(cuda.to_cpu(o.data.argmax(1)))

      if o_argmax == OP_SHIFT:
        t0 = Tree(None, [text])
        sc0, s0 = (self.STACK_ZEROS, self.net_shift(x, j, k, a, b, s1))
        rc0, r0 = self.net_stack(rc1, s0, r1)
        q_list.pop(0)
        unary_chain = 0
        label = self.net_semiterminal(s0)
      elif o_argmax == OP_UNARY:
        t0 = Tree(None, [t1])
        sc0, s0 = self.net_unary(sc1, a, b, s1, s2)
        rc0, r0 = self.net_stack(rc2, s0, r2)
        s_list.pop()
        unary_chain += 1
        label = self.net_phrase(s0)
      elif o_argmax == OP_BINARY:
        t0 = Tree(None, [t2, t1])
        sc0, s0 = self.net_binary(sc1, sc2, a, b, s1, s2, s3)
        rc0, r0 = self.net_stack(rc3, s0, r3)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_phrase(s0)
      else: # OP_FINISH
        break

      if is_training:
        loss += functions.softmax_cross_entropy(label, gold_label_vram)
        label_argmax = gold_label
      else:
        label_argmax = int(cuda.to_cpu(label.data.argmax(1)))

      t0.set_label(label_argmax)
      s_list.append((t0, sc0, s0, rc0, r0))

      '''
      if is_training:
        o_est = int(cuda.to_cpu(o.data.argmax(1)))
        label_est = int(cuda.to_cpu(label.data.argmax(1)))
        trace('%c %c gold=%d-%2d, est=%d-%2d, stack=%2d, queue=%2d' % (
            '*' if o_est == gold_op else ' ',
            '*' if label_est == gold_label else ' ',
            gold_op, gold_label,
            o_est, label_est,
            len(s_list), len(q_list)))
      '''

    if is_training:
      return loss
    else:
      # combine multiple trees if they exists, and return the result.
      t0, *_ = s_list.pop()
      if s_list:
        raise RuntimeError('There exist multiple subtrees!')
      return t0
开发者ID:odashi,项目名称:nn_parsers,代码行数:104,代码来源:parse29.py

示例5: forward

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]

#.........这里部分代码省略.........
    s_list = [] # [(tree, state)]
    unary_chain = 0
    if is_training:
      loss = my_zeros((), np.float32)

    # estimate
    for i in itertools.count():
      text, x, q = xq_list[0] if xq_list else ('', EMBED_ZEROS, QUEUE_ZEROS)
      t1, s1 = s_list[-1] if s_list else (None, STACK_ZEROS)
      t2, s2 = s_list[-2] if len(s_list) > 1 else (None, STACK_ZEROS)
      t3, s3 = s_list[-3] if len(s_list) > 2 else (None, STACK_ZEROS)
      p1 = (t1.label() == -1) if t1 is not None else False
      p2 = (t2.label() == -1) if t2 is not None else False

      op = self.net_operation(x, q, s1, s2, s3)

      if is_training:
        loss += functions.softmax_cross_entropy(op, my_array([op_list[i][0]], np.int32))
        op_argmax = op_list[i][0]
      else:
        op_filter = [0.0 for _ in range(NUM_OP)]
        filtered = 0
        if not xq_list or p2:
          op_filter[OP_SHIFT] = NEG_INF
          filtered += 1
        if not s_list or unary_chain >= unary_limit or p1:
          op_filter[OP_UNARY] = NEG_INF
          filtered += 1
        if not xq_list or len(s_list) < 2 or p1:
          op_filter[OP_PARTIAL] = NEG_INF
          filtered += 1
        if len(s_list) < 2 or p1:
          op_filter[OP_REDUCE] = NEG_INF
          filtered += 1
        if xq_list or len(s_list) > 1:
          op_filter[OP_FINISH] = NEG_INF
        if filtered == NUM_OP:
          raise RuntimeError('No possible operation!')

        op += my_array([op_filter], np.float32)
        op_argmax = int(cuda.to_cpu(op.data.argmax(1)))

      if op_argmax == OP_SHIFT:
        t0 = Tree(None, [text])
        s0 = self.net_shift(x, q, s1)
        xq_list.pop(0)
        unary_chain = 0
        label = self.net_semi_label(s0)
      elif op_argmax == OP_UNARY:
        t0 = Tree(None, [t1])
        s0 = self.net_unary(q, s1, s2)
        s_list.pop()
        unary_chain += 1
        label = self.net_phrase_label(s0)
      elif op_argmax == OP_PARTIAL:
        t0 = Tree(None, [t2, t1])
        s0 = self.net_partial(q, s1, s2, s3)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_partial_label(s0)
      elif op_argmax == OP_REDUCE:
        t0 = Tree(None, [t2, t1])
        s0 = self.net_reduce(q, s1, s2, s3)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_phrase_label(s0)
      else: # OP_FINISH
        break

      if is_training:
        loss += functions.softmax_cross_entropy(label, my_array([op_list[i][1]], np.int32))
        label_argmax = op_list[i][1] if op_argmax != OP_PARTIAL else -1
      else:
        label_argmax = int(cuda.to_cpu(label.data.argmax(1))) if op_argmax != OP_PARTIAL else -1
      t0.set_label(label_argmax)

      s_list.append((t0, s0))

      '''
      if is_training:
        op_est = int(cuda.to_cpu(op.data.argmax(1)))
        label_est = int(cuda.to_cpu(label.data.argmax(1))) if op_argmax != OP_PARTIAL else -1
        trace('%c %c gold=%d-%2d, est=%d-%2d, stack=%2d, queue=%2d' % (
            '*' if op_est == op_list[i][0] else ' ',
            '*' if label_est == op_list[i][1] else ' ',
            op_list[i][0], op_list[i][1],
            op_est, label_est,
            len(s_list), len(xq_list)))
      '''
    
    if is_training:
      return loss
    else:
      # combine multiple trees if they exists, and return the result.
      t0, _ = s_list.pop()
      if s_list:
        raise RuntimeError('There exist multiple subtrees!')
      return unbinarize(t0)
开发者ID:odashi,项目名称:nn_parsers,代码行数:104,代码来源:parse02.py

示例6: forward_test

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]
  def forward_test(self, word_list, unary_limit):
    # check args
    if not isinstance(word_list, list) or len(word_list) < 1:
      raise ValueError('Word list is empty.')
    if not isinstance(unary_limit, int) or unary_limit < 0:
      raise ValueError('unary_limit must be non-negative integer.')

    # default values
    EMBED_ZEROS = XP.fzeros((1, self.n_embed))
    CEMBED_ZEROS = XP.fzeros((1, self.n_char_embed))
    QUEUE_ZEROS = XP.fzeros((1, self.n_queue))
    STACK_ZEROS = XP.fzeros((1, self.n_stack))
    SRSTATE_ZEROS = XP.fzeros((1, self.n_srstate))
    QUEUE_DEFAULT = ('', EMBED_ZEROS, CEMBED_ZEROS, CEMBED_ZEROS, QUEUE_ZEROS, QUEUE_ZEROS)
    STACK_DEFAULT = (None, STACK_ZEROS, STACK_ZEROS, STACK_ZEROS, STACK_ZEROS)
    NEG_INF = -1e20

    q_list = make_queue(word_list, QUEUE_ZEROS)

    # estimate
    s_list = []
    zc = SRSTATE_ZEROS
    z = SRSTATE_ZEROS
    unary_chain = 0

    for i in itertools.count():
      text, x, j, k, a, b = q_list[0] if q_list else QUEUE_DEFAULT
      t1, sc1, s1, rc1, r1 = s_list[-1] if s_list else STACK_DEFAULT
      t2, sc2, s2, rc2, r2 = s_list[-2] if len(s_list) >= 2 else STACK_DEFAULT
      t3, sc3, s3, rc3, r3 = s_list[-3] if len(s_list) >= 3 else STACK_DEFAULT

      zc, z = self.net_sr(zc, a, b, s1, r1, s2, r2, z)  
      o = self.net_operation(z)

      o_filter = [0.0 for _ in range(NUM_OP)]
      filtered = 0
      if not q_list:
        o_filter[OP_SHIFT] = NEG_INF
        filtered += 1
      if not s_list or unary_chain >= unary_limit:
        o_filter[OP_UNARY] = NEG_INF
        filtered += 1
      if len(s_list) < 2:
        o_filter[OP_BINARY] = NEG_INF
        filtered += 1
      if q_list or len(s_list) > 1:
        o_filter[OP_FINISH] = NEG_INF
      if filtered == NUM_OP:
        raise RuntimeError('No possible operation!')

      o += XP.farray([o_filter])
      o_argmax = int(cuda.to_cpu(o.data.argmax(1)))

      if o_argmax == OP_SHIFT:
        t0 = Tree(None, [text])
        sc0, s0 = (STACK_ZEROS, self.net_shift(x, j, k, a, b, s1))
        rc0, r0 = self.net_stack(rc1, s0, r1)
        q_list.pop(0)
        unary_chain = 0
        label = self.net_semiterminal(s0)
      elif o_argmax == OP_UNARY:
        t0 = Tree(None, [t1])
        sc0, s0 = self.net_unary(sc1, a, b, s1, s2)
        rc0, r0 = self.net_stack(rc2, s0, r2)
        s_list.pop()
        unary_chain += 1
        label = self.net_phrase(s0)
      elif o_argmax == OP_BINARY:
        t0 = Tree(None, [t2, t1])
        sc0, s0 = self.net_binary(sc1, sc2, a, b, s1, s2, s3)
        rc0, r0 = self.net_stack(rc3, s0, r3)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_phrase(s0)
      else: # OP_FINISH
        break

      label_argmax = int(cuda.to_cpu(label.data.argmax(1)))

      t0.set_label(label_argmax)
      s_list.append((t0, sc0, s0, rc0, r0))

    # combine multiple trees if they exists, and return the result.
    t0, *_ = s_list.pop()
    if s_list:
      raise RuntimeError('There exist multiple subtrees!')
    return t0
开发者ID:odashi,项目名称:nn_parsers,代码行数:90,代码来源:parse23_beam.py

示例7: forward_train

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]
  def forward_train(self, word_list, gold_op_list):
    # check args
    if not isinstance(word_list, list) or len(word_list) < 1:
      raise ValueError('Word list is empty.')
    n_shift = 0
    n_binary = 0
    for op, _ in gold_op_list:
      if op == OP_SHIFT: n_shift += 1
      if op == OP_BINARY: n_binary += 1
    if n_shift != len(word_list) or n_binary != len(word_list) - 1:
      raise ValueError(
          'Invalid operation number: SHIFT=%d (required: %d), BINARY=%d (required: %d)' %
          (n_shift, n_binary, len(word_list), len(word_list) - 1))
    if gold_op_list[-1] != (OP_FINISH, None):
      raise ValueError('Last operation is not OP_FINISH.')

    # default values
    EMBED_ZEROS = XP.fzeros((1, self.n_embed))
    CEMBED_ZEROS = XP.fzeros((1, self.n_char_embed))
    QUEUE_ZEROS = XP.fzeros((1, self.n_queue))
    STACK_ZEROS = XP.fzeros((1, self.n_stack))
    SRSTATE_ZEROS = XP.fzeros((1, self.n_srstate))
    QUEUE_DEFAULT = ('', EMBED_ZEROS, CEMBED_ZEROS, CEMBED_ZEROS, QUEUE_ZEROS, QUEUE_ZEROS)
    STACK_DEFAULT = (None, STACK_ZEROS, STACK_ZEROS, STACK_ZEROS, STACK_ZEROS)
    NEG_INF = -1e20

    q_list = make_queue(word_list, QUEUE_ZEROS)

    # estimate
    s_list = []
    zc = SRSTATE_ZEROS
    z = SRSTATE_ZEROS
    unary_chain = 0
    loss = XP.fzeros(())

    for i in itertools.count():
      text, x, j, k, a, b = q_list[0] if q_list else QUEUE_DEFAULT
      t1, sc1, s1, rc1, r1 = s_list[-1] if s_list else STACK_DEFAULT
      t2, sc2, s2, rc2, r2 = s_list[-2] if len(s_list) >= 2 else STACK_DEFAULT
      t3, sc3, s3, rc3, r3 = s_list[-3] if len(s_list) >= 3 else STACK_DEFAULT

      zc, z = self.net_sr(zc, a, b, s1, r1, s2, r2, z)  
      o = self.net_operation(z)

      loss += functions.softmax_cross_entropy(o, XP.iarray([gold_op_list[i][0]]))
      o_argmax = gold_op_list[i][0]

      if o_argmax == OP_SHIFT:
        t0 = Tree(None, [text])
        sc0, s0 = (STACK_ZEROS, self.net_shift(x, j, k, a, b, s1))
        rc0, r0 = self.net_stack(rc1, s0, r1)
        q_list.pop(0)
        unary_chain = 0
        label = self.net_semiterminal(s0)
      elif o_argmax == OP_UNARY:
        t0 = Tree(None, [t1])
        sc0, s0 = self.net_unary(sc1, a, b, s1, s2)
        rc0, r0 = self.net_stack(rc2, s0, r2)
        s_list.pop()
        unary_chain += 1
        label = self.net_phrase(s0)
      elif o_argmax == OP_BINARY:
        t0 = Tree(None, [t2, t1])
        sc0, s0 = self.net_binary(sc1, sc2, a, b, s1, s2, s3)
        rc0, r0 = self.net_stack(rc3, s0, r3)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_phrase(s0)
      else: # OP_FINISH
        break

      loss += functions.softmax_cross_entropy(label, XP.iarray([gold_op_list[i][1]]))
      label_argmax = gold_op_list[i][1]

      t0.set_label(label_argmax)
      s_list.append((t0, sc0, s0, rc0, r0))

    return loss
开发者ID:odashi,项目名称:nn_parsers,代码行数:81,代码来源:parse23_beam.py

示例8: forward

# 需要导入模块: from nltk.tree import Tree [as 别名]
# 或者: from nltk.tree.Tree import set_label [as 别名]

#.........这里部分代码省略.........
      for (text, _), x, (j, k), b \
      in zip(word_list, x_list, jk_list, b_list)]

    # estimate
    s_list = []
    zc = SRSTATE_ZEROS
    z = SRSTATE_ZEROS
    rf = XP.dropout(STACK_ONES)
    unary_chain = 0
    if is_training:
      loss = XP.fzeros(())

    for i in itertools.count():
      text, x, j, k, b = q_list[0] if q_list else QUEUE_DEFAULT
      t1, sc1, s1, rc1, r1 = s_list[-1] if s_list else STACK_DEFAULT
      t2, sc2, s2, rc2, r2 = s_list[-2] if len(s_list) >= 2 else STACK_DEFAULT
      t3, sc3, s3, rc3, r3 = s_list[-3] if len(s_list) >= 3 else STACK_DEFAULT

      zc, z = self.net_sr(zc, b, r1, z)  
      o = self.net_operation(z)

      if is_training:
        loss += functions.softmax_cross_entropy(o, XP.iarray([gold_op_list[i][0]]))
        o_argmax = gold_op_list[i][0]
      else:
        o_filter = [0.0 for _ in range(NUM_OP)]
        filtered = 0
        if not q_list:
          o_filter[OP_SHIFT] = NEG_INF
          filtered += 1
        if not s_list or unary_chain >= unary_limit:
          o_filter[OP_UNARY] = NEG_INF
          filtered += 1
        if len(s_list) < 2:
          o_filter[OP_BINARY] = NEG_INF
          filtered += 1
        if q_list or len(s_list) > 1:
          o_filter[OP_FINISH] = NEG_INF
        if filtered == NUM_OP:
          raise RuntimeError('No possible operation!')

        o += XP.farray([o_filter])
        o_argmax = int(cuda.to_cpu(o.data.argmax(1)))

      if o_argmax == OP_SHIFT:
        t0 = Tree(None, [text])
        sc0, s0 = (STACK_ZEROS, self.net_shift(x, j, k, b, s1))
        rc0, r0 = self.net_stack(rc1, s0, r1)
        q_list.pop(0)
        unary_chain = 0
        label = self.net_semiterminal(s0)
      elif o_argmax == OP_UNARY:
        t0 = Tree(None, [t1])
        sc0, s0 = self.net_unary(sc1, b, s1, s2)
        rc0, r0 = self.net_stack(rc2, s0, r2)
        s_list.pop()
        unary_chain += 1
        label = self.net_phrase(s0)
      elif o_argmax == OP_BINARY:
        t0 = Tree(None, [t2, t1])
        sc0, s0 = self.net_binary(sc1, sc2, b, s1, s2, s3)
        rc0, r0 = self.net_stack(rc3, s0, r3)
        s_list.pop()
        s_list.pop()
        unary_chain = 0
        label = self.net_phrase(s0)
      else: # OP_FINISH
        break

      r0 *= rf

      if is_training:
        loss += functions.softmax_cross_entropy(label, XP.iarray([gold_op_list[i][1]]))
        label_argmax = gold_op_list[i][1]
      else:
        label_argmax = int(cuda.to_cpu(label.data.argmax(1)))

      t0.set_label(label_argmax)
      s_list.append((t0, sc0, s0, rc0, r0))

      '''
      if is_training:
        o_est = int(cuda.to_cpu(o.data.argmax(1)))
        label_est = int(cuda.to_cpu(label.data.argmax(1)))
        trace('%c %c gold=%d-%2d, est=%d-%2d, stack=%2d, queue=%2d' % (
            '*' if o_est == gold_op_list[i][0] else ' ',
            '*' if label_est == gold_op_list[i][1] else ' ',
            gold_op_list[i][0], gold_op_list[i][1],
            o_est, label_est,
            len(s_list), len(q_list)))
      '''

    if is_training:
      return loss
    else:
      # combine multiple trees if they exists, and return the result.
      t0, *_ = s_list.pop()
      if s_list:
        raise RuntimeError('There exist multiple subtrees!')
      return t0
开发者ID:odashi,项目名称:nn_parsers,代码行数:104,代码来源:parse26.py


注:本文中的nltk.tree.Tree.set_label方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。