当前位置: 首页>>代码示例>>Python>>正文


Python logger.dump_tabular方法代码示例

本文整理汇总了Python中rllab.misc.logger.dump_tabular方法的典型用法代码示例。如果您正苦于以下问题:Python logger.dump_tabular方法的具体用法?Python logger.dump_tabular怎么用?Python logger.dump_tabular使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rllab.misc.logger的用法示例。


在下文中一共展示了logger.dump_tabular方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
        self.start_worker()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                logger.log('Obtaining samples...')
                paths = self.sampler.obtain_samples(itr)
                logger.log('Processing samples...')
                samples_data = self.sampler.process_samples(itr, paths)
                logger.log('Logging diagnostics...')
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                # Save the trajectories into the param
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                logger.save_itr_params(itr, params)
                logger.log('Saved')
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                                  'continue...')

        self.shutdown_worker() 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:31,代码来源:batch_polopt.py

示例2: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log('Obtaining samples...')
                paths = self.sampler.obtain_samples(itr)
                logger.log('Processing samples...')
                # Update the Reward function
                paths = self.compute_irl(paths, itr=itr)
                # returns.append(self.log_avg_returns(paths))
                samples_data = self.sampler.process_samples(itr, paths)

                logger.log('Logging diagnostics...')
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                logger.save_itr_params(itr, params)
                logger.log('Saved')
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')
        self.shutdown_worker()
        return 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:42,代码来源:irl_batch_polopt.py

示例3: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        if created_session:
            sess.close() 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:39,代码来源:batch_polopt.py

示例4: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):

        memory = ReplayMem(
            obs_dim=self.env.observation_space.flat_dim,
            act_dim=self.env.action_space.flat_dim,
            memory_size=self.memory_size)

        itr = 0
        path_length = 0
        path_return = 0
        end = False
        obs = self.env.reset()

        for epoch in range(self.n_epochs):
            logger.push_prefix("epoch #%d | " % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # run the policy
                if end:
                    # reset the environment and stretegy when an episode ends
                    obs = self.env.reset()
                    self.strategy.reset()
                    # self.policy.reset()
                    self.strategy_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                # note action is sampled from the policy not the target policy
                act = self.strategy.get_action(obs, self.policy)
                nxt, rwd, end, _ = self.env.step(act)

                path_length += 1
                path_return += rwd

                if not end and path_length >= self.max_path_length:
                    end = True
                    if self.include_horizon_terminal:
                        memory.add_sample(obs, act, rwd, end)
                else:
                    memory.add_sample(obs, act, rwd, end)

                obs = nxt

                if memory.size >= self.memory_start_size:
                    for update_time in range(self.n_updates_per_sample):
                        batch = memory.get_batch(self.batch_size)
                        self.do_update(itr, batch)

                itr += 1

            logger.log("Training finished")
            if memory.size >= self.memory_start_size:
                self.evaluate(epoch, memory)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()

        # self.env.terminate()
        # self.policy.terminate() 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:59,代码来源:ddpg.py

示例5: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)

        if self.init_qvar_params is not None:
            self.qvar_model.set_params(self.init_qvar_params)

        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)

        if self.init_empw_params is not None:
            self.empw.set_params(self.init_empw_params)

        self.start_worker()
        start_time = time.time()

        returns = []
        rew = [] # stores score at each step
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()

            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths,r = self.compute_irl(paths, itr=itr)
                rew.append(r)
                returns.append(self.log_avg_returns(paths))
                self.compute_qvar(paths, itr=itr)
                self.compute_empw(paths, itr=itr)
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            if itr%self.target_empw_update==0 and self.train_empw:  #reward 5
                print('updating target empowerment parameters')
                self.tempw.set_params(self.__empw_params)


        #pickle.dump(rew, open("rewards.p", "wb" )) # uncomment to store rewards in every iteration
        self.shutdown_worker()
        return 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:62,代码来源:irl_batch_polopt.py

示例6: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths = self.compute_irl(paths, itr=itr)
                returns.append(self.log_avg_returns(paths))
                samples_data = self.process_samples(itr, paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        return 
开发者ID:justinjfu,项目名称:inverse_rl,代码行数:44,代码来源:irl_batch_polopt.py

示例7: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):

        memory = ReplayMem(
            obs_dim=self.env.observation_space.flat_dim,
            act_dim=self.env.action_space.flat_dim,
            memory_size=self.memory_size)

        itr = 0
        path_length = 0
        path_return = 0
        end = False
        obs = self.env.reset()

        for epoch in xrange(self.n_epochs):
            logger.push_prefix("epoch #%d | " % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # run the policy
                if end:
                    # reset the environment and stretegy when an episode ends
                    obs = self.env.reset()
                    self.strategy.reset()
                    # self.policy.reset()
                    self.strategy_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                # note action is sampled from the policy not the target policy
                act = self.strategy.get_action(obs, self.policy)
                nxt, rwd, end, _ = self.env.step(act)

                path_length += 1
                path_return += rwd

                if not end and path_length >= self.max_path_length:
                    end = True
                    if self.include_horizon_terminal:
                        memory.add_sample(obs, act, rwd, end)
                else:
                    memory.add_sample(obs, act, rwd, end)

                obs = nxt

                if memory.size >= self.memory_start_size:
                    for update_time in xrange(self.n_updates_per_sample):
                        batch = memory.get_batch(self.batch_size)
                        self.do_update(itr, batch)

                itr += 1

            logger.log("Training finished")
            if memory.size >= self.memory_start_size:
                self.evaluate(epoch, memory)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()

        # self.env.terminate()
        # self.policy.terminate() 
开发者ID:mahyarnajibi,项目名称:SNIPER-mxnet,代码行数:59,代码来源:ddpg.py


注:本文中的rllab.misc.logger.dump_tabular方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。