当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow_datasets.builder方法代码示例

本文整理汇总了Python中tensorflow_datasets.builder方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow_datasets.builder方法的具体用法?Python tensorflow_datasets.builder怎么用?Python tensorflow_datasets.builder使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow_datasets的用法示例。


在下文中一共展示了tensorflow_datasets.builder方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: build_dataset

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def build_dataset(
    shape: Tuple[int, int],
    name: str="mnist",
    train_batch_size: int=32,
    valid_batch_size: int=32
    ):

    dataset = {}
    builder = tfds.builder(name)
    dataset["num_train"] = builder.info.splits['train'].num_examples
    dataset["num_test"] = builder.info.splits['test'].num_examples

    [ds_train, ds_test], info = tfds.load(name=name, split=["train", "test"], with_info=True)
    dataset["num_classes"] = info.features["label"].num_classes
    dataset["channels"] = ds_train.output_shapes["image"][-1].value

    ds_train = ds_train.shuffle(1024).repeat()
    ds_train = ds_train.map(lambda data: _parse_function(data, shape, dataset["num_classes"], dataset["channels"]))
    dataset["train"] = ds_train.batch(train_batch_size)

    ds_test = ds_test.shuffle(1024).repeat()
    ds_test = ds_test.map(lambda data: _parse_function(data, shape, dataset["num_classes"], dataset["channels"]))
    dataset["test"] = ds_test.batch(valid_batch_size)

    return dataset 
开发者ID:Bisonai,项目名称:mobilenetv3-tensorflow,代码行数:27,代码来源:datasets.py

示例2: count_max_boxes

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def count_max_boxes(cls, builder):
        sess = tf.compat.v1.Session()
        max_boxes = 0

        for split in builder.info.splits:
            tf_dataset = builder.as_dataset(split=split)
            iterator = tf.compat.v1.data.make_one_shot_iterator(tf_dataset)
            next_batch = iterator.get_next()

            while True:
                try:
                    data = sess.run(next_batch)
                    if max_boxes < data["objects"]["label"].shape[0]:
                        max_boxes = data["objects"]["label"].shape[0]
                except tf.errors.OutOfRangeError:
                    break

        return max_boxes 
开发者ID:blue-oil,项目名称:blueoil,代码行数:20,代码来源:tfds.py

示例3: _get_full_names

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def _get_full_names(datasets: Optional[List[str]] = None) -> List[str]:
  """List all builder names `ds/version` and `ds/config/version` to generate.

  Args:
    datasets: List of datasets from which get the builder names.

  Returns:
    builder_names: The builder names.
  """
  if datasets is None:
    return tfds.core.registered.list_full_names(
        current_version_only=True,
    )
  else:
    builder_names = list(itertools.chain.from_iterable([
        tfds.core.registered.single_full_names(builder_name)
        for builder_name in datasets
    ]))
    return builder_names 
开发者ID:tensorflow,项目名称:datasets,代码行数:21,代码来源:generate_visualization.py

示例4: document_single_builder

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def document_single_builder(builder):
  """Doc string for a single builder, with or without configs."""
  print('Document builder %s...' % builder.name)
  get_config_builder = lambda config: tfds.builder(builder.name, config=config)
  config_builders = []
  if builder.builder_configs:
    with futures.ThreadPoolExecutor(max_workers=WORKER_COUNT_CONFIGS) as tpool:
      config_builders = list(
          tpool.map(get_config_builder, builder.BUILDER_CONFIGS))
  tmpl = get_mako_template('dataset')
  visu_doc_util = VisualizationDocUtil()
  out_str = tmpl.render_unicode(
      builder=builder,
      config_builders=config_builders,
      visu_doc_util=visu_doc_util,
      nightly_doc_util=NightlyDocUtil(),
  ).strip()
  schema_org_tmpl = get_mako_template('schema_org')
  schema_org_out_str = schema_org_tmpl.render_unicode(
      builder=builder,
      config_builders=config_builders,
      visu_doc_util=visu_doc_util,
  ).strip()
  out_str = schema_org_out_str + '\n' + out_str
  return out_str 
开发者ID:tensorflow,项目名称:datasets,代码行数:27,代码来源:document_datasets.py

示例5: _representative_dataset_gen

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def _representative_dataset_gen():
  """Gets a python generator of numpy arrays for the given dataset."""
  image_size = FLAGS.image_size
  dataset = tfds.builder(FLAGS.dataset_name, data_dir=FLAGS.dataset_dir)
  dataset.download_and_prepare()
  data = dataset.as_dataset()[FLAGS.dataset_split]
  iterator = tf.data.make_one_shot_iterator(data)
  if FLAGS.use_model_specific_preprocessing:
    preprocess_fn = functools.partial(
        preprocessing_factory.get_preprocessing(name=FLAGS.model_name),
        output_height=image_size,
        output_width=image_size)
  else:
    preprocess_fn = functools.partial(
        _preprocess_for_quantization, image_size=image_size)
  features = iterator.get_next()
  image = features["image"]
  image = preprocess_fn(image)
  image = tf.reshape(image, [1, image_size, image_size, 3])
  for _ in range(FLAGS.num_steps):
    yield [image.eval()] 
开发者ID:tensorflow,项目名称:models,代码行数:23,代码来源:post_training_quantization.py

示例6: train_and_eval_dataset

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def train_and_eval_dataset(dataset_name, data_dir):
  """Return train and evaluation datasets, feature info and supervised keys.

  Args:
    dataset_name: a string, the name of the dataset; if it starts with "v1_"
      then we'll search T2T Problem registry for it, otherwise we assume it
      is a dataset from TFDS and load it from there.
    data_dir: directory where the data is located.

  Returns:
    a 4-tuple consisting of:
     * the train tf.data.Dataset
     * the eval tf.data.Dataset
     * information about features: a python dictionary with feature names
         as keys and an object as value that provides .shape and .num_classes.
     * supervised_keys: information what's the input and what's the target,
         ie., a pair of lists with input and target feature names.
  """
  if dataset_name.startswith("v1_"):
    return _train_and_eval_dataset_v1(dataset_name[3:], data_dir)
  dataset_builder = tfds.builder(dataset_name, data_dir=data_dir)
  info = dataset_builder.info
  splits = dataset_builder.info.splits
  if tfds.Split.TRAIN not in splits:
    raise ValueError("To train we require a train split in the dataset.")
  if tfds.Split.VALIDATION not in splits and "test" not in splits:
    raise ValueError("We require a validation or test split in the dataset.")
  eval_split = tfds.Split.VALIDATION
  if tfds.Split.VALIDATION not in splits:
    eval_split = tfds.Split.TEST
  train, valid = tfds.load(
      name=dataset_name, split=[tfds.Split.TRAIN, eval_split])
  keys = None
  if info.supervised_keys:
    keys = ([info.supervised_keys[0]], [info.supervised_keys[1]])
  return train, valid, info.features, keys 
开发者ID:yyht,项目名称:BERT,代码行数:38,代码来源:t2t.py

示例7: download_and_prepare

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def download_and_prepare(dataset_name, data_dir):
  """Downloads and prepares T2T or TFDS dataset.

  Args:
    dataset_name: tfds dataset or t2t problem name prefixed by 't2t_'.
    data_dir: location of existing dataset or None.

  Returns:
    data_dir: path string of downloaded data.
  """
  if not data_dir:
    data_dir = os.path.expanduser('~/tensorflow_datasets/')
    dl_dir = os.path.join(data_dir, 'download')
    logging.info(
        'No dataset directory provided. '
        'Downloading and generating dataset for %s inside data directory %s '
        'For large datasets it is better to prepare datasets manually!',
        dataset_name, data_dir)
    if dataset_name.startswith('t2t_'):
      # Download and run dataset generator for T2T problem.
      data_dir = os.path.join(data_dir, dataset_name)
      tf.io.gfile.makedirs(data_dir)
      tf.io.gfile.makedirs(dl_dir)
      t2t_problems().problem(
          dataset_name[len('t2t_'):]).generate_data(data_dir, dl_dir)
    else:
      # Download and prepare TFDS dataset.
      tfds_builder = tfds.builder(dataset_name)
      tfds_builder.download_and_prepare(download_dir=dl_dir)
  else:
    data_dir = os.path.expanduser(data_dir)
  return data_dir 
开发者ID:google,项目名称:trax,代码行数:34,代码来源:tf_inputs.py

示例8: __init__

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def __init__(
            self,
            name,
            data_dir,
            image_size,
            download=False,
            num_max_boxes=None,
            tfds_pre_processor=None,
            tfds_augmentor=None,
            *args,
            **kwargs
    ):
        super().__init__(
            *args,
            **kwargs,
        )

        if name in tfds.list_builders():
            self._builder = tfds.builder(name, data_dir=data_dir)
            if download:
                self._builder.download_and_prepare()
        else:
            if not tf.io.gfile.exists(os.path.join(data_dir, name)):
                raise ValueError("Dataset directory does not exist: {}\n"
                                 "Please run `python blueoil/cmd/build_tfds.py -c <config file>` before training."
                                 .format(os.path.join(data_dir, name)))

            self._builder = self.builder_class(name, data_dir=data_dir)

        self.info = self._builder.info
        self._init_available_splits()
        self._validate_feature_structure()

        self.tf_dataset = self._builder.as_dataset(split=self.available_splits[self.subset])
        self.tfds_pre_processor = tfds_pre_processor
        self.tfds_augmentor = tfds_augmentor
        self._image_size = image_size
        self._num_max_boxes = num_max_boxes
        self._format_dataset() 
开发者ID:blue-oil,项目名称:blueoil,代码行数:41,代码来源:tfds.py

示例9: download_and_prepare

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def download_and_prepare(builder):
  """Generate data for a given dataset."""
  logging.info("download_and_prepare for dataset %s...", builder.info.full_name)

  dl_config = download_config()

  if isinstance(builder, tfds.core.BeamBasedBuilder):
    beam = tfds.core.lazy_imports.apache_beam
    # TODO(b/129149715): Restore compute stats. Currently skipped because not
    # beam supported.
    dl_config.compute_stats = tfds.download.ComputeStatsMode.SKIP
    dl_config.beam_options = beam.options.pipeline_options.PipelineOptions(
        flags=["--%s" % opt for opt in FLAGS.beam_pipeline_options])

  if FLAGS.add_name_to_manual_dir:
    dl_config.manual_dir = os.path.join(dl_config.manual_dir, builder.name)

  builder.download_and_prepare(
      download_dir=FLAGS.download_dir,
      download_config=dl_config,
  )
  termcolor.cprint(str(builder.info.as_proto), attrs=["bold"])

  if FLAGS.debug:
    dataset = builder.as_dataset(split=tfds.Split.TRAIN)
    pdb.set_trace()
    del dataset 
开发者ID:tensorflow,项目名称:datasets,代码行数:29,代码来源:download_and_prepare.py

示例10: _get_name

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def _get_name(self, builder):
    return builder.info.full_name.replace('/', '-') + '.png' 
开发者ID:tensorflow,项目名称:datasets,代码行数:4,代码来源:document_datasets.py

示例11: get_url

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def get_url(self, builder):
    return self.BASE_URL + self._get_name(builder) 
开发者ID:tensorflow,项目名称:datasets,代码行数:4,代码来源:document_datasets.py

示例12: has_visualization

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def has_visualization(self, builder):
    filepath = os.path.join(self.BASE_PATH, self._get_name(builder))
    return tf.io.gfile.exists(filepath) 
开发者ID:tensorflow,项目名称:datasets,代码行数:5,代码来源:document_datasets.py

示例13: _split_full_name

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def _split_full_name(full_name: str) -> Tuple[str, str, str]:
  """Extracts the `(ds name, config, version)` from the full_name."""
  if not tfds.core.registered.is_full_name(full_name):
    raise ValueError(
        f'Parsing builder name string {full_name} failed.'
        'The builder name string must be of the following format:'
        '`dataset_name[/config_name]/version`')
  ds_name, *optional_config, version = full_name.split('/')
  assert len(optional_config) <= 1
  config = next(iter(optional_config)) if optional_config else ''
  return ds_name, config, version 
开发者ID:tensorflow,项目名称:datasets,代码行数:13,代码来源:document_datasets.py

示例14: is_builder_nightly

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def is_builder_nightly(
      self,
      builder: Union[tfds.core.DatasetBuilder, str],
  ) -> bool:
    """Returns `True` if the builder is new."""
    if isinstance(builder, tfds.core.DatasetBuilder):
      builder_name = builder.name
    else:
      builder_name = builder
    return self._nightly_dict[builder_name] is True  # pylint: disable=g-bool-id-comparison 
开发者ID:tensorflow,项目名称:datasets,代码行数:12,代码来源:document_datasets.py

示例15: is_config_nightly

# 需要导入模块: import tensorflow_datasets [as 别名]
# 或者: from tensorflow_datasets import builder [as 别名]
def is_config_nightly(self, builder: tfds.core.DatasetBuilder) -> bool:
    """Returns `True` if the config is new."""
    ds_name, config, _ = _split_full_name(builder.info.full_name)
    if self.is_builder_nightly(builder):
      return False
    return self._nightly_dict[ds_name][config] is True  # pylint: disable=g-bool-id-comparison 
开发者ID:tensorflow,项目名称:datasets,代码行数:8,代码来源:document_datasets.py


注:本文中的tensorflow_datasets.builder方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。