diff --git a/sumatra/projects.py b/sumatra/projects.py index 37030853..36e4b6c9 100644 --- a/sumatra/projects.py +++ b/sumatra/projects.py @@ -316,15 +316,29 @@ def find_records(self, tags=None, reverse=False, parameters=None, *args, **kwarg records = [rec for rec in records if len(rec.parameters.diff(parameters)[-1]) == 0] return records - def find_data(self, *args, **kwargs): + def find_input_data(self, *args, **kwargs): + records = self.find_records(*args, **kwargs) + if len(records) == 0: return [] + input_data = [] + for record in records: + for input_file in record.input_data: + input_data.append(input_file) + return input_data + + def find_output_data(self, *args, **kwargs): records = self.find_records(*args, **kwargs) + if (records) == 0: return [] output_data = [] for record in records: - for output_datakey in record.output_data: - if self.data_store.contains_path(output_datakey.path): - output_data.append(os.path.join(self.data_store.root, output_datakey.path)) + for output_file in record.output_data: + output_data.append(output_file) return output_data + def find_data(self, *args, **kwargs): + input_data = self.find_input_data(*args, **kwargs) + output_data = self.find_output_data(*args, **kwargs) + return {'input_data': input_data, 'output_data': output_data} + def format_records(self, format='text', mode='short', tags=None, reverse=False, *args, **kwargs): if format=='text' and mode=='short' and ('parameters' not in kwargs.keys()): return '\n'.join(self.get_labels(tags=tags, reverse=reverse, *args, **kwargs))