Source code for uv.reporter.state

#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reporter interface and implementations."""

import logging
import os
from contextlib import contextmanager
import google.auth
from typing import Dict, Optional

import mlflow as mlf
import uv.types as t
import uv.util.env as ue
from uv.reporter.base import AbstractReporter
from uv.reporter.store import NullReporter

_active_reporter = None


[docs]def get_reporter() -> AbstractReporter: """Returns the active reporter set using set_reporter()""" if _active_reporter is not None: return _active_reporter return NullReporter()
[docs]def set_reporter(r: AbstractReporter) -> AbstractReporter: """Set the globally available reporter instance. Returns its input.""" global _active_reporter _active_reporter = r return _active_reporter
[docs]@contextmanager def active_reporter(r: AbstractReporter): old_reporter = _active_reporter globals()['_active_reporter'] = r yield r globals()['_active_reporter'] = old_reporter
[docs]def report(step: int, k: t.MetricKey, v: t.Metric) -> None: """Accepts a step (an ordered int referencing some timestep), a metric key and a value, and persists the metric into the globally available reporter returned by uv.get_reporter(). """ return get_reporter().report(step, k, v)
[docs]def report_all(step: int, m: Dict[t.MetricKey, t.Metric]) -> None: """Accepts a step (an ordered int referencing some timestep) and a dictionary of metric key => metric value, and persists the metric into the globally available reporter returned by uv.get_reporter(). """ return get_reporter().report_all(step, m)
[docs]def report_param(k: str, v: str) -> None: """Accepts a key and value parameter and logs these as parameters alongside the reported metrics. Reports to the globally available reporter returned by uv.get_reporter(). """ return get_reporter().report_param(k, v)
[docs]def report_params(m: Dict[str, str]) -> None: """Accepts a dict of parameter name -> value, and logs these as parameters alongside the reported metrics. Reports to the globally available reporter returned by uv.get_reporter(). """ return get_reporter().report_params(m)
def _ensure_non_null_project(artifact_root: Optional[str]): '''Ensures that the google cloud python api methods can get a non-None project id when the mlflow artifact root is a storage bucket. This is necessary because mlflow uses google.cloud.storage.Client() to create a client instance, which requires a project. This project name does not appear to need to be valid for reading and writing artifacts, so we set it to a placeholder string if there is no project available via the standard api methods. ''' if artifact_root is None: return if not artifact_root.startswith('gs://'): return _, project_id = google.auth.default() if project_id is not None: return # we only set this as a last resort os.environ['GOOGLE_CLOUD_PROJECT'] = 'placeholder' return
[docs]def start_run(param_prefix: Optional[str] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, artifact_location: Optional[str] = None, **args) -> mlf.ActiveRun: """Close alias of mlflow.start_run. The only difference is that uv.start_run attempts to extract parameters from the environment and log those to the bound UV reporter using `report_params`. Note that if experiment_name is specified and refers to an existing experiment, then the artifact_location will not be honored as this is an immutable property of an mlflow experiment. This method will issue a warning but proceed. Note that the returned value can be used as a context manager: https://www.mlflow.org/docs/latest/python_api/mlflow.html#mlflow.start_run """ if experiment_name is None: experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") if run_name is None: run_name = os.environ.get("MLFLOW_RUN_NAME") if artifact_location is None: artifact_location = os.environ.get("MLFLOW_ARTIFACT_ROOT") _ensure_non_null_project(artifact_location) # Make sure the experiment exists before the run starts. if experiment_name is not None: if mlf.get_experiment_by_name(experiment_name) is None: mlf.create_experiment(experiment_name, artifact_location) mlf.set_experiment(experiment_name) ret = mlf.start_run(run_name=run_name, **args) env_params = ue.extract_params(prefix=param_prefix) mlf.set_tags(env_params) # for CAIP jobs, we add the job id as a tag, along with a link to the # console page cloud_ml_job_id = os.environ.get('CLOUD_ML_JOB_ID') if cloud_ml_job_id is not None: mlf.set_tag( 'cloud_ml_job_details', f'https://console.cloud.google.com/ai-platform/jobs/{cloud_ml_job_id}') mlf.set_tag('cloud_ml_job_id', cloud_ml_job_id) mlf_artifact_uri = mlf.get_artifact_uri() if mlf_artifact_uri is not None and artifact_location is not None: if not mlf_artifact_uri.startswith(artifact_location): logging.warning( f'requested mlflow artifact location {artifact_location} differs ' f'from existing experiment artifact uri {mlf_artifact_uri}') return ret