Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom x axes to TensorBoard #38

Merged
merged 7 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/dowel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module instantiates a global logger singleton.
"""
from dowel.histogram import Histogram
from dowel.logger import Logger, LogOutput
from dowel.logger import Logger, LoggerWarning, LogOutput
from dowel.simple_outputs import StdOutput, TextOutput
from dowel.tabular_input import TabularInput
from dowel.csv_output import CsvOutput # noqa: I100
Expand All @@ -19,6 +19,7 @@
'StdOutput',
'TextOutput',
'LogOutput',
'LoggerWarning',
'TabularInput',
'TensorBoardOutput',
'logger',
Expand Down
2 changes: 0 additions & 2 deletions src/dowel/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,5 +327,3 @@ def disable_warnings(self):

class LoggerWarning(UserWarning):
"""Warning class for the Logger."""

pass
87 changes: 75 additions & 12 deletions src/dowel/tensor_board_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
tensorboard summary operations through tensorboardX.

Note:
Neither tensorboardX or TensorBoard does not support log parametric
Neither TensorboardX nor TensorBoard supports log parametric
distributions. We add this feature by sampling data from a
`tfp.distributions.Distribution` object.
"""
import functools
import warnings

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -20,28 +21,52 @@
tf = None

from dowel import Histogram
from dowel import LoggerWarning
from dowel import LogOutput
from dowel import TabularInput
from dowel.utils import colorize


class TensorBoardOutput(LogOutput):
"""TensorBoard output for logger.

Args:
log_dir(str): The save location of the tensorboard event files.
x_axis(str): The name of data used as x-axis for scalar tabular.
If None, x-axis will be the number of dump() is called.
additional_x_axes(list[str]): Names of data to used be as additional
x-axes.
flush_secs(int): How often, in seconds, to flush the added summaries
and events to disk.
histogram_samples(int): Number of samples to generate when logging
random distribution.
"""
TensorBoard output for logger.

:param log_dir(str): The save location of the tensorboard event files.
:param flush_secs(int): How often, in seconds, to flush the added summaries
and events to disk.
"""
def __init__(self,
log_dir,
x_axis=None,
additional_x_axes=None,
flush_secs=120,
histogram_samples=1e3):
if x_axis is None:
assert not additional_x_axes, (
'You have to specify an x_axis if you want additional axes.')

additional_x_axes = additional_x_axes or []

def __init__(self, log_dir, flush_secs=120, histogram_samples=1e3):
self._writer = tbX.SummaryWriter(log_dir, flush_secs=flush_secs)
self._x_axis = x_axis
self._additional_x_axes = additional_x_axes
self._default_step = 0
self._histogram_samples = int(histogram_samples)
self._added_graph = False
self._waiting_for_dump = []
# Used in tests to emulate Tensorflow not being installed.
self._tf = tf

self._warned_once = set()
self._disable_warnings = False

@property
def types_accepted(self):
"""Return the types that the logger may pass to this output."""
Expand All @@ -51,11 +76,12 @@ def types_accepted(self):
return (TabularInput, self._tf.Graph)

def record(self, data, prefix=''):
"""
Add data to tensorboard summary.
"""Add data to tensorboard summary.

Args:
data: The data to be logged by the output.
prefix(str): A prefix placed before a log entry in text outputs.

:param data: The data to be logged by the output.
:param prefix(str): A prefix placed before a log entry in text outputs.
"""
if isinstance(data, TabularInput):
self._waiting_for_dump.append(
Expand All @@ -66,8 +92,29 @@ def record(self, data, prefix=''):
raise ValueError('Unacceptable type.')

def _record_tabular(self, data, step):
if self._x_axis:
nonexist_axes = []
for axis in [self._x_axis] + self._additional_x_axes:
if axis not in data.as_dict:
nonexist_axes.append(axis)
if nonexist_axes:
self._warn('{} {} exist in the tabular data.'.format(
', '.join(nonexist_axes),
'do not' if len(nonexist_axes) > 1 else 'does not'))

for key, value in data.as_dict.items():
self._record_kv(key, value, step)
if isinstance(value,
np.ScalarType) and self._x_axis in data.as_dict:
if self._x_axis is not key:
x = data.as_dict[self._x_axis]
self._record_kv(key, value, x)

for axis in self._additional_x_axes:
if key is not axis and key in data.as_dict:
x = data.as_dict[axis]
self._record_kv('{}/{}'.format(key, axis), value, x)
else:
self._record_kv(key, value, step)
data.mark(key)

def _record_kv(self, key, value, step):
Expand Down Expand Up @@ -106,3 +153,19 @@ def dump(self, step=None):
def close(self):
"""Flush all the events to disk and close the file."""
self._writer.close()

def _warn(self, msg):
"""Warns the user using warnings.warn.

The stacklevel parameter needs to be 3 to ensure the call to logger.log
is the one printed.
"""
if not self._disable_warnings and msg not in self._warned_once:
warnings.warn(
colorize(msg, 'yellow'), NonexistentAxesWarning, stacklevel=3)
self._warned_once.add(msg)
return msg


class NonexistentAxesWarning(LoggerWarning):
"""Raise when the specified x axes do not exist in the tabular."""
57 changes: 57 additions & 0 deletions tests/dowel/test_tensor_board_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,60 @@ def test_types_accepted_without_tensorflow(self):
self.tensor_board_output._tf = None
assert TabularInput in self.tensor_board_output.types_accepted
assert tf.Graph not in self.tensor_board_output.types_accepted


class TestTensorBoardOutputXAxesMocked(TBOutputTest):
"""Test custom x axes."""

def run(self, result=None):
with mock.patch('tensorboardX.SummaryWriter'):
super().run(result)

def setUp(self):
super().setUp()
self.mock_writer = self.tensor_board_output._writer

def test_record_scalar_one_axis(self):
self.tensor_board_output._x_axis = 'foo'
self.tensor_board_output._additional_x_axes = []

foo = 5
bar = 10.0
self.tabular.record('foo', foo)
self.tabular.record('bar', bar)
self.tensor_board_output.record(self.tabular)
self.tensor_board_output.dump()

self.mock_writer.add_scalar.assert_any_call('bar', bar, foo)
assert self.mock_writer.add_scalar.call_count == 1

def test_record_scalar_two_axes(self):
self.tensor_board_output._x_axis = 'foo'
self.tensor_board_output._additional_x_axes = ['bar']

foo = 5
bar = 10.0
self.tabular.record('foo', foo)
self.tabular.record('bar', bar)
self.tensor_board_output.record(self.tabular)
self.tensor_board_output.dump()

self.mock_writer.add_scalar.assert_any_call('foo/bar', foo, bar)
self.mock_writer.add_scalar.assert_any_call('bar', bar, foo)
assert self.mock_writer.add_scalar.call_count == 2

def test_record_scalar_nonexistent_axis(self):
self.tensor_board_output._default_step = 0
self.tensor_board_output._x_axis = 'qux'
self.tensor_board_output._additional_x_axes = ['bar']

foo = 5
bar = 10.0
self.tabular.record('foo', foo)
self.tabular.record('bar', bar)
self.tensor_board_output.record(self.tabular)
self.tensor_board_output.dump()

self.mock_writer.add_scalar.assert_any_call('foo', foo, 0)
self.mock_writer.add_scalar.assert_any_call('bar', bar, 0)
assert self.mock_writer.add_scalar.call_count == 2