Skip to content

Commit

Permalink
Add custom x axes to TensorBoard (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
naeioi authored and ryanjulian committed Dec 4, 2019
1 parent 1ae0b7c commit 7b9fed2
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/dowel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module instantiates a global logger singleton.
"""
from dowel.histogram import Histogram
from dowel.logger import Logger, LogOutput
from dowel.logger import Logger, LoggerWarning, LogOutput
from dowel.simple_outputs import StdOutput, TextOutput
from dowel.tabular_input import TabularInput
from dowel.csv_output import CsvOutput # noqa: I100
Expand All @@ -19,6 +19,7 @@
'StdOutput',
'TextOutput',
'LogOutput',
'LoggerWarning',
'TabularInput',
'TensorBoardOutput',
'logger',
Expand Down
2 changes: 0 additions & 2 deletions src/dowel/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,5 +327,3 @@ def disable_warnings(self):

class LoggerWarning(UserWarning):
"""Warning class for the Logger."""

pass
87 changes: 75 additions & 12 deletions src/dowel/tensor_board_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
tensorboard summary operations through tensorboardX.
Note:
Neither tensorboardX or TensorBoard does not support log parametric
Neither TensorboardX nor TensorBoard supports log parametric
distributions. We add this feature by sampling data from a
`tfp.distributions.Distribution` object.
"""
import functools
import warnings

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -20,28 +21,52 @@
tf = None

from dowel import Histogram
from dowel import LoggerWarning
from dowel import LogOutput
from dowel import TabularInput
from dowel.utils import colorize


class TensorBoardOutput(LogOutput):
"""TensorBoard output for logger.
Args:
log_dir(str): The save location of the tensorboard event files.
x_axis(str): The name of data used as x-axis for scalar tabular.
If None, x-axis will be the number of dump() is called.
additional_x_axes(list[str]): Names of data to used be as additional
x-axes.
flush_secs(int): How often, in seconds, to flush the added summaries
and events to disk.
histogram_samples(int): Number of samples to generate when logging
random distribution.
"""
TensorBoard output for logger.

:param log_dir(str): The save location of the tensorboard event files.
:param flush_secs(int): How often, in seconds, to flush the added summaries
and events to disk.
"""
def __init__(self,
log_dir,
x_axis=None,
additional_x_axes=None,
flush_secs=120,
histogram_samples=1e3):
if x_axis is None:
assert not additional_x_axes, (
'You have to specify an x_axis if you want additional axes.')

additional_x_axes = additional_x_axes or []

def __init__(self, log_dir, flush_secs=120, histogram_samples=1e3):
self._writer = tbX.SummaryWriter(log_dir, flush_secs=flush_secs)
self._x_axis = x_axis
self._additional_x_axes = additional_x_axes
self._default_step = 0
self._histogram_samples = int(histogram_samples)
self._added_graph = False
self._waiting_for_dump = []
# Used in tests to emulate Tensorflow not being installed.
self._tf = tf

self._warned_once = set()
self._disable_warnings = False

@property
def types_accepted(self):
"""Return the types that the logger may pass to this output."""
Expand All @@ -51,11 +76,12 @@ def types_accepted(self):
return (TabularInput, self._tf.Graph)

def record(self, data, prefix=''):
"""
Add data to tensorboard summary.
"""Add data to tensorboard summary.
Args:
data: The data to be logged by the output.
prefix(str): A prefix placed before a log entry in text outputs.
:param data: The data to be logged by the output.
:param prefix(str): A prefix placed before a log entry in text outputs.
"""
if isinstance(data, TabularInput):
self._waiting_for_dump.append(
Expand All @@ -66,8 +92,29 @@ def record(self, data, prefix=''):
raise ValueError('Unacceptable type.')

def _record_tabular(self, data, step):
if self._x_axis:
nonexist_axes = []
for axis in [self._x_axis] + self._additional_x_axes:
if axis not in data.as_dict:
nonexist_axes.append(axis)
if nonexist_axes:
self._warn('{} {} exist in the tabular data.'.format(
', '.join(nonexist_axes),
'do not' if len(nonexist_axes) > 1 else 'does not'))

for key, value in data.as_dict.items():
self._record_kv(key, value, step)
if isinstance(value,
np.ScalarType) and self._x_axis in data.as_dict:
if self._x_axis is not key:
x = data.as_dict[self._x_axis]
self._record_kv(key, value, x)

for axis in self._additional_x_axes:
if key is not axis and key in data.as_dict:
x = data.as_dict[axis]
self._record_kv('{}/{}'.format(key, axis), value, x)
else:
self._record_kv(key, value, step)
data.mark(key)

def _record_kv(self, key, value, step):
Expand Down Expand Up @@ -106,3 +153,19 @@ def dump(self, step=None):
def close(self):
"""Flush all the events to disk and close the file."""
self._writer.close()

def _warn(self, msg):
"""Warns the user using warnings.warn.
The stacklevel parameter needs to be 3 to ensure the call to logger.log
is the one printed.
"""
if not self._disable_warnings and msg not in self._warned_once:
warnings.warn(
colorize(msg, 'yellow'), NonexistentAxesWarning, stacklevel=3)
self._warned_once.add(msg)
return msg


class NonexistentAxesWarning(LoggerWarning):
"""Raise when the specified x axes do not exist in the tabular."""
57 changes: 57 additions & 0 deletions tests/dowel/test_tensor_board_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,60 @@ def test_types_accepted_without_tensorflow(self):
self.tensor_board_output._tf = None
assert TabularInput in self.tensor_board_output.types_accepted
assert tf.Graph not in self.tensor_board_output.types_accepted


class TestTensorBoardOutputXAxesMocked(TBOutputTest):
"""Test custom x axes."""

def run(self, result=None):
with mock.patch('tensorboardX.SummaryWriter'):
super().run(result)

def setUp(self):
super().setUp()
self.mock_writer = self.tensor_board_output._writer

def test_record_scalar_one_axis(self):
self.tensor_board_output._x_axis = 'foo'
self.tensor_board_output._additional_x_axes = []

foo = 5
bar = 10.0
self.tabular.record('foo', foo)
self.tabular.record('bar', bar)
self.tensor_board_output.record(self.tabular)
self.tensor_board_output.dump()

self.mock_writer.add_scalar.assert_any_call('bar', bar, foo)
assert self.mock_writer.add_scalar.call_count == 1

def test_record_scalar_two_axes(self):
self.tensor_board_output._x_axis = 'foo'
self.tensor_board_output._additional_x_axes = ['bar']

foo = 5
bar = 10.0
self.tabular.record('foo', foo)
self.tabular.record('bar', bar)
self.tensor_board_output.record(self.tabular)
self.tensor_board_output.dump()

self.mock_writer.add_scalar.assert_any_call('foo/bar', foo, bar)
self.mock_writer.add_scalar.assert_any_call('bar', bar, foo)
assert self.mock_writer.add_scalar.call_count == 2

def test_record_scalar_nonexistent_axis(self):
self.tensor_board_output._default_step = 0
self.tensor_board_output._x_axis = 'qux'
self.tensor_board_output._additional_x_axes = ['bar']

foo = 5
bar = 10.0
self.tabular.record('foo', foo)
self.tabular.record('bar', bar)
self.tensor_board_output.record(self.tabular)
self.tensor_board_output.dump()

self.mock_writer.add_scalar.assert_any_call('foo', foo, 0)
self.mock_writer.add_scalar.assert_any_call('bar', bar, 0)
assert self.mock_writer.add_scalar.call_count == 2

0 comments on commit 7b9fed2

Please sign in to comment.