当前位置: 首页>>代码示例>>Python>>正文


Python Params.pop_bool方法代码示例

本文整理汇总了Python中allennlp.common.Params.pop_bool方法的典型用法代码示例。如果您正苦于以下问题:Python Params.pop_bool方法的具体用法?Python Params.pop_bool怎么用?Python Params.pop_bool使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在allennlp.common.Params的用法示例。


在下文中一共展示了Params.pop_bool方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
 def from_params(cls, params: Params) -> 'WordSplitter':
     language = params.pop('language', 'en_core_web_sm')
     pos_tags = params.pop_bool('pos_tags', False)
     parse = params.pop_bool('parse', False)
     ner = params.pop_bool('ner', False)
     params.assert_empty(cls.__name__)
     return cls(language, pos_tags, parse, ner)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:9,代码来源:word_splitter.py

示例2: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
 def from_params(self, params: Params) -> PytorchSeq2SeqWrapper:
     if not params.pop_bool('batch_first', True):
         raise ConfigurationError("Our encoder semantics assumes batch is always first!")
     if self._module_class in self.PYTORCH_MODELS:
         params['batch_first'] = True
     stateful = params.pop_bool('stateful', False)
     module = self._module_class(**params.as_dict())
     return PytorchSeq2SeqWrapper(module, stateful=stateful)
开发者ID:apmoore1,项目名称:allennlp,代码行数:10,代码来源:__init__.py

示例3: extend_from_instances

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def extend_from_instances(self,
                              params: Params,
                              instances: Iterable['adi.Instance'] = ()) -> None:
        """
        Extends an already generated vocabulary using a collection of instances.
        """
        min_count = params.pop("min_count", None)
        max_vocab_size = pop_max_vocab_size(params)
        non_padded_namespaces = params.pop("non_padded_namespaces", DEFAULT_NON_PADDED_NAMESPACES)
        pretrained_files = params.pop("pretrained_files", {})
        min_pretrained_embeddings = params.pop("min_pretrained_embeddings", None)
        only_include_pretrained_words = params.pop_bool("only_include_pretrained_words", False)
        tokens_to_add = params.pop("tokens_to_add", None)
        params.assert_empty("Vocabulary - from dataset")

        logger.info("Fitting token dictionary from dataset.")
        namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
        for instance in Tqdm.tqdm(instances):
            instance.count_vocab_items(namespace_token_counts)
        self._extend(counter=namespace_token_counts,
                     min_count=min_count,
                     max_vocab_size=max_vocab_size,
                     non_padded_namespaces=non_padded_namespaces,
                     pretrained_files=pretrained_files,
                     only_include_pretrained_words=only_include_pretrained_words,
                     tokens_to_add=tokens_to_add,
                     min_pretrained_embeddings=min_pretrained_embeddings)
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:29,代码来源:vocabulary.py

示例4: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
 def from_params(cls, vocab: Vocabulary, params: Params) -> 'ElmoTokenEmbedder':  # type: ignore
     # pylint: disable=arguments-differ
     params.add_file_to_archive('options_file')
     params.add_file_to_archive('weight_file')
     options_file = params.pop('options_file')
     weight_file = params.pop('weight_file')
     requires_grad = params.pop('requires_grad', False)
     do_layer_norm = params.pop_bool('do_layer_norm', False)
     dropout = params.pop_float("dropout", 0.5)
     namespace_to_cache = params.pop("namespace_to_cache", None)
     if namespace_to_cache is not None:
         vocab_to_cache = list(vocab.get_token_to_index_vocabulary(namespace_to_cache).keys())
     else:
         vocab_to_cache = None
     projection_dim = params.pop_int("projection_dim", None)
     scalar_mix_parameters = params.pop('scalar_mix_parameters', None)
     params.assert_empty(cls.__name__)
     return cls(options_file=options_file,
                weight_file=weight_file,
                do_layer_norm=do_layer_norm,
                dropout=dropout,
                requires_grad=requires_grad,
                projection_dim=projection_dim,
                vocab_to_cache=vocab_to_cache,
                scalar_mix_parameters=scalar_mix_parameters)
开发者ID:apmoore1,项目名称:allennlp,代码行数:27,代码来源:elmo_token_embedder.py

示例5: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding':
        """
        We need the vocabulary here to know how many items we need to embed, and we look for a
        ``vocab_namespace`` key in the parameter dictionary to know which vocabulary to use.  If
        you know beforehand exactly how many embeddings you need, or aren't using a vocabulary
        mapping for the things getting embedded here, then you can pass in the ``num_embeddings``
        key directly, and the vocabulary will be ignored.
        """
        num_embeddings = params.pop_int('num_embeddings', None)
        vocab_namespace = params.pop("vocab_namespace", "tokens")
        if num_embeddings is None:
            num_embeddings = vocab.get_vocab_size(vocab_namespace)
        embedding_dim = params.pop_int('embedding_dim')
        pretrained_file = params.pop("pretrained_file", None)
        projection_dim = params.pop_int("projection_dim", None)
        trainable = params.pop_bool("trainable", True)
        padding_index = params.pop_int('padding_index', None)
        max_norm = params.pop_float('max_norm', None)
        norm_type = params.pop_float('norm_type', 2.)
        scale_grad_by_freq = params.pop_bool('scale_grad_by_freq', False)
        sparse = params.pop_bool('sparse', False)
        params.assert_empty(cls.__name__)

        if pretrained_file:
            # If we're loading a saved model, we don't want to actually read a pre-trained
            # embedding file - the embeddings will just be in our saved weights, and we might not
            # have the original embedding file anymore, anyway.
            weight = _read_pretrained_embedding_file(pretrained_file,
                                                     embedding_dim,
                                                     vocab,
                                                     vocab_namespace)
        else:
            weight = None

        return cls(num_embeddings=num_embeddings,
                   embedding_dim=embedding_dim,
                   projection_dim=projection_dim,
                   weight=weight,
                   padding_index=padding_index,
                   trainable=trainable,
                   max_norm=max_norm,
                   norm_type=norm_type,
                   scale_grad_by_freq=scale_grad_by_freq,
                   sparse=sparse)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:46,代码来源:embedding.py

示例6: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
 def from_params(cls, vocab: Vocabulary, params: Params) -> 'ElmoTokenEmbedder':
     params.add_file_to_archive('options_file')
     params.add_file_to_archive('weight_file')
     options_file = params.pop('options_file')
     weight_file = params.pop('weight_file')
     requires_grad = params.pop('requires_grad', False)
     do_layer_norm = params.pop_bool('do_layer_norm', False)
     dropout = params.pop_float("dropout", 0.5)
     params.assert_empty(cls.__name__)
     return cls(options_file, weight_file, do_layer_norm, dropout, requires_grad=requires_grad)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:12,代码来源:elmo_token_embedder.py

示例7: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BiattentiveClassificationNetwork':  # type: ignore
        # pylint: disable=arguments-differ
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab=vocab, params=embedder_params)
        embedding_dropout = params.pop("embedding_dropout")
        pre_encode_feedforward = FeedForward.from_params(params.pop("pre_encode_feedforward"))
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
        integrator = Seq2SeqEncoder.from_params(params.pop("integrator"))
        integrator_dropout = params.pop("integrator_dropout")

        output_layer_params = params.pop("output_layer")
        if "activations" in output_layer_params:
            output_layer = FeedForward.from_params(output_layer_params)
        else:
            output_layer = Maxout.from_params(output_layer_params)

        elmo = params.pop("elmo", None)
        if elmo is not None:
            elmo = Elmo.from_params(elmo)
        use_input_elmo = params.pop_bool("use_input_elmo", False)
        use_integrator_output_elmo = params.pop_bool("use_integrator_output_elmo", False)

        initializer = InitializerApplicator.from_params(params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(params.pop('regularizer', []))
        params.assert_empty(cls.__name__)

        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   embedding_dropout=embedding_dropout,
                   pre_encode_feedforward=pre_encode_feedforward,
                   encoder=encoder,
                   integrator=integrator,
                   integrator_dropout=integrator_dropout,
                   output_layer=output_layer,
                   elmo=elmo,
                   use_input_elmo=use_input_elmo,
                   use_integrator_output_elmo=use_integrator_output_elmo,
                   initializer=initializer,
                   regularizer=regularizer)
开发者ID:pyknife,项目名称:allennlp,代码行数:41,代码来源:biattentive_classification_network.py

示例8: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, params: Params) -> 'Elmo':
        # Add files to archive
        params.add_file_to_archive('options_file')
        params.add_file_to_archive('weight_file')

        options_file = params.pop('options_file')
        weight_file = params.pop('weight_file')
        requires_grad = params.pop('requires_grad', False)
        num_output_representations = params.pop('num_output_representations')
        do_layer_norm = params.pop_bool('do_layer_norm', False)
        keep_sentence_boundaries = params.pop_bool('keep_sentence_boundaries', False)
        dropout = params.pop_float('dropout', 0.5)
        scalar_mix_parameters = params.pop('scalar_mix_parameters', None)
        params.assert_empty(cls.__name__)

        return cls(options_file=options_file,
                   weight_file=weight_file,
                   num_output_representations=num_output_representations,
                   requires_grad=requires_grad,
                   do_layer_norm=do_layer_norm,
                   keep_sentence_boundaries=keep_sentence_boundaries,
                   dropout=dropout,
                   scalar_mix_parameters=scalar_mix_parameters)
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:25,代码来源:elmo.py

示例9: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, params: Params) -> 'Elmo':
        # Add files to archive
        params.add_file_to_archive('options_file')
        params.add_file_to_archive('weight_file')

        options_file = params.pop('options_file')
        weight_file = params.pop('weight_file')
        requires_grad = params.pop('requires_grad', False)
        num_output_representations = params.pop('num_output_representations')
        do_layer_norm = params.pop_bool('do_layer_norm', False)
        params.assert_empty(cls.__name__)

        return cls(options_file, weight_file, num_output_representations,
                   requires_grad=requires_grad, do_layer_norm=do_layer_norm)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:16,代码来源:elmo.py

示例10: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, params: Params):
        input_dim = params.pop_int('input_dim')
        hidden_dim = params.pop_int('hidden_dim')
        projection_dim = params.pop_int('projection_dim', None)
        feedforward_hidden_dim = params.pop_int("feedforward_hidden_dim")
        num_layers = params.pop_int("num_layers", 2)
        num_attention_heads = params.pop_int('num_attention_heads', 3)
        use_positional_encoding = params.pop_bool('use_positional_encoding', True)
        dropout_prob = params.pop_float("dropout_prob", 0.2)
        params.assert_empty(cls.__name__)

        return cls(input_dim=input_dim,
                   hidden_dim=hidden_dim,
                   feedforward_hidden_dim=feedforward_hidden_dim,
                   projection_dim=projection_dim,
                   num_layers=num_layers,
                   num_attention_heads=num_attention_heads,
                   use_positional_encoding=use_positional_encoding,
                   dropout_prob=dropout_prob)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:21,代码来源:stacked_self_attention.py

示例11: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, params: Params) -> 'AdaptiveIterator':
        adaptive_memory_usage_constant = params.pop_int('adaptive_memory_usage_constant')
        padding_memory_scaling = params.pop('padding_memory_scaling')
        maximum_batch_size = params.pop_int('maximum_batch_size', 10000)
        biggest_batch_first = params.pop_bool('biggest_batch_first', False)
        batch_size = params.pop_int('batch_size', None)
        sorting_keys = params.pop('sorting_keys', None)
        padding_noise = params.pop_float('sorting_noise', 0.2)
        instances_per_epoch = params.pop_int('instances_per_epoch', None)
        max_instances_in_memory = params.pop_int('max_instances_in_memory', None)
        params.assert_empty(cls.__name__)

        return cls(adaptive_memory_usage_constant=adaptive_memory_usage_constant,
                   padding_memory_scaling=padding_memory_scaling,
                   maximum_batch_size=maximum_batch_size,
                   biggest_batch_first=biggest_batch_first,
                   batch_size=batch_size,
                   sorting_keys=sorting_keys,
                   padding_noise=padding_noise,
                   instances_per_epoch=instances_per_epoch,
                   max_instances_in_memory=max_instances_in_memory)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:23,代码来源:adaptive_iterator.py

示例12: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder':  # type: ignore
        # pylint: disable=arguments-differ,bad-super-call

        # The original `from_params` for this class was designed in a way that didn't agree
        # with the constructor. The constructor wants a 'token_embedders' parameter that is a
        # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those
        # key-value pairs to be top-level in the params object.
        #
        # This breaks our 'configuration wizard' and configuration checks. Hence, going forward,
        # the params need a 'token_embedders' key so that they line up with what the constructor wants.
        # For now, the old behavior is still supported, but produces a DeprecationWarning.

        embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None)
        if embedder_to_indexer_map is not None:
            embedder_to_indexer_map = embedder_to_indexer_map.as_dict(quiet=True)
        allow_unmatched_keys = params.pop_bool("allow_unmatched_keys", False)

        token_embedder_params = params.pop('token_embedders', None)

        if token_embedder_params is not None:
            # New way: explicitly specified, so use it.
            token_embedders = {
                    name: TokenEmbedder.from_params(subparams, vocab=vocab)
                    for name, subparams in token_embedder_params.items()
            }

        else:
            # Warn that the original behavior is deprecated
            warnings.warn(DeprecationWarning("the token embedders for BasicTextFieldEmbedder should now "
                                             "be specified as a dict under the 'token_embedders' key, "
                                             "not as top-level key-value pairs"))

            token_embedders = {}
            keys = list(params.keys())
            for key in keys:
                embedder_params = params.pop(key)
                token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params)

        params.assert_empty(cls.__name__)
        return cls(token_embedders, embedder_to_indexer_map, allow_unmatched_keys)
开发者ID:apmoore1,项目名称:allennlp,代码行数:42,代码来源:basic_text_field_embedder.py

示例13: from_params

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
 def from_params(cls, params: Params) -> 'Seq2SeqDatasetReader':
     source_tokenizer_type = params.pop('source_tokenizer', None)
     source_tokenizer = None if source_tokenizer_type is None else Tokenizer.from_params(source_tokenizer_type)
     target_tokenizer_type = params.pop('target_tokenizer', None)
     target_tokenizer = None if target_tokenizer_type is None else Tokenizer.from_params(target_tokenizer_type)
     source_indexers_type = params.pop('source_token_indexers', None)
     source_add_start_token = params.pop_bool('source_add_start_token', True)
     if source_indexers_type is None:
         source_token_indexers = None
     else:
         source_token_indexers = TokenIndexer.dict_from_params(source_indexers_type)
     target_indexers_type = params.pop('target_token_indexers', None)
     if target_indexers_type is None:
         target_token_indexers = None
     else:
         target_token_indexers = TokenIndexer.dict_from_params(target_indexers_type)
     lazy = params.pop('lazy', False)
     params.assert_empty(cls.__name__)
     return Seq2SeqDatasetReader(source_tokenizer=source_tokenizer,
                                 target_tokenizer=target_tokenizer,
                                 source_token_indexers=source_token_indexers,
                                 target_token_indexers=target_token_indexers,
                                 source_add_start_token=source_add_start_token,
                                 lazy=lazy)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:26,代码来源:seq2seq.py

示例14: fine_tune_model

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]
def fine_tune_model(model: Model,
                    params: Params,
                    serialization_dir: str,
                    file_friendly_logging: bool = False) -> Model:
    """
    Fine tunes the given model, using a set of parameters that is largely identical to those used
    for :func:`~allennlp.commands.train.train_model`, except that the ``model`` section is ignored,
    if it is present (as we are already given a ``Model`` here).

    The main difference between the logic done here and the logic done in ``train_model`` is that
    here we do not worry about vocabulary construction or creating the model object.  Everything
    else is the same.

    Parameters
    ----------
    archive : ``Archive``
        A saved model archive that is the result of running the ``train`` command.
    train_data_path : ``str``
        Path to the training data to use for fine-tuning.
    serialization_dir : ``str``
        The directory in which to save results and logs.
    validation_data_path : ``str``, optional
        Path to the validation data to use while fine-tuning.
    file_friendly_logging : ``bool``, optional (default=False)
        If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
        down tqdm's output to only once every 10 seconds.
    """
    prepare_environment(params)
    os.makedirs(serialization_dir)
    prepare_global_logging(serialization_dir, file_friendly_logging)

    serialization_params = deepcopy(params).as_dict(quiet=True)
    with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file:
        json.dump(serialization_params, param_file, indent=4)

    if params.pop('model', None):
        logger.warning("You passed parameters for the model in your configuration file, but we "
                       "are ignoring them, using instead the model parameters in the archive.")

    vocabulary_params = params.pop('vocabulary', {})
    if vocabulary_params.get('directory_path', None):
        logger.warning("You passed `directory_path` in parameters for the vocabulary in "
                       "your configuration file, but it will be ignored. "
                       "Vocabulary from the saved model will be extended with current data.")

    all_datasets = datasets_from_params(params)

    datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))

    for dataset in datasets_for_vocab_creation:
        if dataset not in all_datasets:
            raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")

    logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
    vocab = model.vocab
    vocab.extend_from_instances(vocabulary_params,
                                (instance for key, dataset in all_datasets.items()
                                 for instance in dataset
                                 if key in datasets_for_vocab_creation))
    vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

    iterator = DataIterator.from_params(params.pop("iterator"))
    iterator.index_with(vocab)

    train_data = all_datasets['train']
    validation_data = all_datasets.get('validation')
    test_data = all_datasets.get('test')

    trainer_params = params.pop("trainer")
    no_grad_regexes = trainer_params.pop("no_grad", ())
    for name, parameter in model.named_parameters():
        if any(re.search(regex, name) for regex in no_grad_regexes):
            parameter.requires_grad_(False)

    frozen_parameter_names, tunable_parameter_names = \
                   get_frozen_and_tunable_parameter_names(model)
    logger.info("Following parameters are Frozen  (without gradient):")
    for name in frozen_parameter_names:
        logger.info(name)
    logger.info("Following parameters are Tunable (with gradient):")
    for name in tunable_parameter_names:
        logger.info(name)

    trainer = Trainer.from_params(model,
                                  serialization_dir,
                                  iterator,
                                  train_data,
                                  validation_data,
                                  trainer_params)

    evaluate_on_test = params.pop_bool("evaluate_on_test", False)
    params.assert_empty('base train command')
    try:
        metrics = trainer.train()
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logging.info("Fine-tuning interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir, files_to_archive=params.files_to_archive)
#.........这里部分代码省略.........
开发者ID:pyknife,项目名称:allennlp,代码行数:103,代码来源:fine_tune.py

示例15: train_model

# 需要导入模块: from allennlp.common import Params [as 别名]
# 或者: from allennlp.common.Params import pop_bool [as 别名]

#.........这里部分代码省略.........
    params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

    all_datasets = datasets_from_params(params)
    datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))

    for dataset in datasets_for_vocab_creation:
        if dataset not in all_datasets:
            raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")

    logger.info("From dataset instances, %s will be considered for vocabulary creation.",
                ", ".join(datasets_for_vocab_creation))
    vocab = Vocabulary.from_params(
            params.pop("vocabulary", {}),
            (instance for key, dataset in all_datasets.items()
             for instance in dataset
             if key in datasets_for_vocab_creation)
    )

    model = Model.from_params(vocab=vocab, params=params.pop('model'))

    # Initializing the model can have side effect of expanding the vocabulary
    vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

    iterator = DataIterator.from_params(params.pop("iterator"))
    iterator.index_with(vocab)
    validation_iterator_params = params.pop("validation_iterator", None)
    if validation_iterator_params:
        validation_iterator = DataIterator.from_params(validation_iterator_params)
        validation_iterator.index_with(vocab)
    else:
        validation_iterator = None

    train_data = all_datasets['train']
    validation_data = all_datasets.get('validation')
    test_data = all_datasets.get('test')

    trainer_params = params.pop("trainer")
    no_grad_regexes = trainer_params.pop("no_grad", ())
    for name, parameter in model.named_parameters():
        if any(re.search(regex, name) for regex in no_grad_regexes):
            parameter.requires_grad_(False)

    frozen_parameter_names, tunable_parameter_names = \
                   get_frozen_and_tunable_parameter_names(model)
    logger.info("Following parameters are Frozen  (without gradient):")
    for name in frozen_parameter_names:
        logger.info(name)
    logger.info("Following parameters are Tunable (with gradient):")
    for name in tunable_parameter_names:
        logger.info(name)

    trainer_choice = trainer_params.pop_choice("type",
                                               Trainer.list_available(),
                                               default_to_first_choice=True)
    trainer = Trainer.by_name(trainer_choice).from_params(model=model,
                                                          serialization_dir=serialization_dir,
                                                          iterator=iterator,
                                                          train_data=train_data,
                                                          validation_data=validation_data,
                                                          params=trainer_params,
                                                          validation_iterator=validation_iterator)

    evaluate_on_test = params.pop_bool("evaluate_on_test", False)
    params.assert_empty('base train command')

    try:
        metrics = trainer.train()
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logging.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir, files_to_archive=params.files_to_archive)
        raise

    # Now tar up results
    archive_model(serialization_dir, files_to_archive=params.files_to_archive)

    logger.info("Loading the best epoch weights.")
    best_model_state_path = os.path.join(serialization_dir, 'best.th')
    best_model_state = torch.load(best_model_state_path)
    best_model = model
    best_model.load_state_dict(best_model_state)

    if test_data and evaluate_on_test:
        logger.info("The model will be evaluated using the best epoch weights.")
        test_metrics = evaluate(
                best_model, test_data, validation_iterator or iterator,
                cuda_device=trainer._cuda_devices[0] # pylint: disable=protected-access
        )
        for key, value in test_metrics.items():
            metrics["test_" + key] = value

    elif test_data:
        logger.info("To evaluate on the test set after training, pass the "
                    "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")

    dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True)

    return best_model
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:104,代码来源:train.py


注:本文中的allennlp.common.Params.pop_bool方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。