Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training updates #215

Merged
merged 119 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
a57aecf
tweak training
Aug 27, 2024
ec115ab
rain to fp
Aug 27, 2024
a5eb2ee
xla
Aug 27, 2024
26857b9
add config
Aug 27, 2024
07084f8
make bash and rename
Aug 28, 2024
79dccd1
fix seg width
Sep 15, 2024
504e747
fix name
Sep 15, 2024
ddaf4fe
debug
Sep 17, 2024
a0a6774
use rust binding
Sep 17, 2024
65d962f
remove unneeded
Sep 17, 2024
2586011
remove unneeded
Sep 17, 2024
a63e9a7
avoid bad regions
Sep 18, 2024
91ea9e7
update python-cptv
Sep 20, 2024
5fb72e9
save some files to test
Sep 20, 2024
38228ac
more debugging
Sep 22, 2024
f90c6ed
more debugging
Sep 22, 2024
127940d
double chicken
Sep 22, 2024
31f2e86
add missing station id
Sep 23, 2024
2d31161
load small
Sep 23, 2024
138bdfb
add check
Sep 23, 2024
3dfefc3
use model lbls
Sep 23, 2024
cdc6ef6
remap labels
Sep 23, 2024
deb1a2b
fix new
Sep 23, 2024
6731c5b
no need to show
Sep 23, 2024
e878f0c
add ext
Sep 23, 2024
2880801
add mode
Sep 23, 2024
2657382
add more debug
Sep 23, 2024
a1a37d7
exclude most
Sep 23, 2024
61338c7
comma
Sep 23, 2024
513fcf3
remaining 2
Sep 23, 2024
8ff1771
dont save strange values
Sep 23, 2024
ae90871
weighting
Sep 23, 2024
b66fe31
debugging
Sep 24, 2024
29f0d69
fixed resize and keep edge
Sep 24, 2024
09be246
fix tf
Sep 24, 2024
79d931e
rought balance
Sep 24, 2024
31db14d
more debug
Sep 24, 2024
91b196f
debug source
Sep 24, 2024
03316c0
print id
Sep 24, 2024
5478116
max smaples
Sep 24, 2024
5c16876
more test
Sep 24, 2024
8d9eed7
add lbls
Sep 24, 2024
c13aa4b
tidy up
Sep 25, 2024
670ec4b
remove load config
Sep 25, 2024
1531bec
delete load config
Sep 25, 2024
c5f2106
remove load
Sep 25, 2024
40d8874
fix base_training default
Sep 25, 2024
f38b83a
labels
Sep 25, 2024
f8fffb6
adjusted defaults
Sep 25, 2024
7720d59
remove source id
Sep 25, 2024
1ac2e23
add config to split data by location
Sep 25, 2024
b50918e
up requirements
Sep 25, 2024
a174fc5
none location check
Sep 26, 2024
d40f7e6
added fine tune option
Sep 29, 2024
4ab73d8
fix load
Sep 29, 2024
55898cf
Merge remote-tracking branch 'origin/master' into debugging
Oct 1, 2024
552b788
add date filtering
Oct 1, 2024
c52246d
count action
Oct 1, 2024
a24dade
adjust
Oct 1, 2024
5d0ab22
fix confusion
Oct 1, 2024
34d072a
add date filter
Oct 1, 2024
5ccea12
add loading of metadata
Oct 1, 2024
166c252
adjust
Oct 1, 2024
68bc3b8
no land bird
Oct 1, 2024
2d1eb3f
ignore no meta
Oct 1, 2024
70a4ff8
add none and unid
Oct 1, 2024
e6eb8b8
catch non existend labels
Oct 1, 2024
d7e3c10
use logging
Oct 1, 2024
c6c6abb
exclude unknown tag
Oct 1, 2024
4131892
correct method
Oct 1, 2024
f56fca7
comma
Oct 1, 2024
2b88fe4
add get id
Oct 1, 2024
f160436
remove debugg
Oct 1, 2024
ae4a81c
catch ex
Oct 1, 2024
23eed72
and support for therml norm diff
Oct 1, 2024
1e67f11
remove test logging
Oct 1, 2024
c3ce590
add smooething
Oct 1, 2024
8f77f5a
dont sq
Oct 1, 2024
34f2f21
build frames dataset
Oct 2, 2024
a34b205
support for frames model
Oct 2, 2024
8d12239
skip frames on edge
Oct 3, 2024
86fad59
tweak a few defaults and min mass filtering
Oct 3, 2024
a9044d6
add max samples
Oct 3, 2024
29dc459
fix cap
Oct 3, 2024
64e283a
load fp or animal model
Oct 4, 2024
acb90bd
fix variable name
Oct 6, 2024
f254e40
fix variable
Oct 6, 2024
7d87815
fix confusion
Oct 6, 2024
51ef0a3
add num frames
Oct 6, 2024
602d8e0
load params properly
Oct 6, 2024
4ecf684
set shuffle based on number of frames
Oct 6, 2024
5b7732c
do not sort
Oct 8, 2024
708afad
dont resample evenly some labels
Oct 8, 2024
f81ee53
add fp_frames
Oct 10, 2024
10762a0
read fp model predictions
Oct 14, 2024
9c65452
check for int
Oct 15, 2024
1ec5202
remove log
Oct 15, 2024
39754a4
add country code into tf records
Oct 16, 2024
7550768
remove some
Oct 17, 2024
c008df5
dont filte rby fp
Oct 17, 2024
88aa205
fix excluded
Oct 17, 2024
a8717ec
fix fine tune
Oct 18, 2024
84a5da6
add parsing
Oct 21, 2024
4def7a8
try limit memory
Oct 25, 2024
50faca5
dont validate bins for after date test clips
Oct 28, 2024
37c68e1
added start time
Oct 28, 2024
600af58
union set
Oct 29, 2024
43180b7
repeat frames at random rather than only last frame
Nov 5, 2024
7ca8817
less jobs
Nov 5, 2024
46b431b
add multiple segment type option
Nov 7, 2024
e9b51a9
try random section
Nov 11, 2024
25ae03b
add min path
Nov 11, 2024
7cd7efc
fix to small tracks
Nov 11, 2024
0ca2c93
added mask segment type as default
Nov 12, 2024
d39902e
tidy up
Nov 12, 2024
48a25e3
add check for none
Nov 13, 2024
52bfd99
tidy up
Nov 13, 2024
e27935b
Merge remote-tracking branch 'origin/master' into debugging
Nov 13, 2024
092280f
fix segment type load for old meta
Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pirequirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ scipy==1.9.3
python-dateutil
scikit-learn==1.1.3
tables==3.8.0
h5py==3.8.0
h5py==3.10.0
pyyaml==6.0
pillow==10.0.1
attrs==24.2.0
Expand All @@ -26,4 +26,4 @@ dbus-python==1.3.2
importlib_resources==5.10.2
opencv-python==4.8.0.76
inotify_simple==1.3.5
python-cptv==0.0.3
python-cptv==0.0.5
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
"importlib_resources==5.10.2",
"opencv-python==4.8.0.76",
"inotify_simple==1.3.5",
"python-cptv==0.0.3"
"python-cptv==0.0.5"
]

[project.scripts]
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tensorflow~=2.14.0
tensorflow~=2.17.0
matplotlib~=3.0
pytz
cptv~=1.5.4
Expand All @@ -7,7 +7,7 @@ scipy
python-dateutil
scikit-learn
tables~=3.8.0
h5py~=3.9.0
h5py~=3.10.0
pyyaml>=4.2b1
pillow~=10.0.1
attrs~=24.2.0
Expand All @@ -26,4 +26,4 @@ joblib
#requires sudo apt-get install libopencv-dev used for ir track extraction on server
# pybgs==3.2.0.post1 this was used for ir
inotify_simple==1.3.5
python-cptv==0.0.3
python-cptv==0.0.5
5 changes: 5 additions & 0 deletions src/autobuild-cron
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#run the first of every month
SHELL=/bin/bash
BASH_ENV=~/.bashrc_conda

* * 1 * * cp ( cd /home/cp/cacophony/classifier-pipeline/src && ./autobuild.sh /data2/cptv-files) 2>&1 | logger --tag classifier-auto-build
18 changes: 11 additions & 7 deletions src/autobuild.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#!/bin/sh

#!/bin/bash
set -e
set -x
conda init bash
conda activate tf
config="classifier-thermal.yaml"
month_ago=$(python3 rebuildDate.py -c $config)
echo "Saving into $1"
month_ago=$(python3 rebuildDate.py $1)
echo $month_ago
python3 ../../cptv-download/cptv-download.py -l 0 -i 'poor tracking' -i 'untagged' -i 'part' -i 'untagged-by-humans' -i 'unknown' -i 'unidentified' -m 'human-tagged' --start-date "$month_ago" "../clips$month_ago" [email protected] userpassword
echo "Downloading into ../clips$month_ago"
python3 load.py -target "../clips$month_ago" -c $config
python3 build.py -c $config
python3 ../../cptv-download/cptv-download.py -l 0 -i 'poor tracking' -i 'untagged' -i 'part' -i 'untagged-by-humans' -i 'unknown' -i 'unidentified' -m 'human-tagged' --start-date "$month_ago" "$1" [email protected] userpassword
echo "Downloading into $1"
python3 build.py -c $config --ext ".cptv" $1
dt=$(date '+%d%m%Y-%H%M%S');
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/cp/miniconda3/envs/tf/lib/
python3 train.py -c $config $dt
70 changes: 62 additions & 8 deletions src/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from ml_tools.tfwriter import create_tf_records
from ml_tools.irwriter import save_data as save_ir_data
from ml_tools.thermalwriter import save_data as save_thermal_data


from ml_tools.tools import CustomJSONEncoder
import attrs
import numpy as np

from pathlib import Path
Expand Down Expand Up @@ -57,7 +57,7 @@ def parse_args():
)
parser.add_argument("--split-file", help="Json file defining a split")
parser.add_argument(
"--ext", default=".hdf5", help="Extension of files to load .mp4,.cptv,.hdf5"
"--ext", default=".cptv", help="Extension of files to load .mp4,.cptv,.hdf5"
)

parser.add_argument("-c", "--config-file", help="Path to config file to use")
Expand Down Expand Up @@ -571,7 +571,7 @@ def add_samples(
dataset.add_samples(samples)


def validate_datasets(datasets, test_bins, date):
def validate_datasets(datasets, test_bins, after_date):
# check that clips are only in one dataset
# that only test set has clips after date
# that test set is the only dataset with test_clips
Expand All @@ -580,7 +580,7 @@ def validate_datasets(datasets, test_bins, date):
# for track in dataset.tracks:
# assert track.start_time < date

for i, dataset in enumerate(datasets):
for i, dataset in enumerate(datasets[:2]):
dont_check = set(
[
sample.bin_id
Expand Down Expand Up @@ -608,6 +608,15 @@ def validate_datasets(datasets, test_bins, date):
if sample.label in split_by_clip
]
)
if other.name == "test" and after_date is not None:
dont_check_other = set(
[
sample.bin_id
for sample in other.samples_by_id.values()
if sample.rec_time > after_date
]
)
dont_check = dont_check | dont_check_other
other_bins = set([sample.bin_id for sample in other.samples_by_id.values()])
other_bins = other_bins - dont_check
other_clips = set(
Expand Down Expand Up @@ -717,6 +726,42 @@ def dump_split_ids(datasets, out_file="datasplit.json"):
return


def rough_balance(datasets):
dev_threshold = 2000
logging.info("Roughly Balancing")
print_counts(*datasets)

for dataset in datasets:
lbl_counts = {}
counts = []
for label in dataset.labels:
label_count = len(dataset.samples_by_label.get(label, []))
lbl_counts[label] = label_count
counts.append(label_count)
counts.sort()
std_dev = np.std(counts)
logging.info("Counts are %s std dev %s", counts, std_dev)
if std_dev < dev_threshold or len(counts) <= 1:
logging.info("Not balancing")
continue
if len(counts) <= 2:
cap_at = counts[-2]
elif len(counts) < 7:
cap_at = counts[-2]
else:
cap_at = counts[-2]
logging.info("Capping dataset %s at %s", dataset.name, cap_at)
for lbl, count in lbl_counts.items():
if count <= cap_at:
continue
samples_to_remove = count - cap_at
by_labels = dataset.samples_by_label[lbl]
np.random.shuffle(by_labels)
for i in range(samples_to_remove):
dataset.remove_sample(by_labels[i])
print_counts(*datasets)


def main():
init_logging()
args = parse_args()
Expand Down Expand Up @@ -782,6 +827,8 @@ def main():
print("Splitting data set into train / validation")

datasets = split_randomly(master_dataset, config, args.date, test_clips)

rough_balance(datasets)
validate_datasets(datasets, test_clips, args.date)
dump_split_ids(datasets, record_dir / "datasplit.json")

Expand Down Expand Up @@ -849,15 +896,20 @@ def main():
{
"segment_frame_spacing": master_dataset.segment_spacing * 9,
"segment_width": master_dataset.segment_length,
"segment_type": master_dataset.segment_type,
"segment_types": master_dataset.segment_types,
"segment_min_avg_mass": master_dataset.segment_min_avg_mass,
"max_segments": master_dataset.max_segments,
"dont_filter_segment": True,
"skip_ffc": True,
"tag_precedence": config.load.tag_precedence,
"tag_precedence": config.build.tag_precedence,
"min_mass": master_dataset.min_frame_mass,
"thermal_diff_norm": config.build.thermal_diff_norm,
"filter_by_lq": master_dataset.filter_by_lq,
"max_frames": master_dataset.max_frames,
}
)
# dont filter the test set,
extra_args["filter_by_fp"] = dataset.name != "test"
create_tf_records(
dataset,
dir,
Expand All @@ -879,10 +931,12 @@ def main():
"type": config.train.type,
"counts": dataset_counts,
"by_label": False,
"config": attrs.asdict(config),
"segment_types": master_dataset.segment_types,
}

with open(meta_filename, "w") as f:
json.dump(meta_data, f, indent=4)
json.dump(meta_data, f, indent=4, cls=CustomJSONEncoder)


if __name__ == "__main__":
Expand Down
8 changes: 1 addition & 7 deletions src/classify/clipclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,8 @@
from track.clip import Clip
from track.cliptrackextractor import ClipTrackExtractor, is_affected_by_ffc
from ml_tools import tools
from ml_tools.kerasmodel import KerasModel
from track.irtrackextractor import IRTrackExtractor
from ml_tools.previewer import Previewer
from track.track import Track

from cptv import CPTVReader
from datetime import datetime
from ml_tools.interpreter import get_interpreter


Expand Down Expand Up @@ -134,7 +129,7 @@ def process_file(self, filename, cache=None, reuse_frames=None):
clip = Clip(track_extractor.config, filename)
clip.load_metadata(
meta_data,
self.config.load.tag_precedence,
self.config.build.tag_precedence,
)
track_extractor.parse_clip(clip)

Expand Down Expand Up @@ -250,7 +245,6 @@ def save_metadata(
prediction = predictions.prediction_for(track.get_id())
if prediction is None:
continue

prediction_meta = prediction.get_metadata()
prediction_meta["model_id"] = model_id
prediction_info.append(prediction_meta)
Expand Down
19 changes: 13 additions & 6 deletions src/classify/trackprediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def clarity(self):
best = np.argsort(self.prediction)
return self.prediction[best[-1]] - self.prediction[best[-2]]

def __str__(self):
return f"{self.frames} conf: {np.round(100*self.prediction)}"


class TrackPrediction:
"""
Expand Down Expand Up @@ -107,18 +110,23 @@ def __init__(self, track_id, labels, keep_all=True, start_frame=None):
self.masses = []

def classified_clip(
self, predictions, smoothed_predictions, prediction_frames, top_score=None
self,
predictions,
smoothed_predictions,
prediction_frames,
masses,
top_score=None,
):
self.num_frames_classified = len(predictions)
for prediction, smoothed_prediction, frames in zip(
predictions, smoothed_predictions, prediction_frames
for prediction, smoothed_prediction, frames, mass in zip(
predictions, smoothed_predictions, prediction_frames, masses
):
prediction = Prediction(
prediction,
smoothed_prediction,
frames,
np.amax(frames),
None,
mass,
)
self.predictions.append(prediction)

Expand Down Expand Up @@ -162,11 +170,10 @@ def classified_frames(self, frame_numbers, predictions, mass):
self.class_best_score += smoothed_prediction

def classified_frame(self, frame_number, predictions, mass):
self.prediction_frames.append([frame_number])
self.last_frame_classified = frame_number
self.num_frames_classified += 1
self.masses.append(mass)
smoothed_prediction = prediction * prediction * mass
smoothed_prediction = predictions**2 * mass

prediction = Prediction(
predictions,
Expand Down
54 changes: 53 additions & 1 deletion src/config/buildconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
from os import path
from .defaultconfig import DefaultConfig
from ml_tools.rectangle import Rectangle


@attr.s
Expand All @@ -34,6 +35,45 @@ class BuildConfig(DefaultConfig):
min_frame_mass = attr.ib()
filter_by_lq = attr.ib()
max_segments = attr.ib()
thermal_diff_norm = attr.ib()
tag_precedence = attr.ib()
excluded_tags = attr.ib()
country = attr.ib()
use_segments = attr.ib()
max_frames = attr.ib()

EXCLUDED_TAGS = ["poor tracking", "part", "untagged", "unidentified"]
NO_MIN_FRAMES = ["stoat", "mustelid", "weasel", "ferret"]
# country bounding boxs
COUNTRY_LOCATIONS = {
"AU": Rectangle.from_ltrb(
113.338953078, -10.6681857235, 153.569469029, -43.6345972634
),
"NZ": Rectangle.from_ltrb(
166.509144322, -34.4506617165, 178.517093541, -46.641235447
),
}

DEFAULT_GROUPS = {
0: [
"bird",
"false-positive",
"hedgehog",
"possum",
"rodent",
"mustelid",
"cat",
"kiwi",
"dog",
"leporidae",
"human",
"insect",
"pest",
],
1: ["unidentified", "other"],
2: ["part", "bad track"],
3: ["default"],
}

@classmethod
def load(cls, build):
Expand All @@ -46,6 +86,12 @@ def load(cls, build):
min_frame_mass=build["min_frame_mass"],
filter_by_lq=build["filter_by_lq"],
max_segments=build["max_segments"],
thermal_diff_norm=build["thermal_diff_norm"],
tag_precedence=build["tag_precedence"],
excluded_tags=build["excluded_tags"],
country=build["country"],
use_segments=build["use_segments"],
max_frames=build["max_frames"],
)

@classmethod
Expand All @@ -58,7 +104,13 @@ def get_defaults(cls):
segment_min_avg_mass=10,
min_frame_mass=10,
filter_by_lq=False,
max_segments=5,
max_segments=3,
thermal_diff_norm=False,
tag_precedence=BuildConfig.DEFAULT_GROUPS,
excluded_tags=BuildConfig.EXCLUDED_TAGS,
country=None,
use_segments=True,
max_frames=75,
)

def validate(self):
Expand Down
Loading