diff --git a/setup.py b/setup.py index baca03e..2c4c57c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ except ImportError: from distutils.core import setup -VERSION = '0.304' +VERSION = '0.305' setup( name = 'torch_snippets', # How you named your package folder (MyLib) packages = ['torch_snippets'], # Chose the same as "name" diff --git a/torch_snippets/torch_loader.py b/torch_snippets/torch_loader.py index 6588de1..de3a3e5 100644 --- a/torch_snippets/torch_loader.py +++ b/torch_snippets/torch_loader.py @@ -98,7 +98,12 @@ def plot_epochs(self, keys:list=None, ax=None, **kwargs): fig, ax = plt.subplots(figsize=kwargs.get('figsize', sz)) avgs = defaultdict(list) keys = self.logged if keys is None else keys + from tqdm import trange + if isinstance(keys, str): + key_pattern = keys + keys = [key for key in self.logged if re.search(key_pattern, key)] + for epoch in trange(self.n_epochs+1): for k in keys: items = takewhile(lambda x: epoch-1<=x.pos