当前位置: 首页>>代码示例>>Python>>正文


Python tune.run方法代码示例

本文整理汇总了Python中ray.tune.run方法的典型用法代码示例。如果您正苦于以下问题:Python tune.run方法的具体用法?Python tune.run怎么用?Python tune.run使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在ray.tune的用法示例。


在下文中一共展示了tune.run方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: ray_trainable

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def ray_trainable(config, reporter):
    '''
    Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api
    Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so:
    config = {
        'spec': spec,
        'trial_index': tune.sample_from(lambda spec: gen_trial_index()),
        ... # normal ray config with sample, grid search etc.
    }
    '''
    from convlab.experiment.control import Trial
    # restore data carried from ray.run() config
    spec = config.pop('spec')
    trial_index = config.pop('trial_index')
    spec['meta']['trial'] = trial_index
    spec = inject_config(spec, config)
    # run SLM Lab trial
    metrics = Trial(spec).run()
    metrics.update(config) # carry config for analysis too
    # ray report to carry data in ray trial.last_result
    reporter(trial_data={trial_index: metrics}) 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:23,代码来源:search.py

示例2: _rsync_func

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def _rsync_func(local_dir, remote_uri):
    """rsync data from worker to a remote location (by default the driver)."""
    # SOMEDAY: This function blocks until syncing completes, which is unfortunate.
    # If we instead specified a shell command, ray.tune._LogSyncer would run it asynchronously.
    # But we need to do a two-stage command, creating the directories first, because rsync will
    # balk if destination directory does not exist; so no easy way to do that.
    remote_host, ssh_key, *remainder = remote_uri.split(":")
    remote_dir = ":".join(remainder)  # remote directory may contain :
    remote_dir = shlex.quote(remote_dir)  # make safe for SSH/rsync call

    ssh_command = ["ssh", "-o", "StrictHostKeyChecking=no", "-i", ssh_key]
    ssh_mkdir = ssh_command + [remote_host, "mkdir", "-p", remote_dir]
    subprocess.run(ssh_mkdir, check=True)

    rsync = [
        "rsync",
        "-rlptv",
        "-e",
        " ".join(ssh_command),
        f"{local_dir}/",
        f"{remote_host}:{remote_dir}",
    ]
    subprocess.run(rsync) 
开发者ID:HumanCompatibleAI,项目名称:adversarial-policies,代码行数:25,代码来源:common.py

示例3: config

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def config():
    sacred_ex_name = "expert_demos"  # The experiment to parallelize
    init_kwargs = {}  # Keyword arguments to pass to ray.init()
    _uuid = make_unique_timestamp()
    run_name = f"DEFAULT_{_uuid}"  # CLI --name option. For analysis grouping.
    resources_per_trial = {}  # Argument to `tune.run`
    base_named_configs = []  # Background settings before search_space is applied
    base_config_updates = {}  # Background settings before search_space is applied
    search_space = {
        "named_configs": [],
        "config_updates": {},
    }  # `config` argument to `ray.tune.run(trainable, config)`

    local_dir = None  # `local_dir` arg for `ray.tune.run`
    upload_dir = None  # `upload_dir` arg for `ray.tune.run`
    n_seeds = 3  # Number of seeds to search over by default 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:18,代码来源:parallel.py

示例4: generate_test_data

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def generate_test_data():
    """Used by tests/generate_test_data.sh to generate tests/data/gather_tb/.

    "tests/data/gather_tb/" should contain 4 Tensorboard run directories ("sb_tb/" and
    "tb/" for each of two trials in the search space below).
    """
    sacred_ex_name = "expert_demos"
    run_name = "TEST"
    n_seeds = 1
    search_space = {
        "config_updates": {
            "init_rl_kwargs": {
                "learning_rate": tune.grid_search([3e-4 * x for x in (1 / 3, 1 / 2)]),
            },
        }
    }
    base_named_configs = ["cartpole", "fast"]
    base_config_updates = {
        "init_tensorboard": True,
        "rollout_save_final": False,
    } 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:23,代码来源:parallel.py

示例5: testRegisterEnvOverwrite

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testRegisterEnvOverwrite(self):
        def train(config, reporter):
            reporter(timesteps_total=100, done=True)

        def train2(config, reporter):
            reporter(timesteps_total=200, done=True)

        register_trainable("f1", train)
        register_trainable("f1", train2)
        [trial] = run_experiments({
            "foo": {
                "run": "f1",
            }
        })
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 200) 
开发者ID:ray-project,项目名称:ray,代码行数:18,代码来源:test_api.py

示例6: testTrainableCallable

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testTrainableCallable(self):
        def dummy_fn(config, reporter, steps):
            reporter(timesteps_total=steps, done=True)

        from functools import partial
        steps = 500
        register_trainable("test", partial(dummy_fn, steps=steps))
        [trial] = run_experiments({
            "foo": {
                "run": "test",
            }
        })
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
        [trial] = tune.run(partial(dummy_fn, steps=steps)).trials
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps) 
开发者ID:ray-project,项目名称:ray,代码行数:19,代码来源:test_api.py

示例7: testLogdirStartingWithTilde

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testLogdirStartingWithTilde(self):
        local_dir = "~/ray_results/local_dir"

        def train(config, reporter):
            cwd = os.getcwd()
            assert cwd.startswith(os.path.expanduser(local_dir)), cwd
            assert not cwd.startswith("~"), cwd
            reporter(timesteps_total=1)

        register_trainable("f1", train)
        run_experiments({
            "foo": {
                "run": "f1",
                "local_dir": local_dir,
                "config": {
                    "a": "b"
                },
            }
        }) 
开发者ID:ray-project,项目名称:ray,代码行数:21,代码来源:test_api.py

示例8: testLongFilename

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testLongFilename(self):
        def train(config, reporter):
            assert os.path.join(ray.utils.get_user_temp_dir(), "logdir",
                                "foo") in os.getcwd(), os.getcwd()
            reporter(timesteps_total=1)

        register_trainable("f1", train)
        run_experiments({
            "foo": {
                "run": "f1",
                "local_dir": os.path.join(ray.utils.get_user_temp_dir(),
                                          "logdir"),
                "config": {
                    "a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
                    "b" * 50: tune.sample_from(lambda spec: "long" * 40),
                },
            }
        }) 
开发者ID:ray-project,项目名称:ray,代码行数:20,代码来源:test_api.py

示例9: testBadStoppingReturn

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testBadStoppingReturn(self):
        def train(config, reporter):
            reporter()

        register_trainable("f1", train)

        def f():
            run_experiments({
                "foo": {
                    "run": "f1",
                    "stop": {
                        "time": 10
                    },
                }
            })

        self.assertRaises(TuneError, f) 
开发者ID:ray-project,项目名称:ray,代码行数:19,代码来源:test_api.py

示例10: testNestedStoppingReturn

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testNestedStoppingReturn(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test={"test1": {"test2": i}})

        with self.assertRaises(TuneError):
            [trial] = tune.run(
                train, stop={
                    "test": {
                        "test1": {
                            "test2": 6
                        }
                    }
                }).trials
        [trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
        self.assertEqual(trial.last_result["training_iteration"], 7) 
开发者ID:ray-project,项目名称:ray,代码行数:18,代码来源:test_api.py

示例11: testBadStoppingFunction

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testBadStoppingFunction(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test=i)

        class CustomStopper:
            def stop(self, result):
                return result["test"] > 6

        def stop(result):
            return result["test"] > 6

        with self.assertRaises(TuneError):
            tune.run(train, stop=CustomStopper().stop)
        with self.assertRaises(TuneError):
            tune.run(train, stop=stop) 
开发者ID:ray-project,项目名称:ray,代码行数:18,代码来源:test_api.py

示例12: testTrialInfoAccessFunction

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testTrialInfoAccessFunction(self):
        def train(config, reporter):
            reporter(name=reporter.trial_name, trial_id=reporter.trial_id)

        analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
        trial = analysis.trials[0]
        self.assertEqual(trial.last_result.get("name"), str(trial))
        self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)

        def track_train(config):
            tune.report(
                name=tune.get_trial_name(), trial_id=tune.get_trial_id())

        analysis = tune.run(track_train, stop={TRAINING_ITERATION: 1})
        trial = analysis.trials[0]
        self.assertEqual(trial.last_result.get("name"), str(trial))
        self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id) 
开发者ID:ray-project,项目名称:ray,代码行数:19,代码来源:test_api.py

示例13: testLotsOfStops

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testLotsOfStops(self):
        class TestTrainable(Trainable):
            def step(self):
                result = {"name": self.trial_name, "trial_id": self.trial_id}
                return result

            def cleanup(self):
                time.sleep(2)
                open(os.path.join(self.logdir, "marker"), "a").close()
                return 1

        analysis = tune.run(
            TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1})
        ray.shutdown()
        for trial in analysis.trials:
            path = os.path.join(trial.logdir, "marker")
            assert os.path.exists(path) 
开发者ID:ray-project,项目名称:ray,代码行数:19,代码来源:test_api.py

示例14: testIterationCounter

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def testIterationCounter(self):
        def train(config, reporter):
            for i in range(100):
                reporter(itr=i, timesteps_this_iter=1)

        register_trainable("exp", train)
        config = {
            "my_exp": {
                "run": "exp",
                "config": {
                    "iterations": 100,
                },
                "stop": {
                    "timesteps_total": 100
                },
            }
        }
        [trial] = run_experiments(config)
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
        self.assertEqual(trial.last_result["itr"], 99) 
开发者ID:ray-project,项目名称:ray,代码行数:23,代码来源:test_api.py

示例15: test

# 需要导入模块: from ray import tune [as 别名]
# 或者: from ray.tune import run [as 别名]
def test(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            # We set this just for the example to run quickly.
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total
# __train_def_end__


# __train_func_begin__ 
开发者ID:ray-project,项目名称:ray,代码行数:23,代码来源:tutorial.py


注:本文中的ray.tune.run方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。