Skip to content

Commit

Permalink
adds document for add_histogram_raw (#424)
Browse files Browse the repository at this point in the history
* add_histogram_raw fix #421

* add images for doc
  • Loading branch information
lanpa authored May 14, 2019
1 parent bf8c679 commit b73e5ae
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 11 deletions.
Binary file added docs/_static/img/tensorboard/add_histogram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/tensorboard/add_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/tensorboard/add_images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/tensorboard/add_scalar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/tensorboard/add_scalars.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/img/tensorboard/hier_tags.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 34 additions & 9 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(self, logdir=None, comment='', purge_step=None, max_queue=10,
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
# create a summary writer with automatically generated folder name.
writer = SummaryWriter()
Expand Down Expand Up @@ -322,7 +322,7 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
writer = SummaryWriter()
x = range(100)
for i in x:
Expand Down Expand Up @@ -353,7 +353,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
writer = SummaryWriter()
r = 5
for i in range(100):
Expand Down Expand Up @@ -410,7 +410,7 @@ def add_histogram(self, tag, values, global_step=None, bins='tensorflow', wallti
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import numpy as np
writer = SummaryWriter()
for i in range(10):
Expand Down Expand Up @@ -443,12 +443,37 @@ def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,
num (int): Number of values
sum (float or int): Sum of all values
sum_squares (float or int): Sum of squares for all values
bucket_limits (torch.Tensor, numpy.array): Upper value per bucket
bucket_limits (torch.Tensor, numpy.array): Upper value per
bucket, note that the bucket_limits returned from `np.histogram`
has one more element. See the comment in the following example.
bucket_counts (torch.Tensor, numpy.array): Number of values per bucket
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
Examples::
import numpy as np
dummy_data = []
for idx, value in enumerate(range(30)):
dummy_data += [idx + 0.001] * value
values = np.array(dummy_data).astype(float).reshape(-1)
counts, limits = np.histogram(values)
sum_sq = values.dot(values)
with SummaryWriter() as summary_writer:
summary_writer.add_histogram_raw(
tag='hist_dummy_data',
min=values.min(),
max=values.max(),
num=len(values),
sum=values.sum(),
sum_squares=sum_sq,
bucket_limits=limits[1:].tolist(), # <- note here.
bucket_counts=counts.tolist(),
global_step=0)
"""
if len(bucket_limits) != len(bucket_counts):
raise ValueError('len(bucket_limits) != len(bucket_counts), see the document.')
self._get_file_writer().add_summary(
histogram_raw(tag,
min,
Expand Down Expand Up @@ -479,7 +504,7 @@ def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformat
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import numpy as np
img = np.zeros((3, 100, 100))
img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
Expand Down Expand Up @@ -523,7 +548,7 @@ def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataforma
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import numpy as np
img_batch = np.zeros((16, 3, 100, 100))
Expand Down Expand Up @@ -784,7 +809,7 @@ def add_pr_curve(self, tag, labels, predictions, global_step=None,
Examples::
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import numpy as np
labels = np.random.randint(2, size=100) # binary label
predictions = np.random.rand(100)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pytorch_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_pytorch_histogram_raw(self):
num=num,
sum=floats.sum().item(),
sum_squares=sum_sq,
bucket_limits=limits.tolist(),
bucket_limits=limits[1:].tolist(),
bucket_counts=counts.tolist())

ints = x2num.make_np(torch.randint(0, 100, (num,)))
Expand All @@ -66,5 +66,5 @@ def test_pytorch_histogram_raw(self):
num=num,
sum=ints.sum().item(),
sum_squares=sum_sq,
bucket_limits=limits.tolist(),
bucket_limits=limits[1:].tolist(),
bucket_counts=counts.tolist())

0 comments on commit b73e5ae

Please sign in to comment.