Source code for uv.sql.reporter

#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reporter implementation that relies on SQLAlchemy.

Uses this as inspiration:
https://source.cloud.google.com/research-3141/tf2-jax-notebooks/+/master:experiment_manager.py

"""

from itertools import groupby
from typing import Any, Dict, Iterable, List, Optional, Union

import uv.reader.base as rb
import uv.sql.util as u
import uv.types as t
from sqlalchemy import JSON, REAL, Column, Integer, String
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from uv.reporter.base import AbstractReporter

Base = declarative_base()


[docs]def create_tables(engine: Engine): """This triggers table creation for the classes defined at the top. Run this to get your local DB ready before you write any metrics. """ Base.metadata.create_all(engine)
[docs]class Experiment(Base): """Stores general metadata, about the experiment, and provides a key that links together all of the metrics. You would create one of these when you create a reporter. """ __tablename__ = 'experiment' id = Column(Integer, primary_key=True) params = Column(JSON) def __repr__(self): return u.rep_string(self)
[docs]class Metric(Base): """An individual metric that we'll report to the table.""" __tablename__ = 'metrics' id = Column(Integer, primary_key=True) experiment_id = Column(Integer) run_id = Column(Integer) step = Column(Integer) tag = Column(String) value = Column(REAL) def __repr__(self): return u.rep_string(self)
[docs]def new_experiment(e: Union[Engine, sessionmaker], config: Dict[str, Any]) -> Experiment: """Generates a new experiment.""" session = u.session_maker(e)() exp = Experiment(params=config) session.add(exp) session.commit() return exp
[docs]class SQLReporter(AbstractReporter): """SQLite-backed reporter. This currently only does the things that you'd expect a reporter to be able to do; we don't yet support actual experiment creation, but that's coming. """ def __init__(self, e: Union[Engine, sessionmaker], experiment: Experiment, run_id: int): self._make_session = u.session_maker(e) self._engine = self._make_session.kw["bind"] if not u.sqlite_file_exists(self._engine): raise Exception( f"{self._engine.url} doesn't exist! Create the database before creating a reporter." ) self._experiment = experiment self._run_id = run_id def _metric(self, step: int, k: str, v: float) -> Metric: """Generates an instance of Metric tied to this specific reporter's parameters. """ return Metric(experiment_id=self._experiment.id, run_id=self._run_id, step=step, tag=k, value=v)
[docs] def report_all(self, step: int, m: Dict[t.MetricKey, t.Metric]) -> None: session = self._make_session() session.add_all([self._metric(step, k, v) for k, v in m.items()]) session.commit()
[docs] def reader(self) -> rb.AbstractReader: return SQLReader(self._make_session, self._experiment, self._run_id)
[docs]class SQLReader(rb.AbstractReader, rb.IterableReader): """AbstractReader implementation backed by a sqlite store. """ def __init__(self, e: Union[Engine, sessionmaker], experiment: Experiment, run_id: int, step_key: Optional[str] = None): self._make_session = u.session_maker(e) self._engine = self._make_session.kw["bind"] if not u.sqlite_file_exists(self._engine): raise Exception( f"{self._engine.url} doesn't exist! Create the database before creating a reader." ) if step_key is None: step_key = "step" self._experiment = experiment self._run_id = run_id
[docs] def keys(self) -> Iterable[t.Metric]: """Returns a list of all keys in the DB for this particular experiment and run. """ session = self._make_session() tags = session.query(Metric.tag.distinct()) filtered = tags.filter_by(experiment_id=self._experiment.id, run_id=self._run_id) for k in filtered: yield k[0]
[docs] def read(self, k: t.MetricKey) -> List[t.Metric]: session = self._make_session() tags = session.query(Metric.step, Metric.value) filtered = tags.filter_by(experiment_id=self._experiment.id, run_id=self._run_id, tag=k) return [{ "step": step, "value": v } for step, v in filtered.order_by(Metric.step)]
def _values(self, group): """Turns the group into the familiar reader interface return value.""" return [{"step": step, "value": v} for k, step, v in group]
[docs] def read_all(self, ks: List[t.MetricKey]) -> Dict[t.MetricKey, List[t.Metric]]: session = self._make_session() tags = session.query(Metric.tag, Metric.step, Metric.value) filtered = tags.filter_by( experiment_id=self._experiment.id, run_id=self._run_id, ).filter(Metric.tag.in_(ks)).order_by(Metric.tag, Metric.step) return { k: self._values(group) for k, group in groupby(filtered, lambda t: t[0]) }