本文整理汇总了Python中updater.Updater方法的典型用法代码示例。如果您正苦于以下问题:Python updater.Updater方法的具体用法?Python updater.Updater怎么用?Python updater.Updater使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类updater
的用法示例。
在下文中一共展示了updater.Updater方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: updater_wrapper
# 需要导入模块: import updater [as 别名]
# 或者: from updater import Updater [as 别名]
def updater_wrapper(self, cnt_round, dic_agent_conf, dic_exp_conf, dic_traffic_env_conf, dic_path, best_round=None,
bar_round=None):
updater = Updater(
cnt_round=cnt_round,
dic_agent_conf=dic_agent_conf,
dic_exp_conf=dic_exp_conf,
dic_traffic_env_conf=dic_traffic_env_conf,
dic_path=dic_path,
best_round=best_round,
bar_round=bar_round
)
updater.load_sample_for_agents()
updater.update_network_for_agents()
print("updater_wrapper end")
return
示例2: main
# 需要导入模块: import updater [as 别名]
# 或者: from updater import Updater [as 别名]
def main(resume, gpu, load_path, data_path):
dataset = Dataset(data_path)
GenNetwork = MultiScaleGenerator(c.SCALE_FMS_G, c.SCALE_KERNEL_SIZES_G)
DisNetwork = MultiScaleDiscriminator(c.SCALE_CONV_FMS_D, c.SCALE_KERNEL_SIZES_D, c.SCALE_FC_LAYER_SIZES_D)
optimizers = {}
optimizers["GeneratorNetwork"] = chainer.optimizers.SGD(c.LRATE_G)
optimizers["DiscriminatorNetwork"] = chainer.optimizers.SGD(c.LRATE_D)
iterator = chainer.iterators.SerialIterator(dataset, 1)
params = {'LAM_ADV': 0.05, 'LAM_LP': 1, 'LAM_GDL': .1}
updater = Updater(iterators=iterator, optimizers=optimizers,
GeneratorNetwork=GenNetwork,
DiscriminatorNetwork=DisNetwork,
params=params,
device=gpu
)
if gpu>=0:
updater.GenNetwork.to_gpu()
updater.DisNetwork.to_gpu()
trainer = chainer.training.Trainer(updater, (500000, 'iteration'), out='result')
trainer.extend(extensions.snapshot(filename='snapshot'), trigger=(1, 'iteration'))
trainer.extend(extensions.snapshot_object(trainer.updater.GenNetwork, "GEN"))
trainer.extend(saveGen)
log_keys = ['epoch', 'iteration', 'GeneratorNetwork/L2Loss', 'GeneratorNetwork/GDL',
'DiscriminatorNetwork/DisLoss', 'GeneratorNetwork/CompositeGenLoss']
print_keys = ['GeneratorNetwork/CompositeGenLoss','DiscriminatorNetwork/DisLoss']
trainer.extend(extensions.LogReport(keys=log_keys, trigger=(10, 'iteration')))
trainer.extend(extensions.PrintReport(print_keys), trigger=(10, 'iteration'))
trainer.extend(extensions.PlotReport(['DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="DisLoss.png"))
trainer.extend(extensions.PlotReport(['GeneratorNetwork/CompositeGenLoss'], 'iteration', (10, 'iteration'), file_name="GenLoss.png"))
trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss'], 'iteration', (10, 'iteration'), file_name="AdvGenLoss.png"))
trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss','DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="AdversarialLosses.png"))
trainer.extend(extensions.PlotReport(['GeneratorNetwork/L2Loss'], 'iteration', (10, 'iteration'),file_name="L2Loss.png"))
trainer.extend(extensions.PlotReport(['GeneratorNetwork/GDL'], 'iteration', (10, 'iteration'),file_name="GDL.png"))
trainer.extend(extensions.ProgressBar(update_interval=10))
if resume:
# Resume from a snapshot
chainer.serializers.load_npz(load_path, trainer)
print(trainer.updater.__dict__)
trainer.run()