Skip to content

Commit

Permalink
bug in lightning report
Browse files Browse the repository at this point in the history
  • Loading branch information
yreddy31 committed Feb 4, 2021
1 parent fd47831 commit 642576d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
except ImportError:
from distutils.core import setup

VERSION = '0.304'
VERSION = '0.305'
setup(
name = 'torch_snippets', # How you named your package folder (MyLib)
packages = ['torch_snippets'], # Chose the same as "name"
Expand Down
14 changes: 12 additions & 2 deletions torch_snippets/torch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ def plot_epochs(self, keys:list=None, ax=None, **kwargs):
fig, ax = plt.subplots(figsize=kwargs.get('figsize', sz))
avgs = defaultdict(list)
keys = self.logged if keys is None else keys

from tqdm import trange
if isinstance(keys, str):
key_pattern = keys
keys = [key for key in self.logged if re.search(key_pattern, key)]

for epoch in trange(self.n_epochs+1):
for k in keys:
items = takewhile(lambda x: epoch-1<=x.pos<epoch,
Expand Down Expand Up @@ -155,12 +160,17 @@ def report_metrics(self, pos, **report):
try:
from pytorch_lightning.callbacks.progress import ProgressBarBase
class LightningReport(ProgressBarBase):
def __init__(self, epochs, print_total=10, precision=4):
def __init__(self, epochs, print_total=None, precision=4):
super().__init__()
self.enable = True
self.epoch_ix = 0
self.report = Report(epochs, precision)
self.print_every = epochs // print_total
if print_total is None:
if epochs < 11: self.print_every = 1
else:
self.print_every = epochs // 5
else:
self.print_every = epochs // print_total

def disable(self):
self.enable = False
Expand Down

0 comments on commit 642576d

Please sign in to comment.