當前位置: 首頁>>代碼示例>>Python>>正文


Python mlflow.tracking方法代碼示例

本文整理匯總了Python中mlflow.tracking方法的典型用法代碼示例。如果您正苦於以下問題:Python mlflow.tracking方法的具體用法?Python mlflow.tracking怎麽用?Python mlflow.tracking使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在mlflow的用法示例。


在下文中一共展示了mlflow.tracking方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_workspace_kwargs

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def get_workspace_kwargs() -> dict:
    """Get AzureML keyword arguments from environment

    The name of this environment variable is set in the Argo workflow template,
    and its value should be in the format:
    `<subscription_id>:<resource_group>:<workspace_name>`.

    Returns
    -------
    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`.
    """
    return get_kwargs_from_secret(
        "AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"]
    ) 
開發者ID:equinor,項目名稱:gordo,代碼行數:18,代碼來源:mlflow.py

示例2: __init__

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def __init__(self, artifact_uri):
        super(DatabricksArtifactRepository, self).__init__(artifact_uri)
        if not artifact_uri.startswith('dbfs:/'):
            raise MlflowException(message='DatabricksArtifactRepository URI must start with dbfs:/',
                                  error_code=INVALID_PARAMETER_VALUE)
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be'
                                           ' databricks/mlflow-tracking/path/to/artifact/..'),
                                  error_code=INVALID_PARAMETER_VALUE)
        self.run_id = self._extract_run_id(self.artifact_uri)

        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:26,代碼來源:databricks_artifact_repo.py

示例3: run_uuid

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def run_uuid(self):
        return mlflow.tracking.fluent.active_run().info.run_uuid 
開發者ID:Unbabel,項目名稱:OpenKiwi,代碼行數:4,代碼來源:loggers.py

示例4: experiment_id

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def experiment_id(self):
        return mlflow.tracking.fluent.active_run().info.experiment_id 
開發者ID:Unbabel,項目名稱:OpenKiwi,代碼行數:4,代碼來源:loggers.py

示例5: _is_remote

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def _is_remote(self):
        return not mlflow.tracking.utils._is_local_uri(
            mlflow.get_tracking_uri()
        ) 
開發者ID:Unbabel,項目名稱:OpenKiwi,代碼行數:6,代碼來源:loggers.py

示例6: get_run_id

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def get_run_id(client: MlflowClient, experiment_name: str, model_key: str) -> str:
    """
    Get an existing or create a new run for the given model_key and experiment_name.

    The model key corresponds to a unique configuration of the model. The corresponding
    run must be manually stopped using the `mlflow.tracking.MlflowClient.set_terminated`
    method.

    Parameters
    ----------
    client: mlflow.tracking.MlflowClient
        Client with tracking uri set to AzureML if configured.
    experiment_name: str
        Name of experiment to log to.
    model_key: str
        Unique ID of model configuration.

    Returns
    -------
    run_id: str
        Unique ID of MLflow run to log to.
    """
    experiment = client.get_experiment_by_name(experiment_name)

    experiment_id = (
        getattr(experiment, "experiment_id")
        if experiment
        else client.create_experiment(experiment_name)
    )
    return client.create_run(experiment_id, tags={"model_key": model_key}).info.run_id 
開發者ID:equinor,項目名稱:gordo,代碼行數:32,代碼來源:mlflow.py

示例7: mlflow_context

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def mlflow_context(
    name: str,
    model_key: str = uuid4().hex,
    workspace_kwargs: dict = {},
    service_principal_kwargs: dict = {},
):
    """
    Generate MLflow logger function with either a local or AzureML backend

    Parameters
    ----------
    name: str
        The name of the log group to log to (e.g. a model name).
    model_key: str
        Unique ID of logging run.
    workspace_kwargs: dict
        AzureML Workspace configuration to use for remote MLFlow tracking. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`.
    service_principal_kwargs: dict
        AzureML ServicePrincipalAuthentication keyword arguments. See
        :func:`gordo.builder.mlflow_utils.get_mlflow_client`

    Example
    -------
    >>> with tempfile.TemporaryDirectory as tmp_dir:
    ...     mlflow.set_tracking_uri(f"file:{tmp_dir}")
    ...     with mlflow_context("log_group", "unique_key", {}, {}) as (mlflow_client, run_id):
    ...         log_machine(machine) # doctest: +SKIP
    """
    mlflow_client = get_mlflow_client(workspace_kwargs, service_principal_kwargs)
    run_id = get_run_id(mlflow_client, experiment_name=name, model_key=model_key)

    logger.info(
        f"MLflow client configured to use {'AzureML' if workspace_kwargs else 'local backend'}"
    )

    yield mlflow_client, run_id

    mlflow_client.set_terminated(run_id) 
開發者ID:equinor,項目名稱:gordo,代碼行數:41,代碼來源:mlflow.py

示例8: commands

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def commands():
    """
    Manage runs. To manage runs of experiments associated with a tracking server, set the
    MLFLOW_TRACKING_URI environment variable to the URL of the desired server.
    """
    pass 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:8,代碼來源:runs.py

示例9: list_run

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def list_run(experiment_id, view):
    """
    List all runs of the specified experiment in the configured tracking server.
    """
    store = _get_store()
    view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY
    runs = store.search_runs([experiment_id], None, view_type)
    table = []
    for run in runs:
        tags = {k: v for k, v in run.data.tags.items()}
        run_name = tags.get(MLFLOW_RUN_NAME, "")
        table.append([conv_longdate_to_str(run.info.start_time), run_name, run.info.run_id])
    print(tabulate(sorted(table, reverse=True), headers=["Date", "Name", "ID"])) 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:15,代碼來源:runs.py

示例10: _extract_run_id

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def _extract_run_id(artifact_uri):
        """
        The artifact_uri is expected to be
        dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
        Once the path from the input uri is extracted and normalized, it is
        expected to be of the form
        databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>

        Hence the run_id is the 4th element of the normalized path.

        :return: run_id extracted from the artifact_uri
        """
        artifact_path = extract_and_normalize_path(artifact_uri)
        return artifact_path.split('/')[3] 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:16,代碼來源:databricks_artifact_repo.py

示例11: _call_endpoint

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def _call_endpoint(self, service, api, json_body):
        db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri())
        db_creds = get_databricks_host_creds(db_profile)
        endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
        response_proto = api.Response()
        return call_endpoint(db_creds, endpoint, method, json_body, response_proto) 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:8,代碼來源:databricks_artifact_repo.py

示例12: test_run_databricks_validations

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def test_run_databricks_validations(
        tmpdir, cluster_spec_mock,  # pylint: disable=unused-argument
        tracking_uri_mock, dbfs_mocks, set_tag_mock):  # pylint: disable=unused-argument
    """
    Tests that running on Databricks fails before making any API requests if validations fail.
    """
    with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
        mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\
            as db_api_req_mock:
        # Test bad tracking URI
        tracking_uri_mock.return_value = tmpdir.strpath
        with pytest.raises(ExecutionException):
            run_databricks_project(cluster_spec_mock, synchronous=True)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        mlflow_service = mlflow.tracking.MlflowClient()
        assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID))
                == 0)
        tracking_uri_mock.return_value = "http://"
        # Test misspecified parameters
        with pytest.raises(ExecutionException):
            mlflow.projects.run(
                TEST_PROJECT_DIR, backend="databricks", entry_point="greeter",
                backend_config=cluster_spec_mock)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test bad cluster spec
        with pytest.raises(ExecutionException):
            mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True,
                                backend_config=None)
        assert db_api_req_mock.call_count == 0
        db_api_req_mock.reset_mock()
        # Test that validations pass with good tracking URIs
        databricks.before_run_validations("http://", cluster_spec_mock)
        databricks.before_run_validations("databricks", cluster_spec_mock) 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:37,代碼來源:test_databricks.py

示例13: test_get_tracking_uri_for_run

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def test_get_tracking_uri_for_run():
    mlflow.set_tracking_uri("http://some-uri")
    assert databricks._get_tracking_uri_for_run() == "http://some-uri"
    mlflow.set_tracking_uri("databricks://profile")
    assert databricks._get_tracking_uri_for_run() == "databricks"
    mlflow.set_tracking_uri(None)
    with mock.patch.dict(os.environ, {mlflow.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}):
        assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri" 
開發者ID:mlflow,項目名稱:mlflow,代碼行數:10,代碼來源:test_databricks.py

示例14: test_integration

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def test_integration(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns"))

    true_values = []

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        v = global_step * 0.1
        true_values.append(v)
        logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

    mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    import mlflow

    active_run = mlflow.active_run()

    trainer.run(data, max_epochs=n_epochs)
    mlflow_logger.close()

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value 
開發者ID:pytorch,項目名稱:ignite,代碼行數:41,代碼來源:test_mlflow_logger.py

示例15: test_integration_as_context_manager

# 需要導入模塊: import mlflow [as 別名]
# 或者: from mlflow import tracking [as 別名]
def test_integration_as_context_manager(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    true_values = []

    with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:

        trainer = Engine(update_fn)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            v = global_step * 0.1
            true_values.append(v)
            logger.log_metrics({"{}".format("test_value"): v}, step=global_step)

        mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        import mlflow

        active_run = mlflow.active_run()

        trainer.run(data, max_epochs=n_epochs)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "test_value")

    for t, s in zip(true_values, stored_values):
        assert pytest.approx(t) == s.value 
開發者ID:pytorch,項目名稱:ignite,代碼行數:40,代碼來源:test_mlflow_logger.py


注:本文中的mlflow.tracking方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。