当前位置: 首页>>代码示例>>Python>>正文


Python ray.tune方法代码示例

本文整理汇总了Python中ray.tune方法的典型用法代码示例。如果您正苦于以下问题:Python ray.tune方法的具体用法?Python ray.tune怎么用?Python ray.tune使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在ray的用法示例。


在下文中一共展示了ray.tune方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: ray_trainable

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def ray_trainable(config, reporter):
    '''
    Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api
    Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so:
    config = {
        'spec': spec,
        'trial_index': tune.sample_from(lambda spec: gen_trial_index()),
        ... # normal ray config with sample, grid search etc.
    }
    '''
    from convlab.experiment.control import Trial
    # restore data carried from ray.run() config
    spec = config.pop('spec')
    trial_index = config.pop('trial_index')
    spec['meta']['trial'] = trial_index
    spec = inject_config(spec, config)
    # run SLM Lab trial
    metrics = Trial(spec).run()
    metrics.update(config) # carry config for analysis too
    # ray report to carry data in ray trial.last_result
    reporter(trial_data={trial_index: metrics}) 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:23,代码来源:search.py

示例2: testTrainableCallable

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testTrainableCallable(self):
        def dummy_fn(config, reporter, steps):
            reporter(timesteps_total=steps, done=True)

        from functools import partial
        steps = 500
        register_trainable("test", partial(dummy_fn, steps=steps))
        [trial] = run_experiments({
            "foo": {
                "run": "test",
            }
        })
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
        [trial] = tune.run(partial(dummy_fn, steps=steps)).trials
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps) 
开发者ID:ray-project,项目名称:ray,代码行数:19,代码来源:test_api.py

示例3: testLongFilename

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testLongFilename(self):
        def train(config, reporter):
            assert os.path.join(ray.utils.get_user_temp_dir(), "logdir",
                                "foo") in os.getcwd(), os.getcwd()
            reporter(timesteps_total=1)

        register_trainable("f1", train)
        run_experiments({
            "foo": {
                "run": "f1",
                "local_dir": os.path.join(ray.utils.get_user_temp_dir(),
                                          "logdir"),
                "config": {
                    "a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
                    "b" * 50: tune.sample_from(lambda spec: "long" * 40),
                },
            }
        }) 
开发者ID:ray-project,项目名称:ray,代码行数:20,代码来源:test_api.py

示例4: testNestedStoppingReturn

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testNestedStoppingReturn(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test={"test1": {"test2": i}})

        with self.assertRaises(TuneError):
            [trial] = tune.run(
                train, stop={
                    "test": {
                        "test1": {
                            "test2": 6
                        }
                    }
                }).trials
        [trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
        self.assertEqual(trial.last_result["training_iteration"], 7) 
开发者ID:ray-project,项目名称:ray,代码行数:18,代码来源:test_api.py

示例5: testStopper

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testStopper(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test=i)

        class CustomStopper(Stopper):
            def __init__(self):
                self._count = 0

            def __call__(self, trial_id, result):
                print("called")
                self._count += 1
                return result["test"] > 6

            def stop_all(self):
                return self._count > 5

        trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials
        self.assertTrue(all(t.status == Trial.TERMINATED for t in trials))
        self.assertTrue(
            any(
                t.last_result.get("training_iteration") is None
                for t in trials)) 
开发者ID:ray-project,项目名称:ray,代码行数:25,代码来源:test_api.py

示例6: testBadStoppingFunction

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testBadStoppingFunction(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test=i)

        class CustomStopper:
            def stop(self, result):
                return result["test"] > 6

        def stop(result):
            return result["test"] > 6

        with self.assertRaises(TuneError):
            tune.run(train, stop=CustomStopper().stop)
        with self.assertRaises(TuneError):
            tune.run(train, stop=stop) 
开发者ID:ray-project,项目名称:ray,代码行数:18,代码来源:test_api.py

示例7: testLotsOfStops

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testLotsOfStops(self):
        class TestTrainable(Trainable):
            def step(self):
                result = {"name": self.trial_name, "trial_id": self.trial_id}
                return result

            def cleanup(self):
                time.sleep(2)
                open(os.path.join(self.logdir, "marker"), "a").close()
                return 1

        analysis = tune.run(
            TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1})
        ray.shutdown()
        for trial in analysis.trials:
            path = os.path.join(trial.logdir, "marker")
            assert os.path.exists(path) 
开发者ID:ray-project,项目名称:ray,代码行数:19,代码来源:test_api.py

示例8: test_time

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def test_time(start_ray, tmpdir):
    experiment_name = "test_time"
    experiment_path = os.path.join(str(tmpdir), experiment_name)
    num_samples = 2
    tune.run_experiments({
        experiment_name: {
            "run": "__fake",
            "stop": {
                "training_iteration": 1
            },
            "num_samples": num_samples,
            "local_dir": str(tmpdir)
        }
    })
    times = []
    for i in range(5):
        start = time.time()
        subprocess.check_call(["tune", "ls", experiment_path])
        times += [time.time() - start]

    assert sum(times) / len(times) < 3.0, "CLI is taking too long!" 
开发者ID:ray-project,项目名称:ray,代码行数:23,代码来源:test_commands.py

示例9: test_ls_with_cfg

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def test_ls_with_cfg(start_ray, tmpdir):
    experiment_name = "test_ls_with_cfg"
    experiment_path = os.path.join(str(tmpdir), experiment_name)
    tune.run(
        "__fake",
        name=experiment_name,
        stop={"training_iteration": 1},
        config={"test_variable": tune.grid_search(list(range(5)))},
        local_dir=str(tmpdir))

    columns = [CONFIG_PREFIX + "test_variable", "trial_id"]
    limit = 4
    with Capturing() as output:
        commands.list_trials(experiment_path, info_keys=columns, limit=limit)
    lines = output.captured
    assert all(col in lines[1] for col in columns)
    assert lines[1].count("|") == len(columns) + 1
    assert len(lines) == 3 + limit + 1 
开发者ID:ray-project,项目名称:ray,代码行数:20,代码来源:test_commands.py

示例10: testCloudFunctions

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testCloudFunctions(self):
        tmpdir = tempfile.mkdtemp()
        tmpdir2 = tempfile.mkdtemp()
        os.mkdir(os.path.join(tmpdir2, "foo"))

        def sync_func(local, remote):
            for filename in glob.glob(os.path.join(local, "*.json")):
                shutil.copy(filename, remote)

        [trial] = tune.run(
            "__fake",
            name="foo",
            max_failures=0,
            local_dir=tmpdir,
            stop={
                "training_iteration": 1
            },
            upload_dir=tmpdir2,
            sync_to_cloud=sync_func).trials
        test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
        self.assertTrue(test_file_path)
        shutil.rmtree(tmpdir)
        shutil.rmtree(tmpdir2) 
开发者ID:ray-project,项目名称:ray,代码行数:25,代码来源:test_sync.py

示例11: testNoSync

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def testNoSync(self):
        """Sync should not run on a single node."""

        def sync_func(source, target):
            pass

        with patch.object(CommandBasedClient, "_execute") as mock_sync:
            [trial] = tune.run(
                "__fake",
                name="foo",
                max_failures=0,
                **{
                    "stop": {
                        "training_iteration": 1
                    },
                    "sync_to_driver": sync_func
                }).trials
            self.assertEqual(mock_sync.call_count, 0) 
开发者ID:ray-project,项目名称:ray,代码行数:20,代码来源:test_sync.py

示例12: parse_args

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("-rd", "--ray-directory", default="/data/douillard/ray_results")
    parser.add_argument("-o", "--output-options")
    parser.add_argument("-t", "--tune")
    parser.add_argument("-g", "--gpus", nargs="+", default=["0"])
    parser.add_argument("-per", "--gpu-percent", type=float, default=0.5)
    parser.add_argument("-topn", "--topn", default=5, type=int)
    parser.add_argument("-earlystop", default="ucir", type=str)
    parser.add_argument("-options", "--options", default=None, nargs="+")
    parser.add_argument("-threads", default=2, type=int)
    parser.add_argument("-resume", default=False, action="store_true")
    parser.add_argument("-metric", default="avg_inc_acc", choices=["avg_inc_acc", "last_acc"])

    return parser.parse_args() 
开发者ID:arthurdouillard,项目名称:incremental_learning.pytorch,代码行数:18,代码来源:hyperfind.py

示例13: get_tune_config

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def get_tune_config(tune_options, options_files):
    with open(tune_options) as f:
        options = yaml.load(f, Loader=yaml.FullLoader)

    if "epochs" in options and options["epochs"] == 1:
        raise ValueError("Using only 1 epoch, must be a mistake.")

    config = {}
    for k, v in options.items():
        if not k.startswith("var:"):
            config[k] = v
        else:
            config[k.replace("var:", "")] = tune.grid_search(v)

    if options_files is not None:
        print("Options files: {}".format(options_files))
        config["options"] = [os.path.realpath(op) for op in options_files]

    return config 
开发者ID:arthurdouillard,项目名称:incremental_learning.pytorch,代码行数:21,代码来源:hyperfind.py

示例14: get_tune_experiment

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def get_tune_experiment(config, agent, episodes, root_dir, is_schedule):
    scheduler = None
    agent_class = get_agent(agent)
    ex_conf = {}
    ex_conf["name"] = agent
    ex_conf["run"] = agent_class
    ex_conf["local_dir"] = config["env_config"]["output_dir"]
    ex_conf["stop"] = {"episodes_total": episodes}

    if is_schedule:
        ex_conf["stop"] = {"time_total_s": 300}
        ex_conf["num_samples"] = 2
        config["env_config"]["parallel_envs"] = True
        # custom changes to experiment
        log.info("Performing tune experiment")
        config, scheduler = set_tuning_parameters(agent, config)
    ex_conf["config"] = config
    experiment = Experiment(**ex_conf)
    return experiment, scheduler 
开发者ID:dcgym,项目名称:iroko,代码行数:21,代码来源:run_ray.py

示例15: ray_trainable

# 需要导入模块: import ray [as 别名]
# 或者: from ray import tune [as 别名]
def ray_trainable(config, reporter):
    '''
    Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api
    Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so:
    config = {
        'spec': spec,
        'trial_index': tune.sample_from(lambda spec: gen_trial_index()),
        ... # normal ray config with sample, grid search etc.
    }
    '''
    import os
    os.environ.pop('CUDA_VISIBLE_DEVICES', None)  # remove CUDA id restriction from ray
    from slm_lab.experiment.control import Trial
    # restore data carried from ray.run() config
    spec = config.pop('spec')
    spec = inject_config(spec, config)
    # tick trial_index with proper offset
    trial_index = config.pop('trial_index')
    spec['meta']['trial'] = trial_index - 1
    spec_util.tick(spec, 'trial')
    # run SLM Lab trial
    metrics = Trial(spec).run()
    metrics.update(config)  # carry config for analysis too
    # ray report to carry data in ray trial.last_result
    reporter(trial_data={trial_index: metrics}) 
开发者ID:kengz,项目名称:SLM-Lab,代码行数:27,代码来源:search.py


注:本文中的ray.tune方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。