本文整理汇总了Python中mlflow.tracking方法的典型用法代码示例。如果您正苦于以下问题:Python mlflow.tracking方法的具体用法?Python mlflow.tracking怎么用?Python mlflow.tracking使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mlflow
的用法示例。
在下文中一共展示了mlflow.tracking方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_workspace_kwargs
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def get_workspace_kwargs() -> dict:
"""Get AzureML keyword arguments from environment
The name of this environment variable is set in the Argo workflow template,
and its value should be in the format:
`<subscription_id>:<resource_group>:<workspace_name>`.
Returns
-------
workspace_kwargs: dict
AzureML Workspace configuration to use for remote MLFlow tracking. See
:func:`gordo.builder.mlflow_utils.get_mlflow_client`.
"""
return get_kwargs_from_secret(
"AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"]
)
示例2: __init__
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def __init__(self, artifact_uri):
super(DatabricksArtifactRepository, self).__init__(artifact_uri)
if not artifact_uri.startswith('dbfs:/'):
raise MlflowException(message='DatabricksArtifactRepository URI must start with dbfs:/',
error_code=INVALID_PARAMETER_VALUE)
if not is_databricks_acled_artifacts_uri(artifact_uri):
raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be'
' databricks/mlflow-tracking/path/to/artifact/..'),
error_code=INVALID_PARAMETER_VALUE)
self.run_id = self._extract_run_id(self.artifact_uri)
# Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
# the path of `artifact_uri` relative to the MLflow Run's artifact root
# (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
# repository will be performed relative to this computed location
artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
run_relative_root_path = posixpath.relpath(
path=artifact_repo_root_path, start=run_artifact_root_path
)
# If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
self.run_relative_artifact_repo_root_path = \
"" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
示例3: run_uuid
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def run_uuid(self):
return mlflow.tracking.fluent.active_run().info.run_uuid
示例4: experiment_id
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def experiment_id(self):
return mlflow.tracking.fluent.active_run().info.experiment_id
示例5: _is_remote
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def _is_remote(self):
return not mlflow.tracking.utils._is_local_uri(
mlflow.get_tracking_uri()
)
示例6: get_run_id
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def get_run_id(client: MlflowClient, experiment_name: str, model_key: str) -> str:
"""
Get an existing or create a new run for the given model_key and experiment_name.
The model key corresponds to a unique configuration of the model. The corresponding
run must be manually stopped using the `mlflow.tracking.MlflowClient.set_terminated`
method.
Parameters
----------
client: mlflow.tracking.MlflowClient
Client with tracking uri set to AzureML if configured.
experiment_name: str
Name of experiment to log to.
model_key: str
Unique ID of model configuration.
Returns
-------
run_id: str
Unique ID of MLflow run to log to.
"""
experiment = client.get_experiment_by_name(experiment_name)
experiment_id = (
getattr(experiment, "experiment_id")
if experiment
else client.create_experiment(experiment_name)
)
return client.create_run(experiment_id, tags={"model_key": model_key}).info.run_id
示例7: mlflow_context
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def mlflow_context(
name: str,
model_key: str = uuid4().hex,
workspace_kwargs: dict = {},
service_principal_kwargs: dict = {},
):
"""
Generate MLflow logger function with either a local or AzureML backend
Parameters
----------
name: str
The name of the log group to log to (e.g. a model name).
model_key: str
Unique ID of logging run.
workspace_kwargs: dict
AzureML Workspace configuration to use for remote MLFlow tracking. See
:func:`gordo.builder.mlflow_utils.get_mlflow_client`.
service_principal_kwargs: dict
AzureML ServicePrincipalAuthentication keyword arguments. See
:func:`gordo.builder.mlflow_utils.get_mlflow_client`
Example
-------
>>> with tempfile.TemporaryDirectory as tmp_dir:
... mlflow.set_tracking_uri(f"file:{tmp_dir}")
... with mlflow_context("log_group", "unique_key", {}, {}) as (mlflow_client, run_id):
... log_machine(machine) # doctest: +SKIP
"""
mlflow_client = get_mlflow_client(workspace_kwargs, service_principal_kwargs)
run_id = get_run_id(mlflow_client, experiment_name=name, model_key=model_key)
logger.info(
f"MLflow client configured to use {'AzureML' if workspace_kwargs else 'local backend'}"
)
yield mlflow_client, run_id
mlflow_client.set_terminated(run_id)
示例8: commands
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def commands():
"""
Manage runs. To manage runs of experiments associated with a tracking server, set the
MLFLOW_TRACKING_URI environment variable to the URL of the desired server.
"""
pass
示例9: list_run
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def list_run(experiment_id, view):
"""
List all runs of the specified experiment in the configured tracking server.
"""
store = _get_store()
view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY
runs = store.search_runs([experiment_id], None, view_type)
table = []
for run in runs:
tags = {k: v for k, v in run.data.tags.items()}
run_name = tags.get(MLFLOW_RUN_NAME, "")
table.append([conv_longdate_to_str(run.info.start_time), run_name, run.info.run_id])
print(tabulate(sorted(table, reverse=True), headers=["Date", "Name", "ID"]))
示例10: _extract_run_id
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def _extract_run_id(artifact_uri):
"""
The artifact_uri is expected to be
dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
Once the path from the input uri is extracted and normalized, it is
expected to be of the form
databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
Hence the run_id is the 4th element of the normalized path.
:return: run_id extracted from the artifact_uri
"""
artifact_path = extract_and_normalize_path(artifact_uri)
return artifact_path.split('/')[3]
示例11: _call_endpoint
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def _call_endpoint(self, service, api, json_body):
db_profile = get_db_profile_from_uri(mlflow.tracking.get_tracking_uri())
db_creds = get_databricks_host_creds(db_profile)
endpoint, method = _SERVICE_AND_METHOD_TO_INFO[service][api]
response_proto = api.Response()
return call_endpoint(db_creds, endpoint, method, json_body, response_proto)
示例12: test_run_databricks_validations
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_run_databricks_validations(
tmpdir, cluster_spec_mock, # pylint: disable=unused-argument
tracking_uri_mock, dbfs_mocks, set_tag_mock): # pylint: disable=unused-argument
"""
Tests that running on Databricks fails before making any API requests if validations fail.
"""
with mock.patch.dict(os.environ, {'DATABRICKS_HOST': 'test-host', 'DATABRICKS_TOKEN': 'foo'}),\
mock.patch("mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request")\
as db_api_req_mock:
# Test bad tracking URI
tracking_uri_mock.return_value = tmpdir.strpath
with pytest.raises(ExecutionException):
run_databricks_project(cluster_spec_mock, synchronous=True)
assert db_api_req_mock.call_count == 0
db_api_req_mock.reset_mock()
mlflow_service = mlflow.tracking.MlflowClient()
assert (len(mlflow_service.list_run_infos(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID))
== 0)
tracking_uri_mock.return_value = "http://"
# Test misspecified parameters
with pytest.raises(ExecutionException):
mlflow.projects.run(
TEST_PROJECT_DIR, backend="databricks", entry_point="greeter",
backend_config=cluster_spec_mock)
assert db_api_req_mock.call_count == 0
db_api_req_mock.reset_mock()
# Test bad cluster spec
with pytest.raises(ExecutionException):
mlflow.projects.run(TEST_PROJECT_DIR, backend="databricks", synchronous=True,
backend_config=None)
assert db_api_req_mock.call_count == 0
db_api_req_mock.reset_mock()
# Test that validations pass with good tracking URIs
databricks.before_run_validations("http://", cluster_spec_mock)
databricks.before_run_validations("databricks", cluster_spec_mock)
示例13: test_get_tracking_uri_for_run
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_get_tracking_uri_for_run():
mlflow.set_tracking_uri("http://some-uri")
assert databricks._get_tracking_uri_for_run() == "http://some-uri"
mlflow.set_tracking_uri("databricks://profile")
assert databricks._get_tracking_uri_for_run() == "databricks"
mlflow.set_tracking_uri(None)
with mock.patch.dict(os.environ, {mlflow.tracking._TRACKING_URI_ENV_VAR: "http://some-uri"}):
assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri"
示例14: test_integration
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_integration(dirname):
n_epochs = 5
data = list(range(50))
losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)
def update_fn(engine, batch):
return next(losses_iter)
trainer = Engine(update_fn)
mlflow_logger = MLflowLogger(tracking_uri=os.path.join(dirname, "mlruns"))
true_values = []
def dummy_handler(engine, logger, event_name):
global_step = engine.state.get_event_attrib_value(event_name)
v = global_step * 0.1
true_values.append(v)
logger.log_metrics({"{}".format("test_value"): v}, step=global_step)
mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)
import mlflow
active_run = mlflow.active_run()
trainer.run(data, max_epochs=n_epochs)
mlflow_logger.close()
from mlflow.tracking import MlflowClient
client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
stored_values = client.get_metric_history(active_run.info.run_id, "test_value")
for t, s in zip(true_values, stored_values):
assert pytest.approx(t) == s.value
示例15: test_integration_as_context_manager
# 需要导入模块: import mlflow [as 别名]
# 或者: from mlflow import tracking [as 别名]
def test_integration_as_context_manager(dirname):
n_epochs = 5
data = list(range(50))
losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)
def update_fn(engine, batch):
return next(losses_iter)
true_values = []
with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:
trainer = Engine(update_fn)
def dummy_handler(engine, logger, event_name):
global_step = engine.state.get_event_attrib_value(event_name)
v = global_step * 0.1
true_values.append(v)
logger.log_metrics({"{}".format("test_value"): v}, step=global_step)
mlflow_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)
import mlflow
active_run = mlflow.active_run()
trainer.run(data, max_epochs=n_epochs)
from mlflow.tracking import MlflowClient
client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
stored_values = client.get_metric_history(active_run.info.run_id, "test_value")
for t, s in zip(true_values, stored_values):
assert pytest.approx(t) == s.value