当前位置: 首页>>代码示例>>Python>>正文


Python modules.TextFieldEmbedder方法代码示例

本文整理汇总了Python中allennlp.modules.TextFieldEmbedder方法的典型用法代码示例。如果您正苦于以下问题:Python modules.TextFieldEmbedder方法的具体用法?Python modules.TextFieldEmbedder怎么用?Python modules.TextFieldEmbedder使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在allennlp.modules的用法示例。


在下文中一共展示了modules.TextFieldEmbedder方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 verbose_metrics: bool = False,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.dropout = torch.nn.Dropout(dropout)
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.classifier_feedforward = torch.nn.Linear(self.text_field_embedder.get_output_dim()  , self.num_classes)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i)
        self.loss = torch.nn.CrossEntropyLoss()

        initializer(self) 
开发者ID:allenai,项目名称:scibert,代码行数:26,代码来源:bert_text_classifier.py

示例2: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 text_encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 verbose_metrics: False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.text_encoder = text_encoder
        self.classifier_feedforward = classifier_feedforward
        self.prediction_layer = torch.nn.Linear(self.classifier_feedforward.get_output_dim()  , self.num_classes)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i)
        self.loss = torch.nn.CrossEntropyLoss()

        self.pool = lambda text, mask: util.get_final_encoder_states(text, mask, bidirectional=True)

        initializer(self) 
开发者ID:allenai,项目名称:scibert,代码行数:30,代码来源:text_classifier.py

示例3: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 include_start_end_transitions: bool = True,
                 dropout: Optional[float] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = 'labels'
        self.num_tags = self.vocab.get_vocab_size(self.label_namespace)

        # encode text
        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.dropout = torch.nn.Dropout(dropout) if dropout else None

        # crf
        output_dim = self.encoder.get_output_dim()
        self.tag_projection_layer = TimeDistributed(Linear(output_dim, self.num_tags))
        self.crf = ConditionalRandomField(self.num_tags, constraints=None, include_start_end_transitions=include_start_end_transitions)

        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3)
        }
        for index, label in self.vocab.get_index_to_token_vocabulary(self.label_namespace).items():
            self.metrics['F1_' + label] = F1Measure(positive_label=index)

        initializer(self) 
开发者ID:allenai,项目名称:scibert,代码行数:32,代码来源:pico_crf_tagger.py

示例4: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 predictor_dropout=0.0,
                 labels_namespace: str = "labels",
                 detect_namespace: str = "d_tags",
                 verbose_metrics: bool = False,
                 label_smoothing: float = 0.0,
                 confidence: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(Seq2Labels, self).__init__(vocab, regularizer)

        self.label_namespaces = [labels_namespace,
                                 detect_namespace]
        self.text_field_embedder = text_field_embedder
        self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace)
        self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace)
        self.label_smoothing = label_smoothing
        self.confidence = confidence
        self.incorr_index = self.vocab.get_token_index("INCORRECT",
                                                       namespace=detect_namespace)

        self._verbose_metrics = verbose_metrics
        self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout))

        self.tag_labels_projection_layer = TimeDistributed(
            Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes))

        self.tag_detect_projection_layer = TimeDistributed(
            Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes))

        self.metrics = {"accuracy": CategoricalAccuracy()}

        initializer(self) 
开发者ID:plkmo,项目名称:NLP_Toolkit,代码行数:36,代码来源:seq2labels_model.py

示例5: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(
        self,
        vocab: Vocabulary,
        sentence_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        attention: Attention,
        decoder_beam_search: BeamSearch,
        max_decoding_steps: int,
        dropout: float = 0.0,
    ) -> None:
        super(NlvrDirectSemanticParser, self).__init__(
            vocab=vocab,
            sentence_embedder=sentence_embedder,
            action_embedding_dim=action_embedding_dim,
            encoder=encoder,
            dropout=dropout,
        )
        self._decoder_trainer = MaximumMarginalLikelihood()
        self._decoder_step = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            activation=Activation.by_name("tanh")(),
            add_action_bias=False,
            dropout=dropout,
        )
        self._decoder_beam_search = decoder_beam_search
        self._max_decoding_steps = max_decoding_steps
        self._action_padding_index = -1 
开发者ID:allenai,项目名称:allennlp-semparse,代码行数:32,代码来源:nlvr_direct_semantic_parser.py

示例6: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(
        self,
        vocab: Vocabulary,
        sentence_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        dropout: float = 0.0,
        rule_namespace: str = "rule_labels",
    ) -> None:
        super(NlvrSemanticParser, self).__init__(vocab=vocab)

        self._sentence_embedder = sentence_embedder
        self._denotation_accuracy = Average()
        self._consistency = Average()
        self._encoder = encoder
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace

        self._action_embedder = Embedding(
            num_embeddings=vocab.get_vocab_size(self._rule_namespace),
            embedding_dim=action_embedding_dim,
        )

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding) 
开发者ID:allenai,项目名称:allennlp-semparse,代码行数:32,代码来源:nlvr_semantic_parser.py

示例7: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 final_feedforward: FeedForward,
                 coverage_loss: CoverageLoss,
                 similarity_function: SimilarityFunction = DotProductSimilarity(),
                 dropout: float = 0.5,
                 contextualize_pair_comparators: bool = False,
                 pair_context_encoder: Seq2SeqEncoder = None,
                 pair_feedforward: FeedForward = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
    	# Need to send it verbatim because otherwise FromParams doesn't work appropriately.
        super().__init__(vocab=vocab,
                         text_field_embedder=text_field_embedder,
                         encoder=encoder,
                         similarity_function=similarity_function,
                         projection_feedforward=projection_feedforward,
                         inference_encoder=inference_encoder,
                         output_feedforward=output_feedforward,
                         output_logit=output_logit,
                         final_feedforward=final_feedforward,
                         contextualize_pair_comparators=contextualize_pair_comparators,
                         coverage_loss=coverage_loss,
                         pair_context_encoder=pair_context_encoder,
                         pair_feedforward=pair_feedforward,
                         dropout=dropout,
                         initializer=initializer,
                         regularizer=regularizer)
        self._answer_loss = torch.nn.CrossEntropyLoss()

        self._accuracy = CategoricalAccuracy() 
开发者ID:StonyBrookNLP,项目名称:multee,代码行数:38,代码来源:single_correct_mcq_multee_esim.py

示例8: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForward,
                 output_logit: FeedForward,
                 final_feedforward: FeedForward,
                 coverage_loss: CoverageLoss,
                 similarity_function: SimilarityFunction = DotProductSimilarity(),
                 dropout: float = 0.5,
                 contextualize_pair_comparators: bool = False,
                 pair_context_encoder: Seq2SeqEncoder = None,
                 pair_feedforward: FeedForward = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab=vocab,
                         text_field_embedder=text_field_embedder,
                         encoder=encoder,
                         similarity_function=similarity_function,
                         projection_feedforward=projection_feedforward,
                         inference_encoder=inference_encoder,
                         output_feedforward=output_feedforward,
                         output_logit=output_logit,
                         final_feedforward=final_feedforward,
                         coverage_loss=coverage_loss,
                         contextualize_pair_comparators=contextualize_pair_comparators,
                         pair_context_encoder=pair_context_encoder,
                         pair_feedforward=pair_feedforward,
                         dropout=dropout,
                         initializer=initializer,
                         regularizer=regularizer)
        self._ignore_index = -1
        self._answer_loss = torch.nn.CrossEntropyLoss(ignore_index=self._ignore_index)
        self._coverage_loss = coverage_loss

        self._accuracy = CategoricalAccuracy()
        self._entailment_f1 = F1Measure(self._label2idx["entailment"]) 
开发者ID:StonyBrookNLP,项目名称:multee,代码行数:40,代码来源:multiple_correct_mcq_multee_esim.py

示例9: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 projected_layer: Seq2SeqEncoder,
                 flow_layer: Seq2SeqEncoder,
                 contextual_passage: Seq2SeqEncoder,
                 contextual_question: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 regularizer: Optional[RegularizerApplicator] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim)
        self.fuse = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.flow = flow_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.yesno_predictor = torch.nn.Linear(self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self) 
开发者ID:SparkJiao,项目名称:SLQA,代码行数:42,代码来源:slqa_h.py

示例10: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 target_embedding_dim: int = None,
                 copy_token: str = "@COPY@",
                 source_namespace: str = "tokens",
                 target_namespace: str = "target_tokens",
                 tensor_based_metric: Metric = None,
                 token_based_metric: Metric = None,
                 tie_embeddings: bool = False) -> None:
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        CopyNetSeq2Seq.__init__(
            self,
            vocab,
            source_embedder,
            encoder,
            attention,
            beam_size,
            max_decoding_steps,
            target_embedding_dim,
            copy_token,
            source_namespace,
            target_namespace,
            tensor_based_metric,
            token_based_metric
        )
        self._tie_embeddings = tie_embeddings

        if self._tie_embeddings:
            assert source_namespace == target_namespace
            assert "token_embedder_tokens" in dict(self._source_embedder.named_children())
            source_token_embedder = dict(self._source_embedder.named_children())["token_embedder_tokens"]
            self._target_embedder.weight = source_token_embedder.weight

        if tensor_based_metric is None:
            self._tensor_based_metric = None 
开发者ID:IlyaGusev,项目名称:summarus,代码行数:42,代码来源:copynet.py

示例11: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self,
                 vocab: Vocabulary,
                 source_text_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 decoder: SeqDecoder,
                 tied_source_embedder_key: Optional[str] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(CustomComposedSeq2Seq, self).__init__(vocab, regularizer)

        self._source_text_embedder = source_text_embedder
        self._encoder = encoder
        self._decoder = decoder

        if self._encoder.get_output_dim() != self._decoder.get_output_dim():
            raise ConfigurationError(f"Encoder output dimension {self._encoder.get_output_dim()} should be"
                                     f" equal to decoder dimension {self._decoder.get_output_dim()}.")
        if tied_source_embedder_key:
            if not isinstance(self._source_text_embedder, BasicTextFieldEmbedder):
                raise ConfigurationError("Unable to tie embeddings,"
                                         "Source text embedder is not an instance of `BasicTextFieldEmbedder`.")
            source_embedder = self._source_text_embedder._token_embedders[tied_source_embedder_key]
            if not isinstance(source_embedder, Embedding):
                raise ConfigurationError("Unable to tie embeddings,"
                                         "Selected source embedder is not an instance of `Embedding`.")
            if source_embedder.get_output_dim() != self._decoder.target_embedder.get_output_dim():
                raise ConfigurationError(f"Output Dimensions mismatch between"
                                         f"source embedder and target embedder.")
            self._source_text_embedder._token_embedders[tied_source_embedder_key] = self._decoder.target_embedder
        initializer(self) 
开发者ID:IlyaGusev,项目名称:summarus,代码行数:33,代码来源:custom_composed_seq2seq.py

示例12: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 passage_encoder: Seq2SeqEncoder,
                 pair_encoder: AttentionEncoder,
                 self_encoder: AttentionEncoder,
                 output_layer: QAOutputLayer,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 share_encoder: bool = False):

        super().__init__(vocab, regularizer)
        self.text_field_embedder = text_field_embedder
        self.question_encoder = question_encoder
        self.passage_encoder = passage_encoder
        self.pair_encoder = pair_encoder
        self.self_encoder = self_encoder
        self.output_layer = output_layer

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self.share_encoder = share_encoder
        self.loss = torch.nn.CrossEntropyLoss()
        initializer(self) 
开发者ID:matthew-z,项目名称:R-net,代码行数:29,代码来源:rnet.py

示例13: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self,
                 vocab: Vocabulary,
                 input_embedder: TextFieldEmbedder,
                 encoder: Encoder = None,
                 dropout: float = None,
                 initializer: InitializerApplicator = InitializerApplicator()
                ) -> None:
        """
        Parameters
        ----------
        vocab: `Vocabulary`
            vocab to use
        input_embedder: `TextFieldEmbedder`
            generic embedder of tokens
        encoder: `Encoder`, optional (default = None)
            Seq2Vec or Seq2Seq Encoder wrapper. If no encoder is provided,
            assume that the input is a bag of word counts, for linear classification.
        dropout: `float`, optional (default = None)
            if set, will apply dropout to output of encoder.
        initializer: `InitializerApplicator`
            generic initializer
        """
        super().__init__(vocab)
        self._input_embedder = input_embedder
        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = None
        self._encoder = encoder
        self._num_labels = vocab.get_vocab_size(namespace="labels")
        if self._encoder:
            self._clf_input_dim = self._encoder.get_output_dim()
        else:
            self._clf_input_dim = self._input_embedder.get_output_dim()
        self._classification_layer = torch.nn.Linear(self._clf_input_dim,
                                                     self._num_labels)
        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self) 
开发者ID:allenai,项目名称:vampire,代码行数:41,代码来源:classifier.py

示例14: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 calculate_span_f1: bool = None,
                 label_encoding: Optional[str] = None,
                 label_namespace: str = "labels",
                 verbose_metrics: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SimpleTagger, self).__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size(label_namespace)
        self.encoder = encoder
        self._verbose_metrics = verbose_metrics
        self.tag_projection_layer = TimeDistributed(Linear(self.encoder.get_output_dim(),
                                                           self.num_classes))

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")

        # We keep calculate_span_f1 as a constructor argument for API consistency with
        # the CrfTagger, even it is redundant in this class
        # (label_encoding serves the same purpose).
        if calculate_span_f1 and not label_encoding:
            raise ConfigurationError("calculate_span_f1 is True, but "
                                     "no label_encoding was specified.")
        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3)
        }

        if calculate_span_f1 or label_encoding:
            self._f1_metric = SpanBasedF1Measure(vocab,
                                                 tag_namespace=label_namespace,
                                                 label_encoding=label_encoding)
        else:
            self._f1_metric = None

        initializer(self) 
开发者ID:DreamerDeo,项目名称:HIT-SCIR-CoNLL2019,代码行数:43,代码来源:simple_tagger.py

示例15: __init__

# 需要导入模块: from allennlp import modules [as 别名]
# 或者: from allennlp.modules import TextFieldEmbedder [as 别名]
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 aggregate_feedforward: FeedForward,
                 premise_encoder: Optional[Seq2SeqEncoder] = None,
                 hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 aggregate_premise: Optional[str] = "max",
                 aggregate_hypothesis: Optional[str] = "max",
                 embeddings_dropout_value: Optional[float] = 0.0,
                 share_encoders: Optional[bool] = False) -> None:
        super(StackedNNAggregateCustom, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        if embeddings_dropout_value > 0.0:
            self._embeddings_dropout = torch.nn.Dropout(p=embeddings_dropout_value)
        else:
            self._embeddings_dropout = lambda x: x

        self._aggregate_feedforward = aggregate_feedforward
        self._premise_encoder = premise_encoder
        self._hypothesis_encoder = hypothesis_encoder

        self._premise_aggregate = aggregate_premise
        self._hypothesis_aggregate = aggregate_hypothesis

        self._num_labels = vocab.get_vocab_size(namespace="labels")

        premise_output_dim = self._text_field_embedder.get_output_dim()
        if self._premise_encoder is not None:
            premise_output_dim = self._premise_encoder.get_output_dim()

        hypothesis_output_dim = self._text_field_embedder.get_output_dim()
        if self._hypothesis_encoder is not None:
            hypothesis_output_dim = self._hypothesis_encoder.get_output_dim()

        if premise_output_dim != hypothesis_output_dim:
            raise ConfigurationError("Output dimension of the premise_encoder (dim: {}), "
                                     "plus hypothesis_encoder (dim: {})"
                                     "must match! "
                                     .format(premise_output_dim,
                                             hypothesis_output_dim))

        if premise_output_dim * 4 != \
                aggregate_feedforward.get_input_dim():
            raise ConfigurationError("The output of aggregate_feedforward input dim ({2})  "
                                     "should be {3} = 4 x {0} ({1} = premise_output_dim == hypothesis_output_dim)!"
                                     .format(premise_output_dim,
                                             hypothesis_output_dim,
                                             aggregate_feedforward.get_input_dim(),
                                             4 * premise_output_dim))

        if aggregate_feedforward.get_output_dim() != self._num_labels:
            raise ConfigurationError("Final output dimension (%d) must equal num labels (%d)" %
                                     (aggregate_feedforward.get_output_dim(), self._num_labels))

        self._accuracy = CategoricalAccuracy()
        self._loss = torch.nn.CrossEntropyLoss()

        initializer(self) 
开发者ID:allenai,项目名称:OpenBookQA,代码行数:61,代码来源:stacked_nn_aggregate_custom.py


注:本文中的allennlp.modules.TextFieldEmbedder方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。