当前位置: 首页>>代码示例>>Python>>正文


Python Predictor.from_archive方法代码示例

本文整理汇总了Python中allennlp.predictors.Predictor.from_archive方法的典型用法代码示例。如果您正苦于以下问题:Python Predictor.from_archive方法的具体用法?Python Predictor.from_archive怎么用?Python Predictor.from_archive使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在allennlp.predictors.Predictor的用法示例。


在下文中一共展示了Predictor.from_archive方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_simple_gradient_basic_text

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_simple_gradient_basic_text(self):
        inputs = {"sentence": "It was the ending that I hated"}
        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive, "text_classifier")

        interpreter = SimpleGradient(predictor)
        interpretation = interpreter.saliency_interpret_from_json(inputs)
        assert interpretation is not None
        assert "instance_1" in interpretation
        assert "grad_input_1" in interpretation["instance_1"]
        grad_input_1 = interpretation["instance_1"]["grad_input_1"]
        assert len(grad_input_1) == 7  # 7 words in input

        # two interpretations should be identical for gradient
        repeat_interpretation = interpreter.saliency_interpret_from_json(inputs)
        repeat_grad_input_1 = repeat_interpretation["instance_1"]["grad_input_1"]
        for grad, repeat_grad in zip(grad_input_1, repeat_grad_input_1):
            assert grad == approx(repeat_grad) 
开发者ID:allenai,项目名称:allennlp,代码行数:22,代码来源:simple_gradient_test.py

示例2: test_hotflip

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_hotflip(self):
        inputs = {"sentence": "I always write unit tests for my code."}

        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive)

        hotflipper = Hotflip(predictor)
        hotflipper.initialize()
        attack = hotflipper.attack_from_json(inputs, "tokens", "grad_input_1")
        assert attack is not None
        assert "final" in attack
        assert "original" in attack
        assert "outputs" in attack
        assert len(attack["final"][0]) == len(
            attack["original"]
        )  # hotflip replaces words without removing 
开发者ID:allenai,项目名称:allennlp,代码行数:20,代码来源:hotflip_test.py

示例3: test_with_token_characters_indexer

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_with_token_characters_indexer(self):

        inputs = {"sentence": "I always write unit tests for my code."}

        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive)
        predictor._dataset_reader._token_indexers["chars"] = TokenCharactersIndexer(
            min_padding_length=1
        )
        predictor._model._text_field_embedder._token_embedders["chars"] = EmptyEmbedder()

        hotflipper = Hotflip(predictor)
        hotflipper.initialize()
        attack = hotflipper.attack_from_json(inputs, "tokens", "grad_input_1")
        assert attack is not None
        assert "final" in attack
        assert "original" in attack
        assert "outputs" in attack
        assert len(attack["final"][0]) == len(
            attack["original"]
        )  # hotflip replaces words without removing 
开发者ID:allenai,项目名称:allennlp,代码行数:25,代码来源:hotflip_test.py

示例4: test_predictions_to_labeled_instances

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_predictions_to_labeled_instances(self):
        inputs = {
            "sentence": "It was the ending that I hated. I was disappointed that it was so bad."
        }

        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive, "text_classifier")

        instance = predictor._json_to_instance(inputs)
        outputs = predictor._model.forward_on_instance(instance)
        new_instances = predictor.predictions_to_labeled_instances(instance, outputs)
        assert "label" in new_instances[0].fields
        assert new_instances[0].fields["label"] is not None
        assert len(new_instances) == 1 
开发者ID:allenai,项目名称:allennlp,代码行数:18,代码来源:text_classifier_test.py

示例5: test_loads_correct_dataset_reader

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_loads_correct_dataset_reader(self):
        # This model has a different dataset reader configuration for train and validation. The
        # parameter that differs is the token indexer's namespace.
        archive = load_archive(
            self.FIXTURES_ROOT / "simple_tagger_with_span_f1" / "serialization" / "model.tar.gz"
        )

        predictor = Predictor.from_archive(archive, "sentence_tagger")
        assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens"

        predictor = Predictor.from_archive(
            archive, "sentence_tagger", dataset_reader_to_load="train"
        )
        assert predictor._dataset_reader._token_indexers["tokens"].namespace == "tokens"

        predictor = Predictor.from_archive(
            archive, "sentence_tagger", dataset_reader_to_load="validation"
        )
        assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens" 
开发者ID:allenai,项目名称:allennlp,代码行数:21,代码来源:predictor_test.py

示例6: test_get_gradients_when_requires_grad_is_false

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_get_gradients_when_requires_grad_is_false(self):
        inputs = {
            "sentence": "I always write unit tests",
        }

        archive = load_archive(
            self.FIXTURES_ROOT
            / "basic_classifier"
            / "embedding_with_trainable_is_false"
            / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive)

        # ensure that requires_grad is initially False on the embedding layer
        embedding_layer = util.find_embedding_layer(predictor._model)
        assert not embedding_layer.weight.requires_grad
        instance = predictor._json_to_instance(inputs)
        outputs = predictor._model.forward_on_instance(instance)
        labeled_instances = predictor.predictions_to_labeled_instances(instance, outputs)
        # ensure that gradients are always present, despite requires_grad being false on the embedding layer
        for instance in labeled_instances:
            grads = predictor.get_gradients([instance])[0]
            assert bool(grads)
        # ensure that no side effects remain
        assert not embedding_layer.weight.requires_grad 
开发者ID:allenai,项目名称:allennlp,代码行数:27,代码来源:predictor_test.py

示例7: test_uses_named_inputs

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_uses_named_inputs(self):
        inputs = {
                u"sentence": u"What a great test sentence.",
        }

        archive = load_archive(self.FIXTURES_ROOT / u'constituency_parser' / u'serialization' / u'model.tar.gz')
        predictor = Predictor.from_archive(archive, u'constituency-parser')

        result = predictor.predict_json(inputs)

        assert len(result[u"spans"]) == 21 # number of possible substrings of the sentence.
        assert len(result[u"class_probabilities"]) == 21
        assert result[u"tokens"] == [u"What", u"a", u"great", u"test", u"sentence", u"."]
        assert isinstance(result[u"trees"], unicode)

        for class_distribution in result[u"class_probabilities"]:
            self.assertAlmostEqual(sum(class_distribution), 1.0, places=4) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:19,代码来源:constituency_parser_test.py

示例8: test_uses_named_inputs

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_uses_named_inputs(self):
        inputs = {u"document": u"This is a single string document about a test. Sometimes it "
                              u"contains coreferent parts."}
        archive = load_archive(self.FIXTURES_ROOT / u'coref' / u'serialization' / u'model.tar.gz')
        predictor = Predictor.from_archive(archive, u'coreference-resolution')

        result = predictor.predict_json(inputs)

        document = result[u"document"]
        assert document == [u'This', u'is', u'a', u'single', u'string',
                            u'document', u'about', u'a', u'test', u'.', u'Sometimes',
                            u'it', u'contains', u'coreferent', u'parts', u'.']

        clusters = result[u"clusters"]
        assert isinstance(clusters, list)
        for cluster in clusters:
            assert isinstance(cluster, list)
            for mention in cluster:
                # Spans should be integer indices.
                assert isinstance(mention[0], int)
                assert isinstance(mention[1], int)
                # Spans should be inside document.
                assert 0 < mention[0] <= len(document)
                assert 0 < mention[1] <= len(document) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:26,代码来源:coref_test.py

示例9: test_uses_named_inputs

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_uses_named_inputs(self):
        inputs = {
                u"question": u"names",
                u"table": u"name\tdate\nmatt\t2017\npradeep\t2018"
        }

        archive_path = self.FIXTURES_ROOT / u'semantic_parsing' / u'wikitables' / u'serialization' / u'model.tar.gz'
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, u'wikitables-parser')

        result = predictor.predict_json(inputs)

        action_sequence = result.get(u"best_action_sequence")
        if action_sequence:
            # We don't currently disallow endless loops in the decoder, and an untrained seq2seq
            # model will easily get itself into a loop.  An endless loop isn't a finished logical
            # form, so decoding doesn't return any finished states, which means no actions.  So,
            # sadly, we don't have a great test here.  This is just testing that the predictor
            # runs, basically.
            assert len(action_sequence) > 1
            assert all([isinstance(action, unicode) for action in action_sequence])

            logical_form = result.get(u"logical_form")
            assert logical_form is not None 
开发者ID:plasticityai,项目名称:magnitude,代码行数:26,代码来源:wikitables_parser_test.py

示例10: test_atis_parser_uses_named_inputs

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_atis_parser_uses_named_inputs(self):
        inputs = {"utterance": "show me the flights to seattle"}

        archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz"
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, "atis-parser")

        result = predictor.predict_json(inputs)
        action_sequence = result.get("best_action_sequence")
        if action_sequence:
            # An untrained model will likely get into a loop, and not produce at finished states.
            # When the model gets into a loop it will not produce any valid SQL, so we don't get
            # any actions. This basically just tests if the model runs.
            assert len(action_sequence) > 1
            assert all([isinstance(action, str) for action in action_sequence])
            predicted_sql_query = result.get("predicted_sql_query")
            assert predicted_sql_query is not None 
开发者ID:allenai,项目名称:allennlp-semparse,代码行数:19,代码来源:atis_parser_test.py

示例11: test_uses_named_inputs

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_uses_named_inputs(self):
        inputs = {"question": "names", "table": "name\tdate\nmatt\t2017\npradeep\t2018"}

        archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz"
        archive = load_archive(archive_path)
        predictor = Predictor.from_archive(archive, "wikitables-parser")

        result = predictor.predict_json(inputs)

        action_sequence = result.get("best_action_sequence")
        if action_sequence:
            # We don't currently disallow endless loops in the decoder, and an untrained seq2seq
            # model will easily get itself into a loop.  An endless loop isn't a finished logical
            # form, so decoding doesn't return any finished states, which means no actions.  So,
            # sadly, we don't have a great test here.  This is just testing that the predictor
            # runs, basically.
            assert len(action_sequence) > 1
            assert all([isinstance(action, str) for action in action_sequence])

            logical_form = result.get("logical_form")
            assert logical_form is not None 
开发者ID:allenai,项目名称:allennlp-semparse,代码行数:23,代码来源:wikitables_parser_test.py

示例12: test_predictor

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_predictor():
    question_json = {"id": "1700", "question_tokens": ["@start@", "For", "what", "does", "a", "stove", "generally", "generate", "heat", "?", "@end@"], "choice_tokens_list": [["@start@", "warming", "the", "air", "in", "the", "area", "@end@"], ["@start@", "heating", "nutrients", "to", "appropriate", "temperatures", "@end@"], ["@start@", "entertaining", "various", "visitors", "and", "guests", "@end@"], ["@start@", "to", "create", "electrical", "charges", "@end@"]], "facts_tokens_list": [["@start@", "UML", "can", "generate", "code", "@end@"], ["@start@", "generate", "is", "a", "synonym", "of", "beget", "@end@"], ["@start@", "Heat", "is", "generated", "by", "a", "stove", "@end@"], ["@start@", "A", "sonnet", "is", "generally", "very", "structured", "@end@"], ["@start@", "A", "fundamentalist", "is", "generally", "right", "-", "wing", "@end@"], ["@start@", "menstruation", "is", "generally", "crampy", "@end@"], ["@start@", "an", "erection", "is", "generally", "pleasurable", "@end@"], ["@start@", "gunfire", "is", "generally", "lethal", "@end@"], ["@start@", "ejaculating", "is", "generally", "pleasurable", "@end@"], ["@start@", "Huddersfield", "is", "generally", "urban", "@end@"], ["@start@", "warming", "is", "a", "synonym", "of", "calefacient", "@end@"], ["@start@", "heat", "is", "related", "to", "warming", "air", "@end@"], ["@start@", "a", "stove", "is", "for", "warming", "food", "@end@"], ["@start@", "an", "air", "conditioning", "is", "for", "warming", "@end@"], ["@start@", "The", "earth", "is", "warming", "@end@"], ["@start@", "a", "heat", "source", "is", "for", "warming", "up", "@end@"], ["@start@", "A", "foyer", "is", "an", "enterance", "area", "@end@"], ["@start@", "Being", "nosey", "is", "not", "appropriate", "@end@"], ["@start@", "seize", "is", "a", "synonym", "of", "appropriate", "@end@"], ["@start@", "a", "fitting", "room", "is", "used", "for", "something", "appropriate", "@end@"], ["@start@", "appropriate", "is", "a", "synonym", "of", "allow", "@end@"], ["@start@", "appropriate", "is", "similar", "to", "befitting", "@end@"], ["@start@", "appropriate", "is", "similar", "to", "grade", "-", "appropriate", "@end@"], ["@start@", "grade", "-", "appropriate", "is", "similar", "to", "appropriate", "@end@"], ["@start@", "A", "parlor", "is", "used", "for", "entertaining", "guests", "@end@"], ["@start@", "a", "back", "courtyard", "is", "for", "entertaining", "guests", "@end@"], ["@start@", "guest", "is", "a", "type", "of", "visitor", "@end@"], ["@start@", "a", "family", "room", "is", "for", "entertaining", "guests", "@end@"], ["@start@", "cooking", "a", "meal", "is", "for", "entertaining", "guests", "@end@"], ["@start@", "buying", "a", "house", "is", "for", "entertaining", "guests", "@end@"], ["@start@", "having", "a", "party", "is", "for", "entertaining", "guests", "@end@"], ["@start@", "a", "dining", "area", "is", "used", "for", "entertaining", "guests", "@end@"], ["@start@", "visitor", "is", "related", "to", "guest", "@end@"], ["@start@", "guest", "is", "related", "to", "visitor", "@end@"], ["@start@", "Electrical", "charges", "are", "additive", "@end@"], ["@start@", "Lightning", "is", "an", "electrical", "charge", "@end@"], ["@start@", "electrons", "have", "electrical", "charge", "@end@"], ["@start@", "A", "judge", "is", "in", "charge", "in", "a", "courtroom", "@end@"], ["@start@", "charge", "is", "a", "synonym", "of", "accusation", "@end@"], ["@start@", "A", "consultant", "can", "charge", "a", "fee", "to", "a", "client", "@end@"], ["@start@", "charge", "is", "a", "synonym", "of", "commission", "@end@"], ["@start@", "charge", "is", "a", "synonym", "of", "cathexis", "@end@"], ["@start@", "charge", "is", "not", "cash", "@end@"], ["@start@", "arraign", "entails", "charge", "@end@"], ["@start@", "a", "stove", "generates", "heat", "for", "cooking", "usually", "@end@"], ["@start@", "preferences", "are", "generally", "learned", "characteristics", "@end@"], ["@start@", "a", "windmill", "does", "not", "create", "pollution", "@end@"], ["@start@", "temperature", "is", "a", "measure", "of", "heat", "energy", "@end@"], ["@start@", "a", "hot", "something", "is", "a", "source", "of", "heat", "@end@"], ["@start@", "the", "moon", "does", "not", "contain", "water", "@end@"], ["@start@", "sunlight", "produces", "heat", "@end@"], ["@start@", "an", "oven", "is", "a", "source", "of", "heat", "@end@"], ["@start@", "a", "hot", "substance", "is", "a", "source", "of", "heat", "@end@"], ["@start@", "a", "car", "engine", "is", "a", "source", "of", "heat", "@end@"], ["@start@", "as", "the", "amount", "of", "rainfall", "increases", "in", "an", "area", ",", "the", "amount", "of", "available", "water", "in", "that", "area", "will", "increase", "@end@"], ["@start@", "sound", "can", "travel", "through", "air", "@end@"], ["@start@", "the", "greenhouse", "effect", "is", "when", "carbon", "in", "the", "air", "heats", "a", "planet", "'s", "atmosphere", "@end@"], ["@start@", "a", "community", "is", "made", "of", "many", "types", "of", "organisms", "in", "an", "area", "@end@"], ["@start@", "air", "is", "a", "vehicle", "for", "sound", "@end@"], ["@start@", "rainfall", "is", "the", "amount", "of", "rain", "an", "area", "receives", "@end@"], ["@start@", "an", "animal", "requires", "air", "for", "survival", "@end@"], ["@start@", "humidity", "is", "the", "amount", "of", "water", "vapor", "in", "the", "air", "@end@"], ["@start@", "if", "some", "nutrients", "are", "in", "the", "soil", "then", "those", "nutrients", "are", "in", "the", "food", "chain", "@end@"], ["@start@", "as", "heat", "is", "transferred", "from", "something", "to", "something", "else", ",", "the", "temperature", "of", "that", "something", "will", "decrease", "@end@"], ["@start@", "uneven", "heating", "causes", "convection", "@end@"], ["@start@", "as", "temperature", "during", "the", "day", "increases", ",", "the", "temperature", "in", "an", "environment", "will", "increase", "@end@"], ["@start@", "uneven", "heating", "of", "the", "Earth", "'s", "surface", "cause", "wind", "@end@"], ["@start@", "an", "animal", "needs", "to", "eat", "food", "for", "nutrients", "@end@"], ["@start@", "soil", "contains", "nutrients", "for", "plants", "@end@"], ["@start@", "if", "two", "objects", "have", "the", "same", "charge", "then", "those", "two", "materials", "will", "repel", "each", "other", "@end@"], ["@start@", "water", "is", "an", "electrical", "conductor", "@end@"], ["@start@", "a", "battery", "is", "a", "source", "of", "electrical", "energy", "@end@"], ["@start@", "metal", "is", "an", "electrical", "energy", "conductor", "@end@"], ["@start@", "when", "an", "electrical", "circuit", "is", "working", "properly", ",", "electrical", "current", "runs", "through", "the", "wires", "in", "that", "circuit", "@end@"], ["@start@", "brick", "is", "an", "electrical", "insulator", "@end@"], ["@start@", "wood", "is", "an", "electrical", "energy", "insulator", "@end@"], ["@start@", "a", "toaster", "converts", "electrical", "energy", "into", "heat", "energy", "for", "toasting", "@end@"]], "gold_label": 1, "gold_facts": {"fact1": "a stove generates heat for cooking usually", "fact2": "cooking involves heating nutrients to higher temperatures"}, "label_probs": [0.002615198493003845, 0.9686304330825806, 0.008927381597459316, 0.01982697658240795], "label_ranks": [3, 0, 2, 1], "predicted_label": 1, }

    inputs = question_to_predictor_input(question_json)
    inputs = predictor_input_to_pred_input_with_full_question_text(inputs)
    print(json.dumps(inputs, indent=4))

    archive = load_archive('_trained_models/model_CN5_1202.tar.gz')
    predictor = Predictor.from_archive(archive, 'predictor-qa-mc-with-know-visualize')

    result = predictor.predict_json(inputs)

    print(result) 
开发者ID:allenai,项目名称:OpenBookQA,代码行数:15,代码来源:predictor_qa_mc_with_know_visualize_test.py

示例13: test_smooth_gradient

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_smooth_gradient(self):
        inputs = {"sentence": "It was the ending that I hated"}
        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive, "text_classifier")

        interpreter = SmoothGradient(predictor)
        interpretation = interpreter.saliency_interpret_from_json(inputs)
        assert interpretation is not None
        assert "instance_1" in interpretation
        assert "grad_input_1" in interpretation["instance_1"]
        assert len(interpretation["instance_1"]["grad_input_1"]) == 7  # 7 words in input 
开发者ID:allenai,项目名称:allennlp,代码行数:15,代码来源:smooth_gradient_test.py

示例14: test_input_reduction

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_input_reduction(self):
        # test using classification model
        inputs = {"sentence": "I always write unit tests for my code."}

        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive)

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, "tokens", "grad_input_1")
        assert reduced is not None
        assert "final" in reduced
        assert "original" in reduced
        assert reduced["final"][0]  # always at least one token
        assert len(reduced["final"][0]) <= len(
            reduced["original"]
        )  # input reduction removes tokens
        for word in reduced["final"][0]:  # no new words entered
            assert word in reduced["original"]

        # test using NER model (tests different underlying logic)
        inputs = {"sentence": "Eric Wallace was an intern at AI2"}

        archive = load_archive(
            self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive, "sentence_tagger")

        reducer = InputReduction(predictor)
        reduced = reducer.attack_from_json(inputs, "tokens", "grad_input_1")
        assert reduced is not None
        assert "final" in reduced
        assert "original" in reduced
        for reduced_input in reduced["final"]:
            assert reduced_input  # always at least one token
            assert len(reduced_input) <= len(reduced["original"])  # input reduction removes tokens
            for word in reduced_input:  # no new words entered
                assert word in reduced["original"] 
开发者ID:allenai,项目名称:allennlp,代码行数:41,代码来源:input_reduction_test.py

示例15: test_uses_named_inputs

# 需要导入模块: from allennlp.predictors import Predictor [as 别名]
# 或者: from allennlp.predictors.Predictor import from_archive [as 别名]
def test_uses_named_inputs(self):
        inputs = {
            "sentence": "It was the ending that I hated. I was disappointed that it was so bad."
        }

        archive = load_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive, "text_classifier")
        result = predictor.predict_json(inputs)

        logits = result.get("logits")
        assert logits is not None
        assert isinstance(logits, list)
        assert len(logits) == 2
        assert all(isinstance(x, float) for x in logits)

        probs = result.get("probs")
        assert probs is not None
        assert isinstance(probs, list)
        assert len(probs) == 2
        assert all(isinstance(x, float) for x in probs)
        assert all(x >= 0 for x in probs)
        assert sum(probs) == approx(1.0)

        label = result.get("label")
        assert label is not None
        assert label in predictor._model.vocab.get_token_to_index_vocabulary(namespace="labels")

        exps = [math.exp(x) for x in logits]
        sum_exps = sum(exps)
        for e, p in zip(exps, probs):
            assert e / sum_exps == approx(p) 
开发者ID:allenai,项目名称:allennlp,代码行数:35,代码来源:text_classifier_test.py


注:本文中的allennlp.predictors.Predictor.from_archive方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。