当前位置: 首页>>代码示例>>Python>>正文


Python callbacks.ModelIntervalCheckpoint方法代码示例

本文整理汇总了Python中rl.callbacks.ModelIntervalCheckpoint方法的典型用法代码示例。如果您正苦于以下问题:Python callbacks.ModelIntervalCheckpoint方法的具体用法?Python callbacks.ModelIntervalCheckpoint怎么用?Python callbacks.ModelIntervalCheckpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rl.callbacks的用法示例。


在下文中一共展示了callbacks.ModelIntervalCheckpoint方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: build_callbacks

# 需要导入模块: from rl import callbacks [as 别名]
# 或者: from rl.callbacks import ModelIntervalCheckpoint [as 别名]
def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks 
开发者ID:PacktPublishing,项目名称:Deep-Learning-Quick-Reference,代码行数:8,代码来源:dqn_lunar_lander.py

示例2: build_callbacks

# 需要导入模块: from rl import callbacks [as 别名]
# 或者: from rl.callbacks import ModelIntervalCheckpoint [as 别名]
def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks 
开发者ID:PacktPublishing,项目名称:Deep-Learning-Quick-Reference,代码行数:8,代码来源:dqn_breakout.py

示例3: training_game

# 需要导入模块: from rl import callbacks [as 别名]
# 或者: from rl.callbacks import ModelIntervalCheckpoint [as 别名]
def training_game():
    env = Environment(map_name="HallucinIce", visualize=True, game_steps_per_episode=150, agent_interface_format=features.AgentInterfaceFormat(
        feature_dimensions=features.Dimensions(screen=64, minimap=32)
    ))

    input_shape = (_SIZE, _SIZE, 1)
    nb_actions = _SIZE * _SIZE  # Should this be an integer

    model = neural_network_model(input_shape, nb_actions)
    # memory : how many subsequent observations should be provided to the network?
    memory = SequentialMemory(limit=5000, window_length=_WINDOW_LENGTH)

    processor = SC2Proc()

    ### Policy
    # Agent´s behaviour function. How the agent pick actions
    # LinearAnnealedPolicy is a wrapper that transforms the policy into a linear incremental linear solution . Then why im not see LAP with other than not greedy ?
    # EpsGreedyQPolicy is a way of selecting random actions with uniform distributions from a set of actions . Select an action that can give max or min rewards
    # BolztmanQPolicy . Assumption that it follows a Boltzman distribution. gives the probability that a system will be in a certain state as a function of that state´s energy??

    policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr="eps", value_max=1, value_min=0.7, value_test=.0,
                                  nb_steps=1e6)
    # policy = (BoltzmanQPolicy( tau=1., clip= (-500,500)) #clip defined in between -500 / 500


    ### Agent
    # Double Q-learning ( combines Q-Learning with a deep Neural Network )
    # Q Learning -- Bellman equation

    dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory,
                   nb_steps_warmup=500, target_model_update=1e-2, policy=policy,
                   batch_size=150, processor=processor)

    dqn.compile(Adam(lr=.001), metrics=["mae"])


    ## Save the parameters and upload them when needed

    name = "HallucinIce"
    w_file = "dqn_{}_weights.h5f".format(name)
    check_w_file = "train_w" + name + "_weights.h5f"

    if SAVE_MODEL:
        check_w_file = "train_w" + name + "_weights_{step}.h5f"

    log_file = "training_w_{}_log.json".format(name)
    callbacks = [ModelIntervalCheckpoint(check_w_file, interval=1000)]
    callbacks += [FileLogger(log_file, interval=100)]

    if LOAD_MODEL:
        dqn.load_weights(w_file)

    dqn.fit(env, callbacks=callbacks, nb_steps=1e7, action_repetition=2,
            log_interval=1e4, verbose=2)

    dqn.save_weights(w_file, overwrite=True)
    dqn.test(env, action_repetition=2, nb_episodes=30, visualize=False) 
开发者ID:SoyGema,项目名称:Startcraft_pysc2_minigames,代码行数:59,代码来源:DQN_Agent.py


注:本文中的rl.callbacks.ModelIntervalCheckpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。