当前位置: 首页>>代码示例>>Python>>正文


Python tf_util.huber_loss方法代码示例

本文整理汇总了Python中baselines.common.tf_util.huber_loss方法的典型用法代码示例。如果您正苦于以下问题:Python tf_util.huber_loss方法的具体用法?Python tf_util.huber_loss怎么用?Python tf_util.huber_loss使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在baselines.common.tf_util的用法示例。


在下文中一共展示了tf_util.huber_loss方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: qmap_build_train

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import huber_loss [as 别名]
def qmap_build_train(observation_space, coords_shape, model, n_actions, optimizer, grad_norm_clip, scope='q_map'):
    with tf.variable_scope(scope):
        ob_shape = observation_space.shape
        observations = tf.placeholder(tf.float32, [None] + list(ob_shape), name='observations')
        actions = tf.placeholder(tf.int32, [None], name='actions')
        target_qs = tf.placeholder(tf.float32, [None] + list(coords_shape), name='targets')
        weights = tf.placeholder(tf.float32, [None], name='weights')

        q_values = model(inpt=observations, n_actions=n_actions, scope='q_func')
        q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name + "/q_func")

        target_q_values = model(inpt=observations, n_actions=n_actions, scope='target_q_func')
        target_q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name + "/target_q_func")

        action_masks = tf.expand_dims(tf.expand_dims(tf.one_hot(actions, n_actions), axis=1), axis=1)
        qs_selected = tf.reduce_sum(q_values * action_masks, 3)

        td_errors = 1 * (qs_selected - target_qs) # TODO: coefficient?
        losses = tf.reduce_mean(tf.square(td_errors), [1, 2]) # TODO: find best, was U.huber_loss
        weighted_loss = tf.reduce_mean(weights * losses)

        if grad_norm_clip is not None:
            gradients = optimizer.compute_gradients(weighted_loss, var_list=q_func_vars)
            for i, (grad, var) in enumerate(gradients):
                if grad is not None:
                    gradients[i] = (tf.clip_by_norm(grad, grad_norm_clip), var)
            optimize = optimizer.apply_gradients(gradients)
            grad_norms = [tf.norm(grad) for grad in gradients]
        else:
            optimize = optimizer.minimize(weighted_loss, var_list=q_func_vars)
            grad_norms = None

        update_target_expr = []
        for var, var_target in zip(sorted(q_func_vars, key=lambda v: v.name),
                                   sorted(target_q_func_vars, key=lambda v: v.name)):
            update_target_expr.append(var_target.assign(var))
        update_target_expr = tf.group(*update_target_expr)

    errors = tf.reduce_mean(tf.abs(td_errors), [1, 2]) # TODO: try with the losses directly
    compute_q_values = U.function(inputs=[observations], outputs=q_values)
    compute_double_q_values = U.function(inputs=[observations], outputs=[q_values, target_q_values])
    train = U.function(inputs=[observations, actions, target_qs, weights], outputs=errors, updates=[optimize])
    update_target = U.function([], [], updates=[update_target_expr])
    trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
    train_debug = U.function(inputs=[observations, actions, target_qs, weights], outputs=[errors, weighted_loss, grad_norms, trainable_vars], updates=[optimize])

    return compute_q_values, compute_double_q_values, train, update_target, train_debug 
开发者ID:fabiopardo,项目名称:qmap,代码行数:49,代码来源:q_map_dqn_agent.py


注:本文中的baselines.common.tf_util.huber_loss方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。