当前位置: 首页>>代码示例>>Python>>正文


Python Serializable.quick_init方法代码示例

本文整理汇总了Python中rllab.core.serializable.Serializable.quick_init方法的典型用法代码示例。如果您正苦于以下问题:Python Serializable.quick_init方法的具体用法?Python Serializable.quick_init怎么用?Python Serializable.quick_init使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rllab.core.serializable.Serializable的用法示例。


在下文中一共展示了Serializable.quick_init方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(
            self,
            env,
            scale_reward=1.,
            normalize_obs=False,
            normalize_reward=False,
            obs_alpha=0.001,
            reward_alpha=0.001,
    ):
        Serializable.quick_init(self, locals())
        ProxyEnv.__init__(self, env)
        self._scale_reward = scale_reward
        self._normalize_obs = normalize_obs
        self._normalize_reward = normalize_reward
        self._obs_alpha = obs_alpha
        self._obs_mean = np.zeros(env.observation_space.flat_dim)
        self._obs_var = np.ones(env.observation_space.flat_dim)
        self._reward_alpha = reward_alpha
        self._reward_mean = 0.
        self._reward_var = 1. 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:22,代码来源:normalized_env.py

示例2: clone

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def clone(self, out=None):
        """Clones state of this environment, optionally into an existing one."""
        if out is None:
            # Reconstruct using the state saved by Serializable.quick_init
            out = type(self)(*self.__args, **self.__kwargs)
        
        if type(out) != type(self):
            raise Exception("out has the wrong type")
        if out._Serializable__args != self._Serializable__args or out._Serializable__kwargs != self._Serializable__kwargs:
            raise Exception("out was constructed with the wrong arguments")

        out.num_steps = self.num_steps
        out.terminated = self.terminated
        out.state['__last_action_name'] = self.state['__last_action_name']
        
        for module, out_module in zip(self.modules, out.modules):
            module.clone(self.state, out_module, out.state)
        return out 
开发者ID:vicariousinc,项目名称:pixelworld,代码行数:20,代码来源:modular_env.py

示例3: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self, env, record_video=True, video_schedule=None,
            log_dir=None, timestep_limit=9999):
        # Ensure the version saved to disk doesn't monitor into our log_dir
        locals_no_monitor = dict(locals())
        locals_no_monitor['log_dir'] = None
        locals_no_monitor['record_video'] = False
        locals_no_monitor['video_schedule'] = None
        Serializable.quick_init(self, locals_no_monitor)

        self.env = env
        self._observation_space = to_rllab_space(env.observation_space)
        self._action_space = to_rllab_space(env.action_space)        
        self.env.spec = EnvSpec('GymEnv-v0')

        monitor.logger.setLevel(logging.WARNING)
        if not record_video:
            self.video_schedule = NoVideoSchedule()
        else:
            if video_schedule is None:
                self.video_schedule = CappedCubicVideoSchedule()
            else:
                self.video_schedule = video_schedule
        self.set_log_dir(log_dir)

        self._horizon = timestep_limit 
开发者ID:vicariousinc,项目名称:pixelworld,代码行数:27,代码来源:gym_env.py

示例4: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(
            self,
            env,
            policy,
            baseline,
            optimizer=None,
            optimizer_args=None,
            **kwargs):
        Serializable.quick_init(self, locals())
        if optimizer is None:
            default_args = dict(
                batch_size=None,
                max_epochs=1,
            )
            if optimizer_args is None:
                optimizer_args = default_args
            else:
                optimizer_args = dict(default_args, **optimizer_args)
            optimizer = FirstOrderOptimizer(**optimizer_args)
        self.optimizer = optimizer
        self.opt_info = None
        super(VPG, self).__init__(env=env, policy=policy, baseline=baseline, **kwargs) 
开发者ID:thanard,项目名称:me-trpo,代码行数:24,代码来源:vpg.py

示例5: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self):
        Serializable.quick_init(self, locals())
        self.qpos = None
        self.qvel = None
        self.mass = 0.1
        self.dt = 0.01
        self.frame_skip = 5
        self.boundary = np.array([-10, 10])
        self.vel_bounds = [-np.inf, np.inf]
        """
        In 1 frame forward,
            qpos' = qpos + qvel*dt
            qvel' = qvel + u/m*dt
        """
        eig_vec = np.array([[0.7, -0.6], [-0.3, -0.1]])
        self.A = np.identity(2)# eig_vec @ np.diag([1.0, 0.8]) @ np.linalg.inv(eig_vec)
        self.B = np.array([[0.2, -0.04], [.3, .9]])
        self.c = np.array([0.0, 0.0])
        self.goal = None
        self.init_mean = np.zeros(2)
        self.init_std = 0.1
        self.ctrl_cost_coeff = 0.01 
开发者ID:thanard,项目名称:me-trpo,代码行数:24,代码来源:point_mass_env.py

示例6: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(
            self,
            env_spec,
            subsample_factor=1.,
            num_seq_inputs=1,
            regressor_args=None,
    ):
        Serializable.quick_init(self, locals())
        super(GaussianMLPBaseline, self).__init__(env_spec)
        if regressor_args is None:
            regressor_args = dict()

        self._regressor = GaussianMLPRegressor(
            input_shape=(env_spec.observation_space.flat_dim * num_seq_inputs,),
            output_dim=1,
            name="vf",
            **regressor_args
        ) 
开发者ID:sisl,项目名称:hgail,代码行数:20,代码来源:gaussian_mlp_baseline.py

示例7: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self, game_name, agent_num, action_num=12):
        Serializable.quick_init(self, locals())
        self.game = game_name
        self.agent_num = agent_num
        self.action_num = action_num
        self.action_spaces = MADiscrete([action_num] * self.agent_num)
        self.observation_spaces = MADiscrete([1] * self.agent_num)
        self.env_specs = MAEnvSpec(self.observation_spaces, self.action_spaces)
        self.t = 0
        self.numplots = 0
        self.payoff = {}

        if self.game == 'lemonade':
            assert self.agent_num == 3
            def get_distance(a_n, i):
                assert len(a_n) == 3
                a_n_i = np.copy(a_n)
                a_n_i[0], a_n_i[i] = a_n_i[i], a_n_i[0]
                return np.abs(a_n_i[0] - a_n_i[1]) + np.abs(a_n_i[0] - a_n_i[2])
            self.payoff = lambda a_n, i: get_distance(a_n, i) 
开发者ID:ml3705454,项目名称:mapr2,代码行数:22,代码来源:discrete_static_game.py

示例8: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self,
                 ctrl_cost_coeff=1e-2,  # gym has 1 here!
                 rew_speed=False,  # if True the dot product is taken with the speed instead of the position
                 rew_dir=None,  # (x,y,z) -> Rew=dot product of the CoM SPEED with this dir. Otherwise, DIST to 0
                 ego_obs=False,
                 no_contact=False,
                 sparse=False,
                 *args, **kwargs):
        self.ctrl_cost_coeff = ctrl_cost_coeff
        self.reward_dir = rew_dir
        self.rew_speed = rew_speed
        self.ego_obs = ego_obs
        self.no_cntct = no_contact
        self.sparse = sparse

        super(AntEnv, self).__init__(*args, **kwargs)
        Serializable.quick_init(self, locals()) 
开发者ID:florensacc,项目名称:snn4hrl,代码行数:19,代码来源:ant_env.py

示例9: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(
            self,
            *args,
            **kwargs):

        Serializable.quick_init(self, locals())
        MazeEnv.__init__(self, *args, **kwargs)
        self._blank_maze = False
        self.blank_maze_obs = np.concatenate([np.zeros(self._n_bins), np.zeros(self._n_bins)])

        # The following caches the spaces so they are not re-instantiated every time
        shp = self.get_current_obs().shape
        ub = BIG * np.ones(shp)
        self._observation_space = spaces.Box(ub * -1, ub)

        shp = self.get_current_robot_obs().shape
        ub = BIG * np.ones(shp)
        self._robot_observation_space = spaces.Box(ub * -1, ub)

        shp = self.get_current_maze_obs().shape
        ub = BIG * np.ones(shp)
        self._maze_observation_space = spaces.Box(ub * -1, ub) 
开发者ID:florensacc,项目名称:snn4hrl,代码行数:24,代码来源:fast_maze_env.py

示例10: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self):
        Serializable.quick_init(self, locals()) 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:4,代码来源:instrument.py

示例11: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:35,代码来源:gym_env.py

示例12: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self, goal_reward=10):
        super().__init__()
        Serializable.quick_init(self, locals())

        self.dynamics = PointDynamics(dim=2, sigma=0)
        self.init_mu = np.array((0, 0), dtype=np.float32)
        self.init_sigma = 0
        self.goal_positions = np.array(
            [
                [5, 0],
                [-5, 0],
                [0, 5],
                [0, -5]
            ],
            dtype=np.float32
        )
        self.goal_threshold = 1.
        self.goal_reward = goal_reward
        self.action_cost_coeff = 30.
        self.xlim = (-7, 7)
        self.ylim = (-7, 7)
        self.vel_bound = 1.
        self.reset()
        self.observation = None

        self.fig = None
        self.ax = None
        self.fixed_plots = None
        self.dynamic_plots = [] 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:31,代码来源:multigoal_env.py

示例13: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(
            self,
            env_spec,
            hidden_sizes=(32, 32),
            hidden_nonlinearity=NL.tanh,
            output_b_init=None,
            weight_signal=1.0,
            weight_nonsignal=1.0, 
            weight_smc=1.0):
        """
        :param env_spec: A spec for the mdp.
        :param hidden_sizes: list of sizes for the fully connected hidden layers
        :param hidden_nonlinearity: nonlinearity used for each hidden layer
        :return:
        """
        Serializable.quick_init(self, locals())
        assert isinstance(env_spec.action_space, Discrete)
        output_b_init = compute_output_b_init(env_spec.action_space.names,
            output_b_init, weight_signal, weight_nonsignal, weight_smc)

        prob_network = MLP(
            input_shape=(env_spec.observation_space.flat_dim,),
            output_dim=env_spec.action_space.n,
            hidden_sizes=hidden_sizes,
            hidden_nonlinearity=hidden_nonlinearity,
            output_nonlinearity=NL.softmax,
            output_b_init=output_b_init
        )
        super(InitCategoricalMLPPolicy, self).__init__(env_spec, hidden_sizes,
            hidden_nonlinearity, prob_network)


# Modified from RLLab GRUNetwork 
开发者ID:vicariousinc,项目名称:pixelworld,代码行数:35,代码来源:init_policy.py

示例14: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(
            self,
            name,
            max_opt_itr=20,
            initial_penalty=1.0,
            min_penalty=1e-2,
            max_penalty=1e6,
            increase_penalty_factor=2,
            decrease_penalty_factor=0.5,
            max_penalty_itr=10,
            adapt_penalty=True):
        Serializable.quick_init(self, locals())
        self._name = name
        self._max_opt_itr = max_opt_itr
        self._penalty = initial_penalty
        self._initial_penalty = initial_penalty
        self._min_penalty = min_penalty
        self._max_penalty = max_penalty
        self._increase_penalty_factor = increase_penalty_factor
        self._decrease_penalty_factor = decrease_penalty_factor
        self._max_penalty_itr = max_penalty_itr
        self._adapt_penalty = adapt_penalty

        self._opt_fun = None
        self._target = None
        self._max_constraint_val = None
        self._constraint_name = None 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:29,代码来源:penalty_lbfgs_optimizer.py

示例15: __init__

# 需要导入模块: from rllab.core.serializable import Serializable [as 别名]
# 或者: from rllab.core.serializable.Serializable import quick_init [as 别名]
def __init__(self, env_name, gym_wrappers=(),
                 register_fn=None, wrapper_args = (), record_log=False, record_video=False,
                 post_create_env_seed=None, force_reset=True):
        Serializable.quick_init(self, locals())
        if register_fn is None:
            import inverse_rl.envs
            register_fn = inverse_rl.envs.register_custom_envs
        register_fn()  # Force register
        self.env_name = env_name
        super(CustomGymEnv, self).__init__(env_name, wrappers=gym_wrappers,
                                           wrapper_args=wrapper_args,
                                           record_log=record_log, record_video=record_video,
                                           post_create_env_seed=post_create_env_seed,
                                           video_schedule=FixedIntervalVideoSchedule(50), force_reset=force_reset) 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:16,代码来源:env_utils.py


注:本文中的rllab.core.serializable.Serializable.quick_init方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。