diff --git a/custom_u2net_test.py b/custom_u2net_test.py new file mode 100644 index 00000000..ea9149bf --- /dev/null +++ b/custom_u2net_test.py @@ -0,0 +1,95 @@ +import os +from skimage import io, transform +import torch +import torchvision +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms#, utils + +import numpy as np +from PIL import Image +import glob + +from data_loader import RescaleT +from data_loader import ToTensor +from data_loader import ToTensorLab +from data_loader import SalObjDataset + +from model import U2NET # full size version 173.6 MB +from model import U2NETP # small version u2net 4.7 MB + + +def test_model(model): + + + # ------- 1. set the directory of test dataset -------- + + test_data_dir = os.path.join(os.getcwd(), 'my_data' + os.sep) + test_image_dir = os.path.join('TDP_test_dataset','TDP_IMAGES' + os.sep) + test_label_dir = os.path.join('TDP_test_dataset','TDP_MASKS' + os.sep) + + + + image_ext = '.jpg' + label_ext = '.png' + + batch_size_val = 1 + + test_img_name_list = glob.glob(test_data_dir + test_image_dir + '*' + image_ext) + + test_lbl_name_list = [] + for img_path in test_img_name_list: + img_name = img_path.split(os.sep)[-1] + + aaa = img_name.split(".") + bbb = aaa[0:-1] + imidx = bbb[0] + for i in range(1,len(bbb)): + imidx = imidx + "." + bbb[i] + + test_lbl_name_list.append(test_data_dir + test_label_dir + imidx + label_ext) + + test_salobj_dataset = SalObjDataset( + img_name_list=test_img_name_list, + lbl_name_list=test_lbl_name_list, + transform=transforms.Compose([ + RescaleT(320), + ToTensorLab(flag=0)])) + test_salobj_dataloader = DataLoader(test_salobj_dataset, + batch_size=batch_size_val, + shuffle=False, + num_workers=1) + + # --------- 2. test process --------- + if torch.cuda.is_available(): + model.cuda() + model.eval() + total_pixels = 0 + correct_pixels = 0 + accuracy = 0 + + with torch.no_grad(): + for i, data in enumerate(test_salobj_dataloader): + inputs, labels = data['image'], data['label'] + + inputs = inputs.type(torch.FloatTensor) + labels = labels.type(torch.FloatTensor) + + if torch.cuda.is_available(): + inputs, labels = inputs.cuda(), labels.cuda() + + outputs = model(inputs) + predicted_masks = (outputs[0] > 0.5).float() + + total_pixels = labels.numel() + correct_pixels = (predicted_masks == labels).sum().item() + accuracy += (correct_pixels/total_pixels)*100 + + avr_accuracy=accuracy/len(test_salobj_dataloader) + print(f'avarage_accuracy: {avr_accuracy}%') + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/custom_u2net_train.py b/custom_u2net_train.py new file mode 100644 index 00000000..a51faa91 --- /dev/null +++ b/custom_u2net_train.py @@ -0,0 +1,199 @@ +import os +import torch +import torchvision +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F + +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +import torch.optim as optim +import torchvision.transforms as standard_transforms + +import numpy as np +import glob +import os + +from data_loader import Rescale +from data_loader import RescaleT +from data_loader import RandomCrop +from data_loader import ToTensor +from data_loader import ToTensorLab +from data_loader import SalObjDataset + +from model import U2NET +from model import U2NETP +from tqdm import tqdm +from model_processing.prepare_model import get_latest_model, get_latest_version +from model_processing.convert_model import convert_model_to_onnx +from custom_u2net_test import test_model +from model_processing.upload_model_to_S3bucket import upload_folder_to_s3 + + + +def main(): + # ------- 1. define loss function -------- + + bce_loss = nn.BCELoss(size_average=True) + + def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): + # Convert all tensors to torch.float32 + d0, d1, d2, d3, d4, d5, d6, labels_v = ( + d0.float(), d1.float(), d2.float(), d3.float(), d4.float(), d5.float(), d6.float(), labels_v.float() + ) + + loss0 = bce_loss(d0, labels_v) + loss1 = bce_loss(d1, labels_v) + loss2 = bce_loss(d2, labels_v) + loss3 = bce_loss(d3, labels_v) + loss4 = bce_loss(d4, labels_v) + loss5 = bce_loss(d5, labels_v) + loss6 = bce_loss(d6, labels_v) + + loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % ( + loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(), + loss6.data.item())) + + return loss0, loss + + + # ------- 2. set the directory of training dataset -------- + + model_name = 'u2net' #'u2netp' + + data_dir = os.path.join(os.getcwd(), 'my_data' + os.sep) + tra_image_dir = os.path.join('TDP_train_dataset','TDP_IMAGES' + os.sep) + tra_label_dir = os.path.join('TDP_train_dataset','TDP_MASKS' + os.sep) + + image_ext = '.jpg' + label_ext = '.png' + + model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep) + + epoch_num = 100 + batch_size_train = 32 + train_num = 0 + + tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext) + + tra_lbl_name_list = [] + for img_path in tra_img_name_list: + img_name = img_path.split(os.sep)[-1] + + aaa = img_name.split(".") + bbb = aaa[0:-1] + imidx = bbb[0] + for i in range(1,len(bbb)): + imidx = imidx + "." + bbb[i] + + tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext) + + print("---") + print("train images: ", len(tra_img_name_list)) + print("train labels: ", len(tra_lbl_name_list)) + print("---") + + train_num = len(tra_img_name_list) + + salobj_dataset = SalObjDataset( + img_name_list=tra_img_name_list, + lbl_name_list=tra_lbl_name_list, + transform=transforms.Compose([ + RescaleT(320), + RandomCrop(288), + ToTensorLab(flag=0)])) + salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1) + + # ------- 3. define model -------- + # define the net + if(model_name=='u2net'): + net = U2NET(3, 1) + elif(model_name=='u2netp'): + net = U2NETP(3,1) + + if torch.cuda.is_available(): + net.cuda() + + # ------- 4. define optimizer -------- + print("---define optimizer...") + optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + # ------- 5. training process -------- + print("---start training...") + ite_num = 0 + running_loss = 0.0 + running_tar_loss = 0.0 + ite_num4val = 0 + # save_frq = 2000 # save the model every 2000 iterations + # Check if there is a pre-trained model to load + pretrained_model_path = get_latest_model("saved_models/u2net") + + if os.path.exists(pretrained_model_path): + # Load the pre-trained model + net.load_state_dict(torch.load(pretrained_model_path)) + print(f"Pre-trained model loaded from {pretrained_model_path}") + else: + print("No pre-trained model found. Training from scratch.") + + for epoch in tqdm(range(0, epoch_num)): + net.train() + + for i, data in tqdm(enumerate(salobj_dataloader)): + ite_num = ite_num + 1 + ite_num4val = ite_num4val + 1 + + inputs, labels = data['image'], data['label'] + + inputs = inputs.type(torch.FloatTensor) + labels = labels.type(torch.FloatTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), + requires_grad=False) + else: + inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False) + + # y zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + d0, d1, d2, d3, d4, d5, d6 = net(inputs_v) + loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v) + + loss.backward() + optimizer.step() + + # # print statistics + running_loss += loss.data.item() + running_tar_loss += loss2.data.item() + + # del temporary outputs and loss + del d0, d1, d2, d3, d4, d5, d6, loss2, loss + + print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % ( + epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)) + + #test_model + test_model(net) + + # Save model + latest_version = get_latest_version(pretrained_model_path) + if latest_version.isdigit(): + model_name = f"u2net_version_{int(latest_version)+1}.pth" + else: + model_name = f"u2net_version_1.pth" + torch.save(net.state_dict(), model_dir + model_name) + print(f"Final model saved as {model_name}") + + #convert model to onnx + latest_model= os.path.join(model_dir,model_name) + convert_model_to_onnx(latest_model) + + #upload model to AWS S3 bucket + upload_folder_to_s3(model_dir,'tdp-model') + +if __name__ == '__main__': + main() + diff --git a/model_processing/convert_model.py b/model_processing/convert_model.py new file mode 100644 index 00000000..940d09b8 --- /dev/null +++ b/model_processing/convert_model.py @@ -0,0 +1,30 @@ +import io +import torch.onnx +from model import U2NET +import os +from model_processing.prepare_model import get_latest_version + +def convert_model_to_onnx(model_path): + torch_model = U2NET(3,1) + batch_size = 1 + + torch_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + torch_model.eval() + + x = torch.randn(batch_size, 3, 512, 512, requires_grad=True) + last_character = get_latest_version(model_path) + model_dir = "saved_models/ABR_model" + if not os.path.exists(model_dir): + os.makedirs(model_dir) + torch.onnx.export(torch_model, x,os.path.join(model_dir,f"ARB_version_{int(last_character)}.onnx"), + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names = ['input'], + output_names = ['output'], + dynamic_axes = {'input' : {0: 'batch_size'}, 'output': {0: 'batch_size'}} + ) + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/model_processing/prepare_model.py b/model_processing/prepare_model.py new file mode 100644 index 00000000..956a2499 --- /dev/null +++ b/model_processing/prepare_model.py @@ -0,0 +1,16 @@ +import os + +def get_latest_model(model_folder): + model_list =[file for file in os.listdir(model_folder) if file.startswith("u2net_version_")] + if not model_list: + latest_model = 'u2net.pth' + else: + sorted_model_files = sorted(model_list, key=lambda x: int(x[len("u2net_version_"):-len(".pth")])) + latest_model = sorted_model_files[-1] + return os.path.join(model_folder,latest_model) + + +def get_latest_version(model_path): + model_name = model_path.split("/")[-1].split(".")[0] + latest_version = model_name[-1] + return latest_version diff --git a/model_processing/upload_model_to_S3bucket.py b/model_processing/upload_model_to_S3bucket.py new file mode 100644 index 00000000..46847910 --- /dev/null +++ b/model_processing/upload_model_to_S3bucket.py @@ -0,0 +1,27 @@ +import os +from dotenv import load_dotenv +import boto3 + +# Load environment variables from .env +load_dotenv() + +def upload_folder_to_s3(local_folder,bucket_name): + # Retrieve AWS credentials from environment variables + aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID') + aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY') + aws_region = os.environ.get('AWS_REGION') + + # Create an S3 client + s3 = boto3.client('s3', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region) + + # Iterate through each file in the local folder + for root, dirs, files in os.walk(local_folder): + for file in files: + local_file_path = os.path.join(root, file) + s3_key = os.path.relpath(local_file_path, local_folder).replace("\\", "/") + + try: + # Upload the file to S3 + s3.upload_file(local_file_path, bucket_name, s3_key) + except Exception as e: + print(f"Error uploading {local_file_path}: {e}") \ No newline at end of file diff --git a/u2net_train.py b/u2net_train.py index 8f19491f..45a09dd6 100644 --- a/u2net_train.py +++ b/u2net_train.py @@ -48,17 +48,17 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): model_name = 'u2net' #'u2netp' -data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep) -tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep) -tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep) +data_dir = os.path.join(os.getcwd(), '/content/Train_data' + os.sep) +tra_image_dir = os.path.join('TDP_IMAGES' + os.sep) +tra_label_dir = os.path.join('TDP_MASKS' + os.sep) image_ext = '.jpg' label_ext = '.png' -model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep) +model_dir = os.path.join(os.getcwd(), '/content/U-2-Net/saved_models', model_name + os.sep) -epoch_num = 100000 -batch_size_train = 12 +epoch_num = 20 +batch_size_train = 16 batch_size_val = 1 train_num = 0 val_num = 0 @@ -113,7 +113,16 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): running_loss = 0.0 running_tar_loss = 0.0 ite_num4val = 0 -save_frq = 2000 # save the model every 2000 iterations +# save_frq = 2000 # save the model every 2000 iterations +# Check if there is a pre-trained model to load +pretrained_model_path = "/content/U-2-Net/saved_models/u2net/u2net.pth" + +if os.path.exists(pretrained_model_path): + # Load the pre-trained model + net.load_state_dict(torch.load(pretrained_model_path)) + print(f"Pre-trained model loaded from {pretrained_model_path}") +else: + print("No pre-trained model found. Training from scratch.") for epoch in range(0, epoch_num): net.train() @@ -154,11 +163,15 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % ( epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)) - if ite_num % save_frq == 0: + # if ite_num % save_frq == 0: - torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)) - running_loss = 0.0 - running_tar_loss = 0.0 - net.train() # resume train - ite_num4val = 0 + # torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)) + # running_loss = 0.0 + # running_tar_loss = 0.0 + # net.train() # resume train + # ite_num4val = 0 +# Save the final model after the entire training process +final_model_name = f"{model_name}_final_itr_{ite_num}_train_{running_loss / ite_num4val:.3f}_tar_{running_tar_loss / ite_num4val:.3f}.pth" +torch.save(net.state_dict(), model_dir + final_model_name) +print(f"Final model saved as {final_model_name}") \ No newline at end of file