-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
yolo object detection example code updated
- Loading branch information
1 parent
45cee39
commit c4499fe
Showing
10 changed files
with
299 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ data | |
*.tgz | ||
*.tar.gz | ||
.mnist-keras | ||
client.yaml | ||
client.yaml | ||
.darknet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,10 @@ | ||
python_env : python_env.yaml | ||
entry_points: | ||
build: | ||
command: python3 entrypoint.py init_seed | ||
command: python model.py | ||
startup: | ||
command: python get_data.py | ||
train: | ||
command: python entrypoint.py train | ||
command: python train.py | ||
validate: | ||
command: python entrypoint.py validate | ||
predict: | ||
command: python entrypoint.py predict | ||
command: python validate.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,79 @@ | ||
import os | ||
import os,sys | ||
from math import floor | ||
import random | ||
|
||
import glob | ||
import numpy as np | ||
import zipfile | ||
import subprocess | ||
import requests | ||
|
||
|
||
def splitset(dataset, parts): | ||
n = dataset.shape[0] | ||
local_n = floor(n / parts) | ||
result = [] | ||
for i in range(parts): | ||
result.append(dataset[i * local_n : (i + 1) * local_n]) | ||
return np.array(result) | ||
def write_list(array, fname): | ||
textfile = open(fname, "w") | ||
for element in array: | ||
textfile.write(f"{ds_path}/{element}\n") | ||
textfile.close() | ||
|
||
|
||
def split(dataset="yolo_computer/data/combine_data", outdir="data", n_splits=1): | ||
|
||
#Splitting the images into train and test | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
def download_blob(): | ||
|
||
url = "https://storage.googleapis.com/public-scaleout/Yolo-object-detection/data.zip" | ||
response = requests.get(url) | ||
|
||
with open(dir_path + "/data/1/data.zip", "wb") as f: | ||
f.write(response.content) | ||
zip_file_path = dir_path + "/data/1/data.zip" | ||
extract_to_path = dir_path + "/data/1/" | ||
|
||
import glob, os | ||
os.makedirs(extract_to_path, exist_ok=True) | ||
|
||
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | ||
zip_ref.extractall(extract_to_path) | ||
|
||
def train_test_creation(dataset=dir_path + "/data/1/data/", outdir=dir_path + "/data/1/", test_percentage=10): | ||
|
||
current_dir = dataset | ||
|
||
|
||
# Percentage of images to be used for the test set | ||
|
||
percentage_test = 10 | ||
|
||
|
||
|
||
# Create and/or truncate train.txt and test.txt | ||
|
||
file_train = open('data/train.txt', 'a') | ||
|
||
file_test = open('data/val.txt', 'a') | ||
|
||
import glob | ||
|
||
images_list1 = glob.glob("/home/sowmya/yolo_computer/data/combine_data/*.jpg") | ||
|
||
images_list2 = glob.glob("/home/sowmya/yolo_computer/data/combine_data/*.png") | ||
|
||
images_list3 = glob.glob("/home/sowmya/yolo_computer/data/combine_data/*.jpeg") | ||
percentage_test = test_percentage | ||
file_train = open(outdir+'train.txt', 'w') | ||
file_test = open(outdir+'val.txt', 'w') | ||
images_list1 = glob.glob(dataset+"*.jpg") | ||
images_list2 = glob.glob(dataset+"*.png") | ||
images_list3 = glob.glob(dataset+"*.jpeg") | ||
images_list = images_list1 + images_list2 + images_list3 | ||
|
||
|
||
print(images_list) | ||
|
||
# types = ('*.jpg', '*.png', "*jpeg") | ||
|
||
# Populate train.txt and test.txt | ||
|
||
counter = 1 | ||
|
||
index_test = round(100 / percentage_test) | ||
|
||
# for pathAndFilename in glob.iglob(os.path.join(current_dir, types)): | ||
|
||
# title, ext = os.path.splitext(os.path.basename(pathAndFilename)) | ||
|
||
|
||
# file = open("data/train.txt", "w") | ||
|
||
for id, name in enumerate(images_list): | ||
# file_train.write("\n".join(images_list)) | ||
# file.close() | ||
if counter == index_test: | ||
counter = 1 | ||
#print('in') | ||
|
||
file_test.write(name + "\n") | ||
else: | ||
# print('in') | ||
|
||
file_train.write(name + "\n") | ||
counter = counter + 1 | ||
|
||
|
||
def get_data(out_dir="data"): | ||
# Make dir if necessary | ||
if not os.path.exists(out_dir): | ||
os.mkdir(out_dir) | ||
|
||
|
||
def update_all_keys_in_obj_data(file_path, updates_dict): | ||
with open(file_path, 'r') as file: | ||
lines = file.readlines() | ||
|
||
with open(file_path, 'w') as file: | ||
for line in lines: | ||
if '=' in line: | ||
key, value = line.split('=', 1) | ||
key = key.strip() | ||
if key in updates_dict: | ||
file.write(f"{key} = {updates_dict[key]}\n") | ||
else: | ||
file.write(line) | ||
else: | ||
file.write(line) | ||
|
||
if __name__ == "__main__": | ||
pass | ||
#get_data() | ||
#split() | ||
download_blob() | ||
train_test_creation() | ||
obj_data_file = dir_path + "/data/1/obj.data" # Path to your obj.data file | ||
updates_dict = { | ||
"classes": "1", | ||
"train": dir_path + "/data/1/train.txt", | ||
"valid":dir_path + "/data/1/val.txt", | ||
"names": dir_path + "/data/1/obj.names", | ||
"backup": dir_path + "/data/1/yolov4_tiny" | ||
} | ||
|
||
update_all_keys_in_obj_data(obj_data_file, updates_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
from math import floor | ||
import numpy as np | ||
import subprocess | ||
|
||
from fedn.utils.helpers.helpers import get_helper | ||
|
||
HELPER_MODULE = "numpyhelper" | ||
helper = get_helper(HELPER_MODULE) | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
def git_clone(repo_url="https://github.com/AlexeyAB/darknet.git", clone_dir="../.darknet"): | ||
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../.darknet")) | ||
if clone_dir: | ||
target_path = clone_dir | ||
else: | ||
target_path = parent_dir | ||
command = ["git", "clone", repo_url,target_path] | ||
|
||
try: | ||
subprocess.run(command, check=True) | ||
print(f"Successfully cloned {repo_url}") | ||
except subprocess.CalledProcessError as e: | ||
print(f"Error during cloning: {e}") | ||
def init_seed(out_path="../seed.npz"): | ||
"""Initialize seed model and save it to file. | ||
:param out_path: The path to save the seed model to. | ||
:type out_path: str | ||
""" | ||
darkfile="../yolov4-tiny.weights" | ||
fp = open(darkfile, "rb") | ||
header=np.fromfile(fp,dtype=np.int32,count=5) | ||
buf = np.fromfile(fp, dtype=np.float32) | ||
helper.save([buf], out_path) | ||
fp.close() | ||
if __name__ == "__main__": | ||
init_seed("../seed.npz") | ||
git_clone() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ build_dependencies: | |
- setuptools | ||
- wheel | ||
dependencies: | ||
- tensorflow>=2.13.1 | ||
- fire==0.3.1 | ||
- fedn | ||
- numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
|
||
|
||
import json | ||
import os | ||
import fire | ||
import numpy as np | ||
import json | ||
import re | ||
import subprocess | ||
from fedn.utils.dist import get_package_path | ||
from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics | ||
|
||
HELPER_MODULE = "numpyhelper" | ||
helper = get_helper(HELPER_MODULE) | ||
|
||
NUM_CLASSES = 1 | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
abs_path = os.path.abspath(dir_path) | ||
def save_darknet2fedn(darkfile, fednfile ): | ||
fp = open(darkfile, "rb") | ||
header=np.fromfile(fp,dtype=np.int32,count=5) | ||
with open('sow.json', 'w') as f: | ||
header={"header":header.tolist()} | ||
json.dump(header, f) | ||
buf = np.fromfile(fp, dtype=np.float32) | ||
helper.save([buf], fednfile) | ||
fp.close() | ||
def save_fedn2darknet(fednfile, darkfile): | ||
|
||
buf = helper.load(fednfile)[0] | ||
if os.path.exists("sow.json"): | ||
with open('sow.json') as f: | ||
header_data = json.load(f) | ||
image_seen=header_data.get('header')[3] | ||
else: | ||
image_seen=0 | ||
with open(darkfile, "wb") as f: | ||
header = np.array([0,2, 5, 0, 0],dtype=np.int32) | ||
header.tofile(f) | ||
buf.tofile(f) | ||
def number_of_lines(file): | ||
with open(file, "r") as f: | ||
lines = f.readlines() | ||
line_count=len(lines) | ||
return line_count | ||
def train(in_model_path, out_model_path, data_path=None, batch_size=64, epochs=1): | ||
"""Complete a model update. | ||
Load model paramters from in_model_path (managed by the FEDn client), | ||
perform a model update, and write updated paramters | ||
to out_model_path (picked up by the FEDn client). | ||
:param in_model_path: The path to the input model. | ||
:type in_model_path: str | ||
:param out_model_path: The path to save the output model to. | ||
:type out_model_path: str | ||
""" | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
os.chdir("../../.darknet") | ||
data_path = _get_data_path() | ||
darkfile = dir_path + "/data/1/example.weights" | ||
save_fedn2darknet(in_model_path, darkfile) | ||
data_file=dir_path + "/data/1/obj.data" | ||
cfg_file = dir_path+"/data/1/yolov4-tiny.cfg" | ||
darknet_path = "./darknet" # Make sure this path is correct | ||
yolo_converted_weights = dir_path + "/data/1/example.weights" | ||
command = [darknet_path, "detector", "train", data_file, cfg_file, yolo_converted_weights,"-dont_show"] | ||
try: | ||
subprocess.run(command, check=True) | ||
except subprocess.CalledProcessError as e: | ||
print(f"Error during training: {e}") | ||
number_of_examples=number_of_lines(dir_path+"/data/1/train.txt") | ||
metadata = { | ||
# num_examples are mandatory | ||
"num_examples": number_of_examples, | ||
"batch_size": 64, | ||
"epochs": 1, | ||
"lr": 0.01, | ||
} | ||
save_metadata(metadata, out_model_path) | ||
save_darknet2fedn(dir_path+"/data/1/yolov4_tiny/yolov4-tiny_final.weights", out_model_path) | ||
if __name__ == "__main__": | ||
train(sys.argv[1], sys.argv[2]) |
Oops, something went wrong.