本文整理汇总了Python中train.Train方法的典型用法代码示例。如果您正苦于以下问题:Python train.Train方法的具体用法?Python train.Train怎么用?Python train.Train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类train
的用法示例。
在下文中一共展示了train.Train方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import train [as 别名]
# 或者: from train import Train [as 别名]
def main():
# Parse the JSON arguments
config_args = parse_args()
# Create the experiment directories
_, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(
config_args.experiment_dir)
model = MobileNetV2(config_args)
if config_args.cuda:
model.cuda()
cudnn.enabled = True
cudnn.benchmark = True
print("Loading Data...")
data = CIFAR10Data(config_args)
print("Data loaded successfully\n")
trainer = Train(model, data.trainloader, data.testloader, config_args)
if config_args.to_train:
try:
print("Training...")
trainer.train()
print("Training Finished\n")
except KeyboardInterrupt:
pass
if config_args.to_test:
print("Testing...")
trainer.test(data.testloader)
print("Testing Finished\n")
示例2: main
# 需要导入模块: import train [as 别名]
# 或者: from train import Train [as 别名]
def main():
# Parse the JSON arguments
try:
config_args = parse_args()
except:
print("Add a config file using \'--config file_name.json\'")
exit(1)
# Create the experiment directories
_, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(config_args.experiment_dir)
# Reset the default Tensorflow graph
tf.reset_default_graph()
# Tensorflow specific configuration
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Data loading
data = DataLoader(config_args.batch_size, config_args.shuffle)
print("Loading Data...")
config_args.img_height, config_args.img_width, config_args.num_channels, \
config_args.train_data_size, config_args.test_data_size = data.load_data()
print("Data loaded\n\n")
# Model creation
print("Building the model...")
model = MobileNet(config_args)
print("Model is built successfully\n\n")
# Summarizer creation
summarizer = Summarizer(sess, config_args.summary_dir)
# Train class
trainer = Train(sess, model, data, summarizer)
if config_args.to_train:
try:
print("Training...")
trainer.train()
print("Training Finished\n\n")
except KeyboardInterrupt:
trainer.save_model()
if config_args.to_test:
print("Final test!")
trainer.test('val')
print("Testing Finished\n\n")
示例3: main
# 需要导入模块: import train [as 别名]
# 或者: from train import Train [as 别名]
def main(args):
config = load_config(args)
global_eval_config = config["eval_params"]
models, model_names = config_modelloader(config, load_pretrain = True)
robust_errs = []
errs = []
for model, model_id, model_config in zip(models, model_names, config["models"]):
# make a copy of global training config, and update per-model config
eval_config = copy.deepcopy(global_eval_config)
if "eval_params" in model_config:
eval_config.update(model_config["eval_params"])
model = BoundSequential.convert(model, eval_config["method_params"]["bound_opts"])
model = model.cuda()
# read training parameters from config file
method = eval_config["method"]
verbose = eval_config["verbose"]
eps = eval_config["epsilon"]
# parameters specific to a training method
method_param = eval_config["method_params"]
norm = float(eval_config["norm"])
train_data, test_data = config_dataloader(config, **eval_config["loader_params"])
model_name = get_path(config, model_id, "model", load = False)
print(model_name)
model_log = get_path(config, model_id, "eval_log")
logger = Logger(open(model_log, "w"))
logger.log("evaluation configurations:", eval_config)
logger.log("Evaluating...")
with torch.no_grad():
# evaluate
robust_err, err = Train(model, 0, test_data, EpsilonScheduler("linear", 0, 0, eps, eps, 1), eps, norm, logger, verbose, False, None, method, **method_param)
robust_errs.append(robust_err)
errs.append(err)
print('model robust errors (for robustly trained models, not valid for naturally trained models):')
print(robust_errs)
robust_errs = np.array(robust_errs)
print('min: {:.4f}, max: {:.4f}, median: {:.4f}, mean: {:.4f}'.format(np.min(robust_errs), np.max(robust_errs), np.median(robust_errs), np.mean(robust_errs)))
print('clean errors for models with min, max and median robust errors')
i_min = np.argmin(robust_errs)
i_max = np.argmax(robust_errs)
i_median = np.argsort(robust_errs)[len(robust_errs) // 2]
print('for min: {:.4f}, for max: {:.4f}, for median: {:.4f}'.format(errs[i_min], errs[i_max], errs[i_median]))
print('model clean errors:')
print(errs)
print('min: {:.4f}, max: {:.4f}, median: {:.4f}, mean: {:.4f}'.format(np.min(errs), np.max(errs), np.median(errs), np.mean(errs)))
示例4: main
# 需要导入模块: import train [as 别名]
# 或者: from train import Train [as 别名]
def main():
# Parse the JSON arguments
config_args = parse_args()
# Create the experiment directories
_, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(config_args.experiment_dir)
# Reset the default Tensorflow graph
tf.reset_default_graph()
# Tensorflow specific configuration
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Data loading
# The batch size is equal to 1 when testing to simulate the real experiment.
data_batch_size = config_args.batch_size if config_args.train_or_test == "train" else 1
data = DataLoader(data_batch_size, config_args.shuffle)
print("Loading Data...")
config_args.img_height, config_args.img_width, config_args.num_channels, \
config_args.train_data_size, config_args.test_data_size = data.load_data()
print("Data loaded\n\n")
# Model creation
print("Building the model...")
model = ShuffleNet(config_args)
print("Model is built successfully\n\n")
# Parameters visualization
show_parameters()
# Summarizer creation
summarizer = Summarizer(sess, config_args.summary_dir)
# Train class
trainer = Train(sess, model, data, summarizer)
if config_args.train_or_test == 'train':
try:
# print("FLOPs for batch size = " + str(config_args.batch_size) + "\n")
# calculate_flops()
print("Training...")
trainer.train()
print("Training Finished\n\n")
except KeyboardInterrupt:
trainer.save_model()
elif config_args.train_or_test == 'test':
# print("FLOPs for single inference \n")
# calculate_flops()
# This can be 'val' or 'test' or even 'train' according to the needs.
print("Testing...")
trainer.test('val')
print("Testing Finished\n\n")
else:
raise ValueError("Train or Test options only are allowed")