當前位置: 首頁>>代碼示例>>Python>>正文


Python updater.Updater方法代碼示例

本文整理匯總了Python中updater.Updater方法的典型用法代碼示例。如果您正苦於以下問題:Python updater.Updater方法的具體用法?Python updater.Updater怎麽用?Python updater.Updater使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在updater的用法示例。


在下文中一共展示了updater.Updater方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: updater_wrapper

# 需要導入模塊: import updater [as 別名]
# 或者: from updater import Updater [as 別名]
def updater_wrapper(self, cnt_round, dic_agent_conf, dic_exp_conf, dic_traffic_env_conf, dic_path, best_round=None,
                        bar_round=None):

        updater = Updater(
            cnt_round=cnt_round,
            dic_agent_conf=dic_agent_conf,
            dic_exp_conf=dic_exp_conf,
            dic_traffic_env_conf=dic_traffic_env_conf,
            dic_path=dic_path,
            best_round=best_round,
            bar_round=bar_round
        )

        updater.load_sample_for_agents()
        updater.update_network_for_agents()
        print("updater_wrapper end")
        return 
開發者ID:multi-commander,項目名稱:Multi-Commander,代碼行數:19,代碼來源:pipeline.py

示例2: main

# 需要導入模塊: import updater [as 別名]
# 或者: from updater import Updater [as 別名]
def main(resume, gpu, load_path, data_path):
	dataset = Dataset(data_path)


	GenNetwork = MultiScaleGenerator(c.SCALE_FMS_G, c.SCALE_KERNEL_SIZES_G)
	DisNetwork = MultiScaleDiscriminator(c.SCALE_CONV_FMS_D, c.SCALE_KERNEL_SIZES_D, c.SCALE_FC_LAYER_SIZES_D)

	optimizers = {}
	optimizers["GeneratorNetwork"] = chainer.optimizers.SGD(c.LRATE_G)
	optimizers["DiscriminatorNetwork"] = chainer.optimizers.SGD(c.LRATE_D)

	iterator = chainer.iterators.SerialIterator(dataset, 1)
	params = {'LAM_ADV': 0.05, 'LAM_LP': 1, 'LAM_GDL': .1}
	updater = Updater(iterators=iterator, optimizers=optimizers,
	                  GeneratorNetwork=GenNetwork,
	                  DiscriminatorNetwork=DisNetwork,
	                  params=params,
	                  device=gpu
	                  )
	if gpu>=0:
		updater.GenNetwork.to_gpu()
		updater.DisNetwork.to_gpu()

	trainer = chainer.training.Trainer(updater, (500000, 'iteration'), out='result')
	trainer.extend(extensions.snapshot(filename='snapshot'), trigger=(1, 'iteration'))
	trainer.extend(extensions.snapshot_object(trainer.updater.GenNetwork, "GEN"))
	trainer.extend(saveGen)

	log_keys = ['epoch', 'iteration', 'GeneratorNetwork/L2Loss', 'GeneratorNetwork/GDL',
	            'DiscriminatorNetwork/DisLoss', 'GeneratorNetwork/CompositeGenLoss']
	print_keys = ['GeneratorNetwork/CompositeGenLoss','DiscriminatorNetwork/DisLoss']
	trainer.extend(extensions.LogReport(keys=log_keys, trigger=(10, 'iteration')))
	trainer.extend(extensions.PrintReport(print_keys), trigger=(10, 'iteration'))
	trainer.extend(extensions.PlotReport(['DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="DisLoss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/CompositeGenLoss'], 'iteration', (10, 'iteration'), file_name="GenLoss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss'], 'iteration', (10, 'iteration'), file_name="AdvGenLoss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss','DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="AdversarialLosses.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/L2Loss'], 'iteration', (10, 'iteration'),file_name="L2Loss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/GDL'], 'iteration', (10, 'iteration'),file_name="GDL.png"))

	trainer.extend(extensions.ProgressBar(update_interval=10))
	if resume:
		# Resume from a snapshot
		chainer.serializers.load_npz(load_path, trainer)
	print(trainer.updater.__dict__)
	trainer.run() 
開發者ID:alokwhitewolf,項目名稱:Video-frame-prediction-by-multi-scale-GAN,代碼行數:48,代碼來源:train.py


注:本文中的updater.Updater方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。