本文整理汇总了Python中torchtext.data.NestedField方法的典型用法代码示例。如果您正苦于以下问题:Python data.NestedField方法的具体用法?Python data.NestedField怎么用?Python data.NestedField使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchtext.data
的用法示例。
在下文中一共展示了data.NestedField方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_init_full
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_init_full(self):
nesting_field = data.Field()
field = data.NestedField(
nesting_field,
use_vocab=False,
init_token="<s>",
eos_token="</s>",
fix_length=10,
dtype=torch.float,
preprocessing=lambda xs: list(reversed(xs)),
postprocessing=lambda xs: [x.upper() for x in xs],
tokenize=list,
pad_first=True,
)
assert not field.use_vocab
assert field.init_token == "<s>"
assert field.eos_token == "</s>"
assert field.fix_length == 10
assert field.dtype is torch.float
assert field.preprocessing("a b c".split()) == "c b a".split()
assert field.postprocessing("a b c".split()) == "A B C".split()
assert field.tokenize("abc") == ["a", "b", "c"]
assert field.pad_first
示例2: test_build_vocab_from_dataset
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_build_vocab_from_dataset(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])
CHARS.build_vocab(dataset, min_freq=2)
expected = "a b <w> </w> <s> </s> <cunk> <cpad>".split()
assert len(CHARS.vocab) == len(expected)
for c in expected:
assert c in CHARS.vocab.stoi
expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
示例3: test_pad_when_no_init_and_eos_tokens
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_pad_when_no_init_and_eos_tokens(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field)
minibatch = [
["john", "loves", "mary"],
["mary", "cries"]
]
expected = [
[
["<w>"] + list("john") + ["</w>", "<cpad>"],
["<w>"] + list("loves") + ["</w>"],
["<w>"] + list("mary") + ["</w>", "<cpad>"],
],
[
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>"] + list("cries") + ["</w>"],
["<cpad>"] * 7,
]
]
assert CHARS.pad(minibatch) == expected
示例4: test_init_minimal
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_init_minimal(self):
nesting_field = data.Field()
field = data.NestedField(nesting_field)
assert isinstance(field, data.Field)
assert field.nesting_field is nesting_field
assert field.sequential
assert field.use_vocab
assert field.init_token is None
assert field.eos_token is None
assert field.unk_token == nesting_field.unk_token
assert field.fix_length is None
assert field.dtype is torch.long
assert field.preprocessing is None
assert field.postprocessing is None
assert field.lower == nesting_field.lower
assert field.tokenize("a b c") == "a b c".split()
assert not field.include_lengths
assert field.batch_first
assert field.pad_token == nesting_field.pad_token
assert not field.pad_first
示例5: test_init_when_nesting_field_is_not_sequential
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_init_when_nesting_field_is_not_sequential(self):
nesting_field = data.Field(sequential=False)
field = data.NestedField(nesting_field)
assert field.pad_token == "<pad>"
示例6: test_init_when_nesting_field_has_include_lengths_equal_true
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_init_when_nesting_field_has_include_lengths_equal_true(self):
nesting_field = data.Field(include_lengths=True)
with pytest.raises(ValueError) as excinfo:
data.NestedField(nesting_field)
assert "nesting field cannot have include_lengths=True" in str(excinfo.value)
示例7: test_init_with_nested_field_as_nesting_field
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_init_with_nested_field_as_nesting_field(self):
nesting_field = data.NestedField(data.Field())
with pytest.raises(ValueError) as excinfo:
data.NestedField(nesting_field)
assert "nesting field must not be another NestedField" in str(excinfo.value)
示例8: test_build_vocab_from_iterable
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_build_vocab_from_iterable(self):
nesting_field = data.Field(unk_token="<cunk>", pad_token="<cpad>")
CHARS = data.NestedField(nesting_field)
CHARS.build_vocab(
[[list("aaa"), list("bbb"), ["c"]], [list("bbb"), list("aaa")]],
[[list("ccc"), list("bbb")], [list("bbb")]],
)
expected = "a b c <cunk> <cpad>".split()
assert len(CHARS.vocab) == len(expected)
for c in expected:
assert c in CHARS.vocab.stoi
expected_freqs = Counter({"a": 6, "b": 12, "c": 4})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
示例9: test_pad
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_pad(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
minibatch = [
[list("john"), list("loves"), list("mary")],
[list("mary"), list("cries")],
]
expected = [
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("john") + ["</w>", "<cpad>"],
["<w>"] + list("loves") + ["</w>"],
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
],
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>"] + list("cries") + ["</w>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
["<cpad>"] * 7,
]
]
assert CHARS.pad(minibatch) == expected
# test include_length
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>",
eos_token="</s>", include_lengths=True)
arr, seq_len, words_len = CHARS.pad(minibatch)
assert arr == expected
assert seq_len == [5, 4]
assert words_len == [[3, 6, 7, 6, 3], [3, 6, 7, 3, 0]]
示例10: test_pad_when_nesting_field_is_not_sequential
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_pad_when_nesting_field_is_not_sequential(self):
nesting_field = data.Field(sequential=False, unk_token="<cunk>",
pad_token="<cpad>", init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
minibatch = [
["john", "loves", "mary"],
["mary", "cries"]
]
expected = [
["<s>", "john", "loves", "mary", "</s>"],
["<s>", "mary", "cries", "</s>", "<pad>"],
]
assert CHARS.pad(minibatch) == expected
示例11: test_pad_when_nesting_field_has_fix_length
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_pad_when_nesting_field_has_fix_length(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>", fix_length=5)
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
minibatch = [
["john", "loves", "mary"],
["mary", "cries"]
]
expected = [
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 2,
["<w>"] + list("joh") + ["</w>"],
["<w>"] + list("lov") + ["</w>"],
["<w>"] + list("mar") + ["</w>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 2,
],
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 2,
["<w>"] + list("mar") + ["</w>"],
["<w>"] + list("cri") + ["</w>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 2,
["<cpad>"] * 5,
]
]
assert CHARS.pad(minibatch) == expected
# test include length
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>", fix_length=5)
CHARS = data.NestedField(nesting_field, init_token="<s>",
eos_token="</s>", include_lengths=True)
arr, seq_len, words_len = CHARS.pad(minibatch)
assert arr == expected
assert seq_len == [5, 4]
assert words_len == [[3, 5, 5, 5, 3], [3, 5, 5, 3, 0]]
示例12: test_pad_when_pad_first_is_true
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_pad_when_pad_first_is_true(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>",
pad_first=True)
minibatch = [
[list("john"), list("loves"), list("mary")],
[list("mary"), list("cries")],
]
expected = [
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("john") + ["</w>", "<cpad>"],
["<w>"] + list("loves") + ["</w>"],
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
],
[
["<cpad>"] * 7,
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>"] + list("cries") + ["</w>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
]
]
assert CHARS.pad(minibatch) == expected
# test include_length
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>",
eos_token="</s>", include_lengths=True,
pad_first=True)
arr, seq_len, words_len = CHARS.pad(minibatch)
assert arr == expected
assert seq_len == [5, 4]
assert words_len == [[3, 6, 7, 6, 3], [0, 3, 6, 7, 3]]
示例13: test_serialization
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def test_serialization(self):
nesting_field = data.Field(batch_first=True)
field = data.NestedField(nesting_field)
ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
dataset = data.Dataset([ex1, ex2], [("words", field)])
field.build_vocab(dataset)
examples_data = [
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("john") + ["</w>", "<cpad>"],
["<w>"] + list("loves") + ["</w>"],
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
],
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>"] + list("cries") + ["</w>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
["<cpad>"] * 7,
]
]
field_pickle_filename = "char_field.pl"
field_pickle_path = os.path.join(self.test_dir, field_pickle_filename)
torch.save(field, field_pickle_path)
loaded_field = torch.load(field_pickle_path)
assert loaded_field == field
original_numericalization = field.numericalize(examples_data)
pickled_numericalization = loaded_field.numericalize(examples_data)
assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
示例14: create_fields
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def create_fields(self, seq_input=True, seq_ner=True, seq_cat=False):
if self.level == "word":
sentence_field = data.Field(sequential=seq_input, preprocessing=self.preprocessor, fix_length=self.fix_length,
init_token="<start>", eos_token="<end>")
elif self.level == "char":
sentence_field = data.Field(sequential=seq_input, tokenize=self.evil_workaround_tokenizer, fix_length=1014)
# sentence_field = data.NestedField(nested_field)
else:
raise KeyError("Sentence_field is undefined!")
ner_label_field = data.Field(sequential=seq_ner, init_token="<start>", eos_token="<end>", unk_token=None)
category_label_field = data.LabelField(sequential=seq_cat)
return sentence_field, ner_label_field, category_label_field
示例15: __init__
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import NestedField [as 别名]
def __init__(self):
super(DMNIterator, self).__init__()
# Define text nested field
self.text_sent = data.Field(sequential=True,
lower=True,
tokenize=lambda x: x.split(" "))
self.text_doc = data.NestedField(self.text_sent,
tokenize=lambda x: x.split("<EOS>"),
include_lengths=True)
# Define entity nested field
self.entity_sent = data.Field(sequential=True,
tokenize=lambda x: x.split(" "),
unk_token=None)
self.entity_doc = data.NestedField(self.entity_sent,
tokenize=lambda x: x.split("<EOS>"))
# Define label nested field
self.label_sent = data.Field(sequential=True,
tokenize=lambda x: x.split(" "),
unk_token=None)
self.label_doc = data.NestedField(self.label_sent,
tokenize=lambda x: x.split("<EOS>"))
# Define offset nested field
self.offset_sent = self.InfoField(sequential=True,
tokenize=lambda x: x.split(" "),
use_vocab=False)
self.offset_doc = self.NestedInfoField(self.offset_sent,
tokenize=lambda x: x.split("<EOS>"),
use_vocab=False)
# Define length nested field
self.length_sent = self.InfoField(sequential=True,
tokenize=lambda x: x.split(" "),
use_vocab=False,
pad_token=None)
self.length_doc = self.NestedInfoField(self.length_sent,
tokenize=lambda x: x.split("<EOS>"),
use_vocab=False)
# Define word attention field
self.word_attn_sent = self.InfoField(sequential=True,
tokenize=lambda x: x.split(" "),
use_vocab=False)
self.word_attn_doc = self.NestedInfoField(self.word_attn_sent,
tokenize=lambda x: x.split("<EOS>"),
use_vocab=False)
# Define sentence attention field
self.sent_attn_doc = self.InfoField(sequential=True,
tokenize=lambda x: x.split("<EOS>"),
use_vocab=False)
# Define doc id field
self.doc_id = self.InfoField(sequential=False, use_vocab=False)
self.vectors = None