当前位置: 首页>>代码示例>>Python>>正文


Python mlflow.tracking方法代码示例

本文整理汇总了Python中mlflow.tracking方法的典型用法代码示例。如果您正苦于以下问题:Python mlflow.tracking方法的具体用法?Python mlflow.tracking怎么用?Python mlflow.tracking使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mlflow的用法示例。


在下文中一共展示了mlflow.tracking方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_workspace_kwargs

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def get_workspace_kwargs() -> dict:
    """Get AzureML keyword arguments from environment

    The name of this environment variable is set in the Argo workflow template,
    and its value should be in the format:
    `<subscription_id>:<resource_group>:<workspace_name>`.

    Returns
    -------
    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`.
    """
    return get_kwargs_from_secret(
        "AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"]
    ) 
开发者ID:equinor,项目名称:gordo,代码行数:18,代码来源:mlflow.py

示例2: __init__

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def __init__(self, artifact_uri):
        super(DatabricksArtifactRepository, self).__init__(artifact_uri)
        if not artifact_uri.startswith('dbfs:/'):
            raise MlflowException(message='DatabricksArtifactRepository URI must start with dbfs:/',
                                  error_code=INVALID_PARAMETER_VALUE)
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be'
                                           ' databricks/mlflow-tracking/path/to/artifact/..'),
                                  error_code=INVALID_PARAMETER_VALUE)
        self.run_id = self._extract_run_id(self.artifact_uri)

        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path 
开发者ID:mlflow,项目名称:mlflow,代码行数:26,代码来源:databricks_artifact_repo.py

示例3: run_uuid

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def run_uuid(self):
        return mlflow.tracking.fluent.active_run().info.run_uuid 
开发者ID:Unbabel,项目名称:OpenKiwi,代码行数:4,代码来源:loggers.py

示例4: experiment_id

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def experiment_id(self):
        return mlflow.tracking.fluent.active_run().info.experiment_id 
开发者ID:Unbabel,项目名称:OpenKiwi,代码行数:4,代码来源:loggers.py

示例5: _is_remote

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def _is_remote(self):
        return not mlflow.tracking.utils._is_local_uri(
            mlflow.get_tracking_uri()
        ) 
开发者ID:Unbabel,项目名称:OpenKiwi,代码行数:6,代码来源:loggers.py

示例6: get_run_id

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def get_run_id(client: MlflowClient, experiment_name: str, model_key: str) -> str:
    """
    Get an existing or create a new run for the given model_key and experiment_name.

    The model key corresponds to a unique configuration of the model. The corresponding
    run must be manually stopped using the `mlflow.tracking.MlflowClient.set_terminated`
    method.

    Parameters
    ----------
    client: mlflow.tracking.MlflowClient
        Client with tracking uri set to AzureML if configured.
    experiment_name: str
        Name of experiment to log to.
    model_key: str
        Unique ID of model configuration.

    Returns
    -------
    run_id: str
        Unique ID of MLflow run to log to.
    """
    experiment = client.get_experiment_by_name(experiment_name)

    experiment_id = (
        getattr(experiment, "experiment_id")
        if experiment
        else client.create_experiment(experiment_name)
    )
    return client.create_run(experiment_id, tags={"model_key": model_key}).info.run_id 
开发者ID:equinor,项目名称:gordo,代码行数:32,代码来源:mlflow.py

示例7: mlflow_context

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def mlflow_context(
    name: str,
    model_key: str = uuid4().hex,
    workspace_kwargs: dict = {},
    service_principal_kwargs: dict = {},
):
    """
    Generate MLflow logger function with either a local or AzureML backend

    Parameters
    ----------
    name: str
        The name of the log group to log to (e.g. a model name).
    model_key: str
        Unique ID of logging run.
    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`.
    service_principal_kwargs: dict
        AzureML ServicePrincipalAuthentication keyword arguments. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`

    Example
    -------
    >>> with tempfile.TemporaryDirectory as tmp_dir:
    ...     mlflow.set_tracking_uri(f"file:{tmp_dir}")
    ...     with mlflow_context("log_group", "unique_key", {}, {}) as (mlflow_client, run_id):
    ...         log_machine(machine) # doctest: +SKIP
    """
    mlflow_client = get_mlflow_client(workspace_kwargs, service_principal_kwargs)
    run_id = get_run_id(mlflow_client, experiment_name=name, model_key=model_key)

    logger.info(
        f"MLflow client configured to use {'AzureML' if workspace_kwargs else 'local backend'}"
    )

    yield mlflow_client, run_id

    mlflow_client.set_terminated(run_id) 
开发者ID:equinor,项目名称:gordo,代码行数:41,代码来源:mlflow.py

示例8: commands

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def commands():
    """
    Manage runs. To manage runs of experiments associated with a tracking server, set the
    MLFLOW_TRACKING_URI environment variable to the URL of the desired server.
    """
    pass 
开发者ID:mlflow,项目名称:mlflow,代码行数:8,代码来源:runs.py

示例9: list_run

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def list_run(experiment_id, view):
    """
    List all runs of the specified experiment in the configured tracking server.
    """
    store = _get_store()
    view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY
    runs = store.search_runs([experiment_id], None, view_type)
    table = []
    for run in runs:
        tags = {k: v for k, v in run.data.tags.items()}
        run_name = tags.get(MLFLOW_RUN_NAME, "")
        table.append([conv_longdate_to_str(run.info.start_time), run_name, run.info.run_id])
    print(tabulate(sorted(table, reverse=True), headers=["Date", "Name", "ID"])) 
开发者ID:mlflow,项目名称:mlflow,代码行数:15,代码来源:runs.py

示例10: _extract_run_id

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def _extract_run_id(artifact_uri):
        """
        The artifact_uri is expected to be
        dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
        Once the path from the input uri is extracted and normalized, it is
        expected to be of the form
        databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>

        Hence the run_id is the 4th element of the normalized path.

        :return: run_id extracted from the artifact_uri
        """
        artifact_path = extract_and_normalize_path(artifact_uri)
        return artifact_path.split('/')[3] 
开发者ID:mlflow,项目名称:mlflow,代码行数:16,代码来源:databricks_artifact_repo.py

示例11: _call_endpoint

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def _call_endpoint(self, service, api, json_body):
        db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri())
        db_creds = get_databricks_host_creds(db_profile)
        endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
        response_proto = api.Response()
        return call_endpoint(db_creds, endpoint, method, json_body, response_proto) 
开发者ID:mlflow,项目名称:mlflow,代码行数:8,代码来源:databricks_artifact_repo.py

示例12: test_run_databricks_validations

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_run_databricks_validations(
        tmpdir, cluster_spec_mock,  # pylint: disable=unused-argument
        tracking_uri_mock, dbfs_mocks, set_tag_mock):  # pylint: disable=unused-argument
    """
    Tests that running on Databricks fails before making any API requests if validations fail.
    """
    with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
        mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\
            as db_api_req_mock:
        # Test bad tracking URI
        tracking_uri_mock.return_value = tmpdir.strpath
        with pytest.raises(ExecutionException):
            run_databricks_project(cluster_spec_mock, synchronous=True)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        mlflow_service = mlflow.tracking.MlflowClient()
        assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID))
                == 0)
        tracking_uri_mock.return_value = "http://"
        # Test misspecified parameters
        with pytest.raises(ExecutionException):
            mlflow.projects.run(
                TEST_PROJECT_DIR, backend="databricks", entry_point="greeter",
                backend_config=cluster_spec_mock)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test bad cluster spec
        with pytest.raises(ExecutionException):
            mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True,
                                backend_config=None)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test that validations pass with good tracking URIs
        databricks.before_run_validations("http://", cluster_spec_mock)
        databricks.before_run_validations("databricks", cluster_spec_mock) 
开发者ID:mlflow,项目名称:mlflow,代码行数:37,代码来源:test_databricks.py

示例13: test_get_tracking_uri_for_run

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_get_tracking_uri_for_run():
    mlflow.set_tracking_uri("http://some-uri")
    assert databricks._get_tracking_uri_for_run() == "http://some-uri"
    mlflow.set_tracking_uri("databricks://profile")
    assert databricks._get_tracking_uri_for_run() == "databricks"
    mlflow.set_tracking_uri(None)
    with mock.patch.dict(os.environ, {mlflow.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}):
        assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri" 
开发者ID:mlflow,项目名称:mlflow,代码行数:10,代码来源:test_databricks.py

示例14: test_integration

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_integration(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns"))

    true_values = []

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        v = global_step * 0.1
        true_values.append(v)
        logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

    mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    import mlflow

    active_run = mlflow.active_run()

    trainer.run(data, max_epochs=n_epochs)
    mlflow_logger.close()

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value 
开发者ID:pytorch,项目名称:ignite,代码行数:41,代码来源:test_mlflow_logger.py

示例15: test_integration_as_context_manager

# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_integration_as_context_manager(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    true_values = []

    with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:

        trainer = Engine(update_fn)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            v = global_step * 0.1
            true_values.append(v)
            logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

        mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        import mlflow

        active_run = mlflow.active_run()

        trainer.run(data, max_epochs=n_epochs)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value 
开发者ID:pytorch,项目名称:ignite,代码行数:40,代码来源:test_mlflow_logger.py


注:本文中的mlflow.tracking方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。