本文整理汇总了Python中core.config方法的典型用法代码示例。如果您正苦于以下问题:Python core.config方法的具体用法?Python core.config怎么用?Python core.config使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类core
的用法示例。
在下文中一共展示了core.config方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: cfg_merge_dicts
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_merge_dicts(dict_a, dict_b):
from ast import literal_eval
for key, value in dict_a.items():
if key not in dict_b:
raise KeyError('Invalid key in config file: {}'.format(key))
if type(value) is dict:
dict_a[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
# the types must match, too
old_type = type(dict_b[key])
if old_type is not type(value) and value is not None:
raise ValueError('Type mismatch ({} vs. {}) for config key: {}'.format(type(dict_b[key]), type(value), key))
# recursively merge dicts
if isinstance(value, AttrDict):
try:
cfg_merge_dicts(dict_a[key], dict_b[key])
except BaseException:
raise Exception('Error under config key: {}'.format(key))
else:
dict_b[key] = value
示例2: cfg_from_dict
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_from_dict(args_dict):
"""Set config keys via list (e.g., from command line)."""
for key, value in args_dict.iteritems():
key_list = key.split('.')
cfg = __C
for subkey in key_list[:-1]:
assert subkey in cfg, 'Config key {} not found'.format(subkey)
cfg = cfg[subkey]
subkey = key_list[-1]
if subkey not in cfg:
raise Exception('Config key {} not found'.format(subkey))
try:
# handle the case when v is a string literal
val = literal_eval(value)
except BaseException:
val = value
if isinstance(val, type(cfg[subkey])) or cfg[subkey] is None:
pass
else:
type1 = type(val)
type2 = type(cfg[subkey])
msg = 'type {} does not match original type {}'.format(type1, type2)
raise Exception(msg)
cfg[subkey] = val
示例3: cfg_from_list
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_from_list(args_list):
"""
Set config keys via list (e.g., from command line).
"""
from ast import literal_eval
assert len(args_list) % 2 == 0, 'Specify values or keys for args'
for key, value in zip(args_list[0::2], args_list[1::2]):
key_list = key.split('.')
cfg = __C
for subkey in key_list[:-1]:
assert subkey in cfg, 'Config key {} not found'.format(subkey)
cfg = cfg[subkey]
subkey = key_list[-1]
assert subkey in cfg, 'Config key {} not found'.format(subkey)
try:
# handle the case when v is a string literal
val = literal_eval(value)
except BaseException:
val = value
msg = 'type {} does not match original type {}'.format(type(val), type(cfg[subkey]))
assert isinstance(val, type(cfg[subkey])) or cfg[subkey] is None, msg
cfg[subkey] = val
示例4: do_reval
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def do_reval(dataset_name, output_dir, args):
dataset = JsonDataset(dataset_name)
with open(os.path.join(output_dir, 'detections.pkl'), 'rb') as f:
dets = pickle.load(f)
# Override config with the one saved in the detections file
if args.cfg_file is not None:
core.config.merge_cfg_from_cfg(yaml.load(dets['cfg']))
else:
core.config._merge_a_into_b(yaml.load(dets['cfg']), cfg)
results = task_evaluation.evaluate_all(
dataset,
dets['all_boxes'],
dets['all_segms'],
dets['all_keyps'],
output_dir,
use_matlab=args.matlab_eval
)
task_evaluation.log_copy_paste_friendly_results(results)
示例5: __config_gpu_for_keras
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __config_gpu_for_keras():
import tensorflow as tf
import keras.backend as K
gpu_core_id = __parse_gpu_id()
K.clear_session()
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(gpu_core_id)
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
K.set_session(session)
# set which device to be used
const.GPU_CORE_ID = gpu_core_id
示例6: cfg_from_file
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def cfg_from_file(file_path, is_check=True):
"""
Load a config file and merge it into the default options.
"""
# read from file
yaml_config = utils.yaml_load(file_path)
# merge to project config
cfg_merge_dicts(yaml_config, __C)
# make sure everything is okay
if is_check:
cfg_sanity_check()
示例7: __define_loader
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __define_loader(is_training):
"""
Define data loader.
"""
# get some configs for the training
n_classes = config.cfg.MODEL.N_CLASSES
dataset_name = config.cfg.DATASET_NAME
backbone_model_name = config.cfg.MODEL.BACKBONE_CNN
backbone_feature_name = config.cfg.MODEL.BACKBONE_FEATURE
n_timesteps = config.cfg.MODEL.N_TC_TIMESTEPS
n_workers = config.cfg.TRAIN.N_WORKERS
batch_size_tr = config.cfg.TRAIN.BATCH_SIZE
batch_size_te = config.cfg.TEST.BATCH_SIZE
batch_size = batch_size_tr if is_training else batch_size_te
# size and name of feature
feature_name = 'features_%s_%s_%sf' % (backbone_model_name, backbone_feature_name, n_timesteps)
c, h, w = utils.get_model_feat_maps_info(backbone_model_name, backbone_feature_name)
feature_dim = (c, n_timesteps, h, w)
# data generators
params = {'batch_size': batch_size, 'n_classes': n_classes, 'feature_name': feature_name, 'feature_dim': feature_dim, 'is_training': is_training}
dataset_class = data_utils.PYTORCH_DATASETS_DICT[dataset_name]
dataset = dataset_class(**params)
n_samples = dataset.n_samples
n_batches = dataset.n_batches
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True)
return data_loader, n_samples, n_batches
示例8: __define_timeception_model
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __define_timeception_model(device):
"""
Define model, optimizer, loss function and metric function.
"""
# some configurations
classification_type = config.cfg.MODEL.CLASSIFICATION_TYPE
solver_name = config.cfg.SOLVER.NAME
solver_lr = config.cfg.SOLVER.LR
adam_epsilon = config.cfg.SOLVER.ADAM_EPSILON
# define model
model = Model().to(device)
model_param = model.parameters()
# define the optimizer
optimizer = SGD(model_param, lr=0.01) if solver_name == 'sgd' else Adam(model_param, lr=solver_lr, eps=adam_epsilon)
# loss and evaluation function for either multi-label "ml" or single-label "sl" classification
if classification_type == 'ml':
loss_fn = torch.nn.BCELoss()
metric_fn = metrics.map_charades
metric_fn_name = 'map'
else:
loss_fn = torch.nn.NLLLoss()
metric_fn = metrics.accuracy
metric_fn_name = 'acc'
return model, optimizer, loss_fn, metric_fn, metric_fn_name
示例9: __init__
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __init__(self):
super(Model, self).__init__()
# some configurations for the model
n_tc_timesteps = config.cfg.MODEL.N_TC_TIMESTEPS
backbone_name = config.cfg.MODEL.BACKBONE_CNN
feature_name = config.cfg.MODEL.BACKBONE_FEATURE
n_tc_layers = config.cfg.MODEL.N_TC_LAYERS
n_classes = config.cfg.MODEL.N_CLASSES
is_dilated = config.cfg.MODEL.MULTISCALE_TYPE
OutputActivation = Sigmoid if config.cfg.MODEL.CLASSIFICATION_TYPE == 'ml' else LogSoftmax
n_channels_in, channel_h, channel_w = utils.get_model_feat_maps_info(backbone_name, feature_name)
n_groups = int(n_channels_in / 128.0)
input_shape = (None, n_channels_in, n_tc_timesteps, channel_h, channel_w) # (C, T, H, W)
self._input_shape = input_shape
# define 4 layers of timeception
self.timeception = timeception_pytorch.Timeception(input_shape, n_tc_layers, n_groups, is_dilated) # (C, T, H, W)
# get number of output channels after timeception
n_channels_in = self.timeception.n_channels_out
# define layers for classifier
self.do1 = Dropout(0.5)
self.l1 = Linear(n_channels_in, 512)
self.bn1 = BatchNorm1d(512)
self.ac1 = LeakyReLU(0.2)
self.do2 = Dropout(0.25)
self.l2 = Linear(512, n_classes)
self.ac2 = OutputActivation()
示例10: __main
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def __main():
"""
Run this script to train Timeception.
"""
default_config_file = 'charades_i3d_tc4_f1024.yaml'
default_config_file = 'charades_i3d_tc2_f256.yaml'
# Parse the arguments
parser = OptionParser()
parser.add_option('-c', '--config_file', dest='config_file', default=default_config_file, help='Yaml config file that contains all training details.')
(options, args) = parser.parse_args()
config_file = options.config_file
# check if exist
if config_file is None or config_file == '':
msg = 'Config file not passed, default config is used: %s' % (config_file)
logging.warning(msg)
config_file = default_config_file
# path of config file
config_path = './configs/%s' % (config_file)
# check if file exist
if not os.path.exists(config_path):
msg = 'Sorry, could not find config file with the following path: %s' % (config_path)
logging.error(msg)
else:
# read the config from file and copy it to the project configuration "cfg"
config_utils.cfg_from_file(config_path)
# choose which training scheme, either 'ete' or 'tco'
training_scheme = config.cfg.TRAIN.SCHEME
# start training
if training_scheme == 'tco':
train_tco()
else:
train_ete()
示例11: train_tco
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def train_tco():
"""
Train Timeception layers based on the given configurations.
This train scheme is Timeception-only (TCO).
"""
# get some configs for the training
n_workers = config.cfg.TRAIN.N_WORKERS
n_epochs = config.cfg.TRAIN.N_EPOCHS
dataset_name = config.cfg.DATASET_NAME
model_name = '%s_%s' % (config.cfg.MODEL.NAME, utils.timestamp())
# data generators
data_generator_tr = __define_data_generator(is_training=True)
data_generator_te = __define_data_generator(is_training=False)
logger.info('--- start time')
logger.info(datetime.datetime.now())
logger.info('... [tr]: n_samples, n_batch, batch_size: %d, %d, %d' % (data_generator_tr.n_samples, data_generator_tr.n_batches, config.cfg.TRAIN.BATCH_SIZE))
logger.info('... [te]: n_samples, n_batch, batch_size: %d, %d, %d' % (data_generator_te.n_samples, data_generator_te.n_batches, config.cfg.TEST.BATCH_SIZE))
# callback to save the model
save_callback = keras_utils.SaveCallback(dataset_name, model_name)
# load model
model = __define_timeception_model()
logger.info(model.summary())
# train the model
model.fit_generator(epochs=n_epochs, generator=data_generator_tr, validation_data=data_generator_te, use_multiprocessing=True, workers=n_workers, callbacks=[save_callback], verbose=2)
logger.info('--- finish time')
logger.info(datetime.datetime.now())
示例12: test_merge_cfg_from_file
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_merge_cfg_from_file(self):
with tempfile.NamedTemporaryFile() as f:
yaml.dump(cfg, f)
s = cfg.MODEL.TYPE
cfg.MODEL.TYPE = 'dummy'
assert cfg.MODEL.TYPE != s
core.config.merge_cfg_from_file(f.name)
assert cfg.MODEL.TYPE == s
示例13: test_merge_cfg_from_list
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_merge_cfg_from_list(self):
opts = [
'TRAIN.SCALES', '(100, )', 'MODEL.TYPE', u'foobar', 'NUM_GPUS', 2
]
assert len(cfg.TRAIN.SCALES) > 0
assert cfg.TRAIN.SCALES[0] != 100
assert cfg.MODEL.TYPE != 'foobar'
assert cfg.NUM_GPUS != 2
core.config.merge_cfg_from_list(opts)
assert type(cfg.TRAIN.SCALES) is tuple
assert len(cfg.TRAIN.SCALES) == 1
assert cfg.TRAIN.SCALES[0] == 100
assert cfg.MODEL.TYPE == 'foobar'
assert cfg.NUM_GPUS == 2
示例14: test_deprecated_key_from_list
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_deprecated_key_from_list(self):
# You should see logger messages like:
# "Deprecated config key (ignoring): MODEL.DILATION"
opts = ['FINAL_MSG', 'foobar', 'MODEL.DILATION', 2]
with self.assertRaises(AttributeError):
_ = cfg.FINAL_MSG # noqa
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
core.config.merge_cfg_from_list(opts)
with self.assertRaises(AttributeError):
_ = cfg.FINAL_MSG # noqa
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
示例15: test_deprecated_key_from_file
# 需要导入模块: import core [as 别名]
# 或者: from core import config [as 别名]
def test_deprecated_key_from_file(self):
# You should see logger messages like:
# "Deprecated config key (ignoring): MODEL.DILATION"
with tempfile.NamedTemporaryFile() as f:
cfg2 = copy.deepcopy(cfg)
cfg2.MODEL.DILATION = 2
yaml.dump(cfg2, f)
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
core.config.merge_cfg_from_file(f.name)
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa