diff --git a/src/classify/clipclassifier.py b/src/classify/clipclassifier.py index 4de7ce0c..9f6945a7 100644 --- a/src/classify/clipclassifier.py +++ b/src/classify/clipclassifier.py @@ -182,8 +182,6 @@ def classify_clip(self, clip, model, meta_data, reuse_frames=None): predictions.model_load_time = time.time() - start for i, track in enumerate(clip.tracks): - logging.info("Track id is %s", track.get_id()) - segment_frames = None if reuse_frames: tracks = meta_data.get("tracks") @@ -247,17 +245,6 @@ def save_metadata( prediction = predictions.prediction_for(track.get_id()) if prediction is None: continue - # DEBUGGING STUFF REMOVE ME - # logging.info("Track predictions %s", track) - # for p in prediction.predictions: - # logging.info( - # "Have %s sum %s smoothed %s mass %s", - # p, - # np.sum(p.prediction), - # np.round(p.smoothed_prediction), - # p.mass, - # ) - # logging.info("smoothed %s", np.round(100 * prediction.class_best_score)) prediction_meta = prediction.get_metadata() prediction_meta["model_id"] = model_id prediction_info.append(prediction_meta) diff --git a/src/ml_tools/hyperparams.py b/src/ml_tools/hyperparams.py index b1868fd0..db558eff 100644 --- a/src/ml_tools/hyperparams.py +++ b/src/ml_tools/hyperparams.py @@ -90,7 +90,6 @@ def segment_width(self): @property def segment_types(self): - segment_types = self.get("segment_type", [SegmentType.ALL_RANDOM]) # convert string to enum type if isinstance(segment_types[0], str): diff --git a/src/ml_tools/tools.py b/src/ml_tools/tools.py index ce604906..cad64667 100644 --- a/src/ml_tools/tools.py +++ b/src/ml_tools/tools.py @@ -194,9 +194,17 @@ def saveclassify_image(data, filename): # saves image channels side by side, expected data to be values in the range of 0->1 Path(filename).parent.mkdir(parents=True, exist_ok=True) r = Image.fromarray(np.uint8(data[:, :, 0])) - g = Image.fromarray(np.uint8(data[:, :, 1])) - b = g - # b = Image.fromarray(np.uint8(data[:, :, 2])) + _, _, channels = data.shape + + if channels == 1: + g = r + else: + g = Image.fromarray(np.uint8(data[:, :, 1])) + + if channels == 2: + b = r + else: + b = Image.fromarray(np.uint8(data[:, :, 2])) concat = np.concatenate((r, g, b), axis=1) # horizontally img = Image.fromarray(np.uint8(concat)) img.save(filename + ".png") diff --git a/src/rebuildDate.py b/src/rebuildDate.py index 7693842d..661e2d60 100644 --- a/src/rebuildDate.py +++ b/src/rebuildDate.py @@ -9,7 +9,7 @@ from dateutil.parser import parse as parse_date parser = argparse.ArgumentParser() -parser.add_argument("data_dir", help="Directory of hdf5 files") +parser.add_argument("data_dir", help="Directory of cptv files") args = parser.parse_args() args.data_dir = Path(args.data_dir) latest_date = None