Skip to content

Commit

Permalink
Porting back code from pytorch/pytorch and tensorflow/tensorboard (1) (
Browse files Browse the repository at this point in the history
…#422)

* backport from pytorch side:
comments
docstrings
variable names

* fix logdir

* backport from tensorboard
  • Loading branch information
lanpa authored May 13, 2019
1 parent 0bf6c07 commit bf8c679
Show file tree
Hide file tree
Showing 9 changed files with 464 additions and 311 deletions.
4 changes: 2 additions & 2 deletions examples/demo_purge.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from time import sleep
from tensorboardX import SummaryWriter

with SummaryWriter(log_dir='runs/purge') as w:
with SummaryWriter(logdir='runs/purge') as w:
for i in range(100):
w.add_scalar('purgetest', i, i)

sleep(1.0)

with SummaryWriter(log_dir='runs/purge', purge_step=42) as w:
with SummaryWriter(logdir='runs/purge', purge_step=42) as w:
# event 42~99 are removed (inclusively)
for i in range(42, 100):
w.add_scalar('purgetest', 42, i)
73 changes: 32 additions & 41 deletions tensorboardX/event_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

import os.path
import os
import socket
import threading
import time
Expand All @@ -39,18 +39,12 @@ def __init__(self, file_prefix, filename_suffix=''):
'''
self._file_name = file_prefix + ".out.tfevents." + str(time.time())[:10] + "." +\
socket.gethostname() + filename_suffix

self._num_outstanding_events = 0

self._py_recordio_writer = RecordWriter(self._file_name)

# Initialize an event instance.
self._event = event_pb2.Event()

self._event.wall_time = time.time()

self._lock = threading.Lock()

self.write_event(self._event)

def write_event(self, event):
Expand Down Expand Up @@ -84,35 +78,30 @@ def close(self):

class EventFileWriter(object):
"""Writes `Event` protocol buffers to an event file.
The `EventFileWriter` class creates an event file in the specified directory,
and asynchronously writes Event protocol buffers to the file. The Event file
is encoded using the tfrecord format, which is similar to RecordIO.
@@__init__
@@add_event
@@flush
@@close
"""

def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=''):
def __init__(self, logdir, max_queue_size=10, flush_secs=120, filename_suffix=''):
"""Creates a `EventFileWriter` and an event file to write to.
On construction the summary writer creates a new event file in `logdir`.
This event file will contain `Event` protocol buffers, which are written to
disk via the add_event method.
The other arguments to the constructor control the asynchronous writes to
the event file:
* `flush_secs`: How often, in seconds, to flush the added summaries
and events to disk.
* `max_queue`: Maximum number of summaries or events pending to be
written to disk before one of the 'add' calls block.
Args:
logdir: A string. Directory where event file will be written.
max_queue: Integer. Size of the queue for pending events and summaries.
max_queue_size: Integer. Size of the queue for pending events and summaries.
flush_secs: Number. How often, in seconds, to flush the
pending events and summaries to disk.
"""
self._logdir = logdir
directory_check(self._logdir)
self._event_queue = six.moves.queue.Queue(max_queue)
self._event_queue = six.moves.queue.Queue(max_queue_size)
self._ev_writer = EventsWriter(os.path.join(
self._logdir, "events"), filename_suffix)
self._flush_secs = flush_secs
Expand Down Expand Up @@ -141,6 +130,7 @@ def reopen(self):

def add_event(self, event):
"""Adds an event to the event file.
Args:
event: An `Event` protocol buffer.
"""
Expand All @@ -149,6 +139,7 @@ def add_event(self, event):

def flush(self):
"""Flushes the event file to disk.
Call this method to make sure that all pending events have been written to
disk.
"""
Expand All @@ -171,61 +162,61 @@ def close(self):
class _EventLoggerThread(threading.Thread):
"""Thread that logs events."""

def __init__(self, queue, ev_writer, flush_secs):
def __init__(self, queue, record_writer, flush_secs):
"""Creates an _EventLoggerThread.
Args:
queue: A Queue from which to dequeue events.
ev_writer: An event writer. Used to log brain events for
queue: A Queue from which to dequeue data.
record_writer: An data writer. Used to log brain events for
the visualizer.
flush_secs: How often, in seconds, to flush the
pending file to disk.
"""
threading.Thread.__init__(self)
self.daemon = True
self._queue = queue
self._ev_writer = ev_writer
self._record_writer = record_writer
self._flush_secs = flush_secs
# The first event will be flushed immediately.
self._next_event_flush_time = 0
self._has_pending_events = False
# The first data will be flushed immediately.
self._next_flush_time = 0
self._has_pending_data = False
self._shutdown_signal = object()

def stop(self):
self._queue.put(self._shutdown_signal)
self.join()

def run(self):
# Here wait on the queue until an event appears, or till the next
# Here wait on the queue until an data appears, or till the next
# time to flush the writer, whichever is earlier. If we have an
# event, write it. If not, an empty queue exception will be raised
# data, write it. If not, an empty queue exception will be raised
# and we can proceed to flush the writer.
while True:
now = time.time()
queue_wait_duration = self._next_event_flush_time - now
event = None
queue_wait_duration = self._next_flush_time - now
data = None
try:
if queue_wait_duration > 0:
event = self._queue.get(True, queue_wait_duration)
data = self._queue.get(True, queue_wait_duration)
else:
event = self._queue.get(False)
data = self._queue.get(False)

if event == self._shutdown_signal:
if data == self._shutdown_signal:
return
self._ev_writer.write_event(event)
self._has_pending_events = True
self._record_writer.write_event(data)
self._has_pending_data = True
except six.moves.queue.Empty:
pass
finally:
if event:
if data:
self._queue.task_done()

now = time.time()
if now > self._next_event_flush_time:
if self._has_pending_events:
# Small optimization - if there are no pending events,
if now > self._next_flush_time:
if self._has_pending_data:
# Small optimization - if there are no pending data,
# there's no need to flush, since each flush can be
# expensive (e.g. uploading a new file to a server).
self._ev_writer.flush()
self._has_pending_events = False
self._record_writer.flush()
self._has_pending_data = False
# Do it again in flush_secs.
self._next_event_flush_time = now + self._flush_secs
self._next_flush_time = now + self._flush_secs
6 changes: 3 additions & 3 deletions tensorboardX/onnx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from .proto.versions_pb2 import VersionDef
from .proto.attr_value_pb2 import AttrValue
from .proto.tensor_shape_pb2 import TensorShapeProto
# from .proto.onnx_pb2 import ModelProto


def gg(fname):
import onnx # 0.2.1
def load_onnx_graph(fname):
import onnx
m = onnx.load(fname)
g = m.graph
return parse(g)
Expand Down Expand Up @@ -46,6 +45,7 @@ def parse(graph):
input=node.input,
attr={'parameters': AttrValue(s=attr)},
))

# two pass token replacement, appends opname to object id
mapping = {}
for node in nodes:
Expand Down
37 changes: 19 additions & 18 deletions tensorboardX/proto_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,47 @@
from .proto.attr_value_pb2 import AttrValue
from .proto.tensor_shape_pb2 import TensorShapeProto

from collections import defaultdict

# nodes.append(
# NodeDef(name=node['name'], op=node['op'], input=node['inputs'],
# attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8')),
# '_output_shapes': AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))}))


def AttrValue_proto(dtype,
shape,
s,
):
def attr_value_proto(dtype, shape, s):
"""Creates a dict of objects matching
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
specifically designed for a NodeDef. The values have been
reverse engineered from standard TensorBoard logged data.
"""
attr = {}

if s is not None:
attr['attr'] = AttrValue(s=s.encode(encoding='utf_8'))

if shape is not None:
shapeproto = TensorShape_proto(shape)
shapeproto = tensor_shape_proto(shape)
attr['_output_shapes'] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))
return attr


def TensorShape_proto(outputsize):
def tensor_shape_proto(outputsize):
"""Creates an object matching
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
"""
return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize])


def Node_proto(name,
def node_proto(name,
op='UnSpecified',
input=[],
input=None,
dtype=None,
shape=None, # type: tuple
outputsize=None,
attributes=''
):
"""Creates an object matching
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
"""
if input is None:
input = []
if not isinstance(input, list):
input = [input]
return NodeDef(
name=name.encode(encoding='utf_8'),
op=op,
input=input,
attr=AttrValue_proto(dtype, outputsize, attributes)
attr=attr_value_proto(dtype, outputsize, attributes)
)
Loading

0 comments on commit bf8c679

Please sign in to comment.