當前位置: 首頁>>代碼示例>>Python>>正文


Python callbacks.ModelIntervalCheckpoint方法代碼示例

本文整理匯總了Python中rl.callbacks.ModelIntervalCheckpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python callbacks.ModelIntervalCheckpoint方法的具體用法?Python callbacks.ModelIntervalCheckpoint怎麽用?Python callbacks.ModelIntervalCheckpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在rl.callbacks的用法示例。


在下文中一共展示了callbacks.ModelIntervalCheckpoint方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: build_callbacks

# 需要導入模塊: from rl import callbacks [as 別名]
# 或者: from rl.callbacks import ModelIntervalCheckpoint [as 別名]
def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks 
開發者ID:PacktPublishing,項目名稱:Deep-Learning-Quick-Reference,代碼行數:8,代碼來源:dqn_lunar_lander.py

示例2: build_callbacks

# 需要導入模塊: from rl import callbacks [as 別名]
# 或者: from rl.callbacks import ModelIntervalCheckpoint [as 別名]
def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks 
開發者ID:PacktPublishing,項目名稱:Deep-Learning-Quick-Reference,代碼行數:8,代碼來源:dqn_breakout.py

示例3: training_game

# 需要導入模塊: from rl import callbacks [as 別名]
# 或者: from rl.callbacks import ModelIntervalCheckpoint [as 別名]
def training_game():
    env = Environment(map_name="HallucinIce", visualize=True, game_steps_per_episode=150, agent_interface_format=features.AgentInterfaceFormat(
        feature_dimensions=features.Dimensions(screen=64, minimap=32)
    ))

    input_shape = (_SIZE, _SIZE, 1)
    nb_actions = _SIZE * _SIZE  # Should this be an integer

    model = neural_network_model(input_shape, nb_actions)
    # memory : how many subsequent observations should be provided to the network?
    memory = SequentialMemory(limit=5000, window_length=_WINDOW_LENGTH)

    processor = SC2Proc()

    ### Policy
    # Agent´s behaviour function. How the agent pick actions
    # LinearAnnealedPolicy is a wrapper that transforms the policy into a linear incremental linear solution . Then why im not see LAP with other than not greedy ?
    # EpsGreedyQPolicy is a way of selecting random actions with uniform distributions from a set of actions . Select an action that can give max or min rewards
    # BolztmanQPolicy . Assumption that it follows a Boltzman distribution. gives the probability that a system will be in a certain state as a function of that state´s energy??

    policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr="eps", value_max=1, value_min=0.7, value_test=.0,
                                  nb_steps=1e6)
    # policy = (BoltzmanQPolicy( tau=1., clip= (-500,500)) #clip defined in between -500 / 500


    ### Agent
    # Double Q-learning ( combines Q-Learning with a deep Neural Network )
    # Q Learning -- Bellman equation

    dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory,
                   nb_steps_warmup=500, target_model_update=1e-2, policy=policy,
                   batch_size=150, processor=processor)

    dqn.compile(Adam(lr=.001), metrics=["mae"])


    ## Save the parameters and upload them when needed

    name = "HallucinIce"
    w_file = "dqn_{}_weights.h5f".format(name)
    check_w_file = "train_w" + name + "_weights.h5f"

    if SAVE_MODEL:
        check_w_file = "train_w" + name + "_weights_{step}.h5f"

    log_file = "training_w_{}_log.json".format(name)
    callbacks = [ModelIntervalCheckpoint(check_w_file, interval=1000)]
    callbacks += [FileLogger(log_file, interval=100)]

    if LOAD_MODEL:
        dqn.load_weights(w_file)

    dqn.fit(env, callbacks=callbacks, nb_steps=1e7, action_repetition=2,
            log_interval=1e4, verbose=2)

    dqn.save_weights(w_file, overwrite=True)
    dqn.test(env, action_repetition=2, nb_episodes=30, visualize=False) 
開發者ID:SoyGema,項目名稱:Startcraft_pysc2_minigames,代碼行數:59,代碼來源:DQN_Agent.py


注:本文中的rl.callbacks.ModelIntervalCheckpoint方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。