当前位置: 首页>>代码示例>>Python>>正文


Python core.config方法代码示例

本文整理汇总了Python中core.config方法的典型用法代码示例。如果您正苦于以下问题:Python core.config方法的具体用法?Python core.config怎么用?Python core.config使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在core的用法示例。


在下文中一共展示了core.config方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: cfg_merge_dicts

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_merge_dicts(dict_a, dict_b):
    from ast import literal_eval

    for key, value in dict_a.items():
        if key not in dict_b:
            raise KeyError('Invalid key in config file: {}'.format(key))
        if type(value) is dict:
            dict_a[key] = value = AttrDict(value)
        if isinstance(value, str):
            try:
                value = literal_eval(value)
            except BaseException:
                pass
        # the types must match, too
        old_type = type(dict_b[key])
        if old_type is not type(value) and value is not None:
            raise ValueError('Type mismatch ({} vs. {}) for config key: {}'.format(type(dict_b[key]), type(value), key))
        # recursively merge dicts
        if isinstance(value, AttrDict):
            try:
                cfg_merge_dicts(dict_a[key], dict_b[key])
            except BaseException:
                raise Exception('Error under config key: {}'.format(key))
        else:
            dict_b[key] = value 
开发者ID:CMU-CREATE-Lab,项目名称:deep-smoke-machine,代码行数:27,代码来源:config_utils.py

示例2: cfg_from_dict

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_from_dict(args_dict):
    """Set config keys via list (e.g., from command line)."""

    for key, value in args_dict.iteritems():
        key_list = key.split('.')
        cfg = __C
        for subkey in key_list[:-1]:
            assert subkey in cfg, 'Config key {} not found'.format(subkey)
            cfg = cfg[subkey]
        subkey = key_list[-1]
        if subkey not in cfg:
            raise Exception('Config key {} not found'.format(subkey))
        try:
            # handle the case when v is a string literal
            val = literal_eval(value)
        except BaseException:
            val = value
        if isinstance(val, type(cfg[subkey])) or cfg[subkey] is None:
            pass
        else:
            type1 = type(val)
            type2 = type(cfg[subkey])
            msg = 'type {} does not match original type {}'.format(type1, type2)
            raise Exception(msg)
        cfg[subkey] = val 
开发者ID:CMU-CREATE-Lab,项目名称:deep-smoke-machine,代码行数:27,代码来源:config_utils.py

示例3: cfg_from_list

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_from_list(args_list):
    """
    Set config keys via list (e.g., from command line).
    """
    from ast import literal_eval

    assert len(args_list) % 2 == 0, 'Specify values or keys for args'
    for key, value in zip(args_list[0::2], args_list[1::2]):
        key_list = key.split('.')
        cfg = __C
        for subkey in key_list[:-1]:
            assert subkey in cfg, 'Config key {} not found'.format(subkey)
            cfg = cfg[subkey]
        subkey = key_list[-1]
        assert subkey in cfg, 'Config key {} not found'.format(subkey)
        try:
            # handle the case when v is a string literal
            val = literal_eval(value)
        except BaseException:
            val = value
        msg = 'type {} does not match original type {}'.format(type(val), type(cfg[subkey]))
        assert isinstance(val, type(cfg[subkey])) or cfg[subkey] is None, msg
        cfg[subkey] = val 
开发者ID:CMU-CREATE-Lab,项目名称:deep-smoke-machine,代码行数:25,代码来源:config_utils.py

示例4: do_reval

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def do_reval(dataset_name, output_dir, args):
    dataset = JsonDataset(dataset_name)
    with open(os.path.join(output_dir, 'detections.pkl'), 'rb') as f:
        dets = pickle.load(f)
    # Override config with the one saved in the detections file
    if args.cfg_file is not None:
        core.config.merge_cfg_from_cfg(yaml.load(dets['cfg']))
    else:
        core.config._merge_a_into_b(yaml.load(dets['cfg']), cfg)
    results = task_evaluation.evaluate_all(
        dataset,
        dets['all_boxes'],
        dets['all_segms'],
        dets['all_keyps'],
        output_dir,
        use_matlab=args.matlab_eval
    )
    task_evaluation.log_copy_paste_friendly_results(results) 
开发者ID:ronghanghu,项目名称:seg_every_thing,代码行数:20,代码来源:reval.py

示例5: __config_gpu_for_keras

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __config_gpu_for_keras():
    import tensorflow as tf
    import keras.backend as K

    gpu_core_id = __parse_gpu_id()

    K.clear_session()
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(gpu_core_id)
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    K.set_session(session)

    # set which device to be used
    const.GPU_CORE_ID = gpu_core_id 
开发者ID:CMU-CREATE-Lab,项目名称:deep-smoke-machine,代码行数:17,代码来源:config_utils.py

示例6: cfg_from_file

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_from_file(file_path, is_check=True):
    """
    Load a config file and merge it into the default options.
    """

    # read from file
    yaml_config = utils.yaml_load(file_path)

    # merge to project config
    cfg_merge_dicts(yaml_config, __C)

    # make sure everything is okay
    if is_check:
        cfg_sanity_check() 
开发者ID:CMU-CREATE-Lab,项目名称:deep-smoke-machine,代码行数:16,代码来源:config_utils.py

示例7: __define_loader

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __define_loader(is_training):
    """
    Define data loader.
    """

    # get some configs for the training
    n_classes = config.cfg.MODEL.N_CLASSES
    dataset_name = config.cfg.DATASET_NAME
    backbone_model_name = config.cfg.MODEL.BACKBONE_CNN
    backbone_feature_name = config.cfg.MODEL.BACKBONE_FEATURE
    n_timesteps = config.cfg.MODEL.N_TC_TIMESTEPS
    n_workers = config.cfg.TRAIN.N_WORKERS

    batch_size_tr = config.cfg.TRAIN.BATCH_SIZE
    batch_size_te = config.cfg.TEST.BATCH_SIZE
    batch_size = batch_size_tr if is_training else batch_size_te

    # size and name of feature
    feature_name = 'features_%s_%s_%sf' % (backbone_model_name, backbone_feature_name, n_timesteps)
    c, h, w = utils.get_model_feat_maps_info(backbone_model_name, backbone_feature_name)
    feature_dim = (c, n_timesteps, h, w)

    # data generators
    params = {'batch_size': batch_size, 'n_classes': n_classes, 'feature_name': feature_name, 'feature_dim': feature_dim, 'is_training': is_training}
    dataset_class = data_utils.PYTORCH_DATASETS_DICT[dataset_name]
    dataset = dataset_class(**params)
    n_samples = dataset.n_samples
    n_batches = dataset.n_batches

    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True)

    return data_loader, n_samples, n_batches 
开发者ID:noureldien,项目名称:timeception,代码行数:34,代码来源:train_pytorch.py

示例8: __define_timeception_model

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __define_timeception_model(device):
    """
    Define model, optimizer, loss function and metric function.
    """
    # some configurations
    classification_type = config.cfg.MODEL.CLASSIFICATION_TYPE
    solver_name = config.cfg.SOLVER.NAME
    solver_lr = config.cfg.SOLVER.LR
    adam_epsilon = config.cfg.SOLVER.ADAM_EPSILON

    # define model
    model = Model().to(device)
    model_param = model.parameters()

    # define the optimizer
    optimizer = SGD(model_param, lr=0.01) if solver_name == 'sgd' else Adam(model_param, lr=solver_lr, eps=adam_epsilon)

    # loss and evaluation function for either multi-label "ml" or single-label "sl" classification
    if classification_type == 'ml':
        loss_fn = torch.nn.BCELoss()
        metric_fn = metrics.map_charades
        metric_fn_name = 'map'
    else:
        loss_fn = torch.nn.NLLLoss()
        metric_fn = metrics.accuracy
        metric_fn_name = 'acc'

    return model, optimizer, loss_fn, metric_fn, metric_fn_name 
开发者ID:noureldien,项目名称:timeception,代码行数:30,代码来源:train_pytorch.py

示例9: __init__

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __init__(self):
        super(Model, self).__init__()

        # some configurations for the model
        n_tc_timesteps = config.cfg.MODEL.N_TC_TIMESTEPS
        backbone_name = config.cfg.MODEL.BACKBONE_CNN
        feature_name = config.cfg.MODEL.BACKBONE_FEATURE
        n_tc_layers = config.cfg.MODEL.N_TC_LAYERS
        n_classes = config.cfg.MODEL.N_CLASSES
        is_dilated = config.cfg.MODEL.MULTISCALE_TYPE
        OutputActivation = Sigmoid if config.cfg.MODEL.CLASSIFICATION_TYPE == 'ml' else LogSoftmax
        n_channels_in, channel_h, channel_w = utils.get_model_feat_maps_info(backbone_name, feature_name)
        n_groups = int(n_channels_in / 128.0)

        input_shape = (None, n_channels_in, n_tc_timesteps, channel_h, channel_w)  # (C, T, H, W)
        self._input_shape = input_shape

        # define 4 layers of timeception
        self.timeception = timeception_pytorch.Timeception(input_shape, n_tc_layers, n_groups, is_dilated)  # (C, T, H, W)

        # get number of output channels after timeception
        n_channels_in = self.timeception.n_channels_out

        # define layers for classifier
        self.do1 = Dropout(0.5)
        self.l1 = Linear(n_channels_in, 512)
        self.bn1 = BatchNorm1d(512)
        self.ac1 = LeakyReLU(0.2)
        self.do2 = Dropout(0.25)
        self.l2 = Linear(512, n_classes)
        self.ac2 = OutputActivation() 
开发者ID:noureldien,项目名称:timeception,代码行数:33,代码来源:train_pytorch.py

示例10: __main

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __main():
    """
    Run this script to train Timeception.
    """

    default_config_file = 'charades_i3d_tc4_f1024.yaml'
    default_config_file = 'charades_i3d_tc2_f256.yaml'

    # Parse the arguments
    parser = OptionParser()
    parser.add_option('-c', '--config_file', dest='config_file', default=default_config_file, help='Yaml config file that contains all training details.')
    (options, args) = parser.parse_args()
    config_file = options.config_file

    # check if exist
    if config_file is None or config_file == '':
        msg = 'Config file not passed, default config is used: %s' % (config_file)
        logging.warning(msg)
        config_file = default_config_file

    # path of config file
    config_path = './configs/%s' % (config_file)

    # check if file exist
    if not os.path.exists(config_path):
        msg = 'Sorry, could not find config file with the following path: %s' % (config_path)
        logging.error(msg)
    else:
        # read the config from file and copy it to the project configuration "cfg"
        config_utils.cfg_from_file(config_path)

        # choose which training scheme, either 'ete' or 'tco'
        training_scheme = config.cfg.TRAIN.SCHEME

        # start training
        if training_scheme == 'tco':
            train_tco()
        else:
            train_ete() 
开发者ID:noureldien,项目名称:timeception,代码行数:41,代码来源:train_pytorch.py

示例11: train_tco

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def train_tco():
    """
    Train Timeception layers based on the given configurations.
    This train scheme is Timeception-only (TCO).
    """

    # get some configs for the training
    n_workers = config.cfg.TRAIN.N_WORKERS
    n_epochs = config.cfg.TRAIN.N_EPOCHS
    dataset_name = config.cfg.DATASET_NAME
    model_name = '%s_%s' % (config.cfg.MODEL.NAME, utils.timestamp())

    # data generators
    data_generator_tr = __define_data_generator(is_training=True)
    data_generator_te = __define_data_generator(is_training=False)

    logger.info('--- start time')
    logger.info(datetime.datetime.now())
    logger.info('... [tr]: n_samples, n_batch, batch_size: %d, %d, %d' % (data_generator_tr.n_samples, data_generator_tr.n_batches, config.cfg.TRAIN.BATCH_SIZE))
    logger.info('... [te]: n_samples, n_batch, batch_size: %d, %d, %d' % (data_generator_te.n_samples, data_generator_te.n_batches, config.cfg.TEST.BATCH_SIZE))

    # callback to save the model
    save_callback = keras_utils.SaveCallback(dataset_name, model_name)

    # load model
    model = __define_timeception_model()
    logger.info(model.summary())

    # train the model
    model.fit_generator(epochs=n_epochs, generator=data_generator_tr, validation_data=data_generator_te, use_multiprocessing=True, workers=n_workers, callbacks=[save_callback], verbose=2)

    logger.info('--- finish time')
    logger.info(datetime.datetime.now()) 
开发者ID:noureldien,项目名称:timeception,代码行数:35,代码来源:train_keras.py

示例12: test_merge_cfg_from_file

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_merge_cfg_from_file(self):
        with tempfile.NamedTemporaryFile() as f:
            yaml.dump(cfg, f)
            s = cfg.MODEL.TYPE
            cfg.MODEL.TYPE = 'dummy'
            assert cfg.MODEL.TYPE != s
            core.config.merge_cfg_from_file(f.name)
            assert cfg.MODEL.TYPE == s 
开发者ID:ronghanghu,项目名称:seg_every_thing,代码行数:10,代码来源:test_cfg.py

示例13: test_merge_cfg_from_list

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_merge_cfg_from_list(self):
        opts = [
            'TRAIN.SCALES', '(100, )', 'MODEL.TYPE', u'foobar', 'NUM_GPUS', 2
        ]
        assert len(cfg.TRAIN.SCALES) > 0
        assert cfg.TRAIN.SCALES[0] != 100
        assert cfg.MODEL.TYPE != 'foobar'
        assert cfg.NUM_GPUS != 2
        core.config.merge_cfg_from_list(opts)
        assert type(cfg.TRAIN.SCALES) is tuple
        assert len(cfg.TRAIN.SCALES) == 1
        assert cfg.TRAIN.SCALES[0] == 100
        assert cfg.MODEL.TYPE == 'foobar'
        assert cfg.NUM_GPUS == 2 
开发者ID:ronghanghu,项目名称:seg_every_thing,代码行数:16,代码来源:test_cfg.py

示例14: test_deprecated_key_from_list

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_deprecated_key_from_list(self):
        # You should see logger messages like:
        #   "Deprecated config key (ignoring): MODEL.DILATION"
        opts = ['FINAL_MSG', 'foobar', 'MODEL.DILATION', 2]
        with self.assertRaises(AttributeError):
            _ = cfg.FINAL_MSG  # noqa
        with self.assertRaises(AttributeError):
            _ = cfg.MODEL.DILATION  # noqa
        core.config.merge_cfg_from_list(opts)
        with self.assertRaises(AttributeError):
            _ = cfg.FINAL_MSG  # noqa
        with self.assertRaises(AttributeError):
            _ = cfg.MODEL.DILATION  # noqa 
开发者ID:ronghanghu,项目名称:seg_every_thing,代码行数:15,代码来源:test_cfg.py

示例15: test_deprecated_key_from_file

# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_deprecated_key_from_file(self):
        # You should see logger messages like:
        #   "Deprecated config key (ignoring): MODEL.DILATION"
        with tempfile.NamedTemporaryFile() as f:
            cfg2 = copy.deepcopy(cfg)
            cfg2.MODEL.DILATION = 2
            yaml.dump(cfg2, f)
            with self.assertRaises(AttributeError):
                _ = cfg.MODEL.DILATION  # noqa
            core.config.merge_cfg_from_file(f.name)
            with self.assertRaises(AttributeError):
                _ = cfg.MODEL.DILATION  # noqa 
开发者ID:ronghanghu,项目名称:seg_every_thing,代码行数:14,代码来源:test_cfg.py


注:本文中的core.config方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。