Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixing delete trained model to prevent deleting base trained mod… #255

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions MainApp.BL/Services/TrainingServices/TrainingRunService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,20 @@ public async Task<ResultDTO> StartTrainingRun(Guid trainingRunId)
Directory.CreateDirectory(detectionCliLogsAbsPath);
}

BufferedCommandResult powerShellResults = await Cli.Wrap(_MMDetectionConfiguration.GetCondaExeAbsPath())
.WithWorkingDirectory(_MMDetectionConfiguration.GetRootDirAbsPath())
string? condaExeAbsPath = _MMDetectionConfiguration.GetCondaExeAbsPath();
if (!File.Exists(condaExeAbsPath))
return ResultDTO.Fail($"Conda.exe file does not exist on this path: {condaExeAbsPath}");

string? rootDirAbsPath = _MMDetectionConfiguration.GetRootDirAbsPath();
if (!Directory.Exists(rootDirAbsPath))
return ResultDTO.Fail($"Root directory does not exist on this path: {rootDirAbsPath}");

BufferedCommandResult powerShellResults = await Cli.Wrap(condaExeAbsPath)
.WithWorkingDirectory(rootDirAbsPath)
.WithValidation(CommandResultValidation.None)
.WithArguments(trainingCommand.ToLower())
.WithStandardOutputPipe(PipeTarget.ToFile(Path.Combine(_MMDetectionConfiguration.GetTrainingRunCliOutDirAbsPath(), $"succ_{trainingRunId}.txt")))
.WithStandardErrorPipe(PipeTarget.ToFile(Path.Combine(_MMDetectionConfiguration.GetTrainingRunCliOutDirAbsPath(), $"error_{trainingRunId}.txt")))
.WithStandardOutputPipe(PipeTarget.ToFile(Path.Combine(trainingCliLogsAbsPath, $"succ_{trainingRunId}.txt")))
.WithStandardErrorPipe(PipeTarget.ToFile(Path.Combine(trainingCliLogsAbsPath, $"error_{trainingRunId}.txt")))
.ExecuteBufferedAsync();

if (powerShellResults.IsSuccess == false)
Expand Down Expand Up @@ -643,6 +651,34 @@ public async Task<ResultDTO> DeleteTrainingRun(Guid trainingRunId, string wwwroo
}
}

//delete config folder for training run from mmdetection
string? configTrainingRunFolder = _MMDetectionConfiguration.GetTrainingRunConfigDirAbsPathByRunId(trainingRunId);
if (Directory.Exists(configTrainingRunFolder))
{
try
{
Directory.Delete(configTrainingRunFolder, recursive: true);
}
catch (Exception ex)
{
return ResultDTO.Fail($"Failed to delete folder: {ex.Message}");
}
}

//delete data folder for training run from mmdetection
string datasetTrainingRunFolder = _MMDetectionConfiguration.GetTrainingRunDatasetDirAbsPath(trainingRunId);
if(Directory.Exists(datasetTrainingRunFolder))
{
try
{
Directory.Delete(datasetTrainingRunFolder, recursive: true);
}
catch (Exception ex)
{
return ResultDTO.Fail($"Failed to delete folder: {ex.Message}");
}
}

//get trained model entity
ResultDTO<TrainedModel?>? resultGetTrainedModel = await _trainedModelsRepository.GetById(resultGetEntity.Data.TrainedModelId!.Value, track: true);
if (!resultGetTrainedModel.IsSuccess && resultGetTrainedModel.HandleError())
Expand All @@ -661,10 +697,16 @@ public async Task<ResultDTO> DeleteTrainingRun(Guid trainingRunId, string wwwroo
List<Guid?>? trainingModelIdsList = resultGetAllTrainingRuns.Data.Where(x => x.Id != trainingRunId).Select(x => x.TrainedModelId).ToList();
if (!trainingModelIdsList.Contains(resultGetTrainedModel.Data.Id))
{
//detele trained model from db if it is not contained in other training runs
ResultDTO? resultDeleteTrainedModel = await _trainedModelsRepository.Delete(resultGetTrainedModel.Data);
if (!resultDeleteTrainedModel.IsSuccess && resultDeleteTrainedModel.HandleError())
return ResultDTO.Fail(resultDeleteTrainedModel.ErrMsg!);
//check if the trained model is not base model to prevent deleting base trained model
if(resultGetTrainedModel.Data.BaseModelId != null)
{
//detele trained model from db if it is not contained in other training runs
ResultDTO? resultDeleteTrainedModel = await _trainedModelsRepository.Delete(resultGetTrainedModel.Data);
if (!resultDeleteTrainedModel.IsSuccess && resultDeleteTrainedModel.HandleError())
return ResultDTO.Fail(resultDeleteTrainedModel.ErrMsg!);
}


}

//delete training run from db
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,21 @@ public async Task<ResultDTO> DeleteDetectionRun(Guid detectionRunId)
}
}

JobList<ProcessingJobDto>? processingJobs = monitoringApi.ProcessingJobs(0, int.MaxValue);
if (processingJobs == null)
return ResultDTO.Fail("Processing jobs not found");

foreach (KeyValuePair<string, ProcessingJobDto> job in processingJobs)
{
string jobId = job.Key;
using (IStorageConnection connection = JobStorage.Current.GetConnection())
{
string storedKey = connection.GetJobParameter(jobId, "detectionRunId");
if (storedKey == detectionRunId.ToString())
return ResultDTO.Fail("Can not delete detection run because it is in process");
}
}

ResultDTO resultDeleteEntity = await _detectionRunService.DeleteDetectionRun(detectionRunId, _webHostEnvironment.WebRootPath);
if (!resultDeleteEntity.IsSuccess && resultDeleteEntity.HandleError())
return ResultDTO.Fail(resultDeleteEntity.ErrMsg!);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public TrainingRunsController(ITrainingRunService trainingRunService,
_configuration = configuration;
_mapper = mapper;
}

[HttpGet]
[HasAuthClaim(nameof(SD.AuthClaims.ViewTrainingRuns))]
public async Task<IActionResult> Index()
Expand All @@ -81,7 +81,7 @@ public async Task<IActionResult> Index()
{
return HandleErrorRedirect("ErrorViewsPath:Error", 400);
}

}

[HttpGet]
Expand Down Expand Up @@ -181,6 +181,7 @@ public async Task<ResultDTO> ScheduleTrainingRun(TrainingRunViewModel viewModel)
[HasAuthClaim(nameof(SD.AuthClaims.ScheduleTrainingRun))]
public async Task<ResultDTO> ExecuteTrainingRunProcess(TrainingRunDTO trainingRunDTO, TrainingRunTrainParamsDTO trainingRunTrainParamsDTO)
{
bool result = false;
try
{
//int numEpochs = 1;
Expand Down Expand Up @@ -238,6 +239,7 @@ public async Task<ResultDTO> ExecuteTrainingRunProcess(TrainingRunDTO trainingRu
if (updateTrainRunResultSuccess.IsSuccess == false && updateTrainRunResultSuccess.HandleError())
return ResultDTO.Fail(updateTrainRunResultSuccess.ErrMsg!);

result = true;
return ResultDTO.Ok();
}
catch (Exception ex)
Expand All @@ -246,7 +248,67 @@ public async Task<ResultDTO> ExecuteTrainingRunProcess(TrainingRunDTO trainingRu
}
finally
{
// TODO: Clean Up Training Run Files, Later
//Clean Up Training Run Files

//1. delete data folder for training run
string? datasetTrainingRunFolder = _MMDetectionConfiguration.GetTrainingRunDatasetDirAbsPath(trainingRunDTO.Id!.Value);
if (Directory.Exists(datasetTrainingRunFolder))
Directory.Delete(datasetTrainingRunFolder, recursive: true);

//2. delete config folder for training run from mmdetection
string? configTrainingRunFolder = _MMDetectionConfiguration.GetTrainingRunConfigDirAbsPathByRunId(trainingRunDTO.Id!.Value);
if (Directory.Exists(configTrainingRunFolder))
Directory.Delete(configTrainingRunFolder, recursive: true);

//3. delete epoches all if failed, except best if successfull (.pth only)
string? trainingRunFolderPath = Path.Combine(_MMDetectionConfiguration.GetTrainingRunsBaseOutDirAbsPath(), trainingRunDTO.Id!.Value.ToString());
if (Directory.Exists(trainingRunFolderPath))
{
//get all .pth files
string[]? pthFiles = Directory.GetFiles(trainingRunFolderPath, "*.pth", SearchOption.TopDirectoryOnly);
if (pthFiles != null && pthFiles.Length > 0)
{
//check the result of try catch block
if (result)
{
//get best epoch
ResultDTO<TrainingRunResultsDTO>? resultGetBestEpoch = _trainingRunService.GetBestEpochForTrainingRun(trainingRunDTO.Id!.Value);
if (resultGetBestEpoch.IsSuccess && resultGetBestEpoch.Data != null)
{
int bestEpoch = resultGetBestEpoch.Data.BestEpochMetrics.Epoch;
foreach (string? file in pthFiles)
{
if (file != null)
{
string? fileName = Path.GetFileNameWithoutExtension(file);
if (fileName != null && fileName.StartsWith("epoch_"))
{
string? numberPart = fileName.Substring("epoch_".Length);
if (numberPart != null && int.TryParse(numberPart, out int epochNumber))
{
//delete all .pth files except best epoch
if (epochNumber != bestEpoch && System.IO.File.Exists(file))
{
System.IO.File.Delete(file);
}
}

}
}
}
}
}
else
{
//result is failed so delete all .pth files
foreach (string? file in pthFiles)
{
if (System.IO.File.Exists(file))
System.IO.File.Delete(file);
}
}
}
}
}
}

Expand Down Expand Up @@ -418,6 +480,21 @@ public async Task<ResultDTO> DeleteTrainingRun(Guid trainingRunId)
}
}

JobList<ProcessingJobDto>? processingJobs = monitoringApi.ProcessingJobs(0, int.MaxValue);
if (processingJobs == null)
return ResultDTO.Fail("Processing jobs not found");

foreach (KeyValuePair<string, ProcessingJobDto> job in processingJobs)
{
string jobId = job.Key;
using (IStorageConnection connection = JobStorage.Current.GetConnection())
{
string storedKey = connection.GetJobParameter(jobId, "trainingRunId");
if (storedKey == trainingRunId.ToString())
return ResultDTO.Fail("Can not delete training run because it is in process");
}
}

ResultDTO? resultDeleteEntity = await _trainingRunService.DeleteTrainingRun(trainingRunId, _webHostEnvironment.WebRootPath);
if (!resultDeleteEntity.IsSuccess && resultDeleteEntity.HandleError())
return ResultDTO.Fail(resultDeleteEntity.ErrMsg!);
Expand Down Expand Up @@ -513,9 +590,9 @@ private async Task<ResultDTO> CreateErrMsgFile(Guid trainingRunId, string errMsg
return ResultDTO.Fail("Directory path not found");

string? filePath = System.IO.Path.Combine(_webHostEnvironment.WebRootPath, trainingRunsErrorLogsFolder.Data);
if (!Directory.Exists(filePath))
if (!Directory.Exists(filePath))
Directory.CreateDirectory(filePath);

string fileName = $"{trainingRunId}_errMsg.txt";
string? fullFilePath = System.IO.Path.Combine(filePath, fileName);
if (fullFilePath == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

<ul class="nav nav-pills nav-sidebar flex-column" data-widget="treeview" role="menu" data-accordion="false">
<li class="nav-item">
<a asp-action="Index" asp-controller="Map" asp-area="IntranetPortal" id="NavlinkMap" class=" nav-link" title="@DbResHtml.T("Map", "Resources")">
<a asp-action="Index" asp-controller="Map" asp-area="IntranetPortal" id="NavlinkMap" class="@Html.IsActive(controllers: "Map") nav-link" title="@DbResHtml.T("Map", "Resources")">
<i class="fas fa-map"></i> &nbsp;&nbsp;
<p>
@DbResHtml.T("Map", "Resources")
Expand Down
2 changes: 1 addition & 1 deletion Tests/MainAppBLTests/Services/DatasetServiceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public async Task GetDatasetById_ShouldReturnFailResult_WhenDatasetNotFound()
var datasetId = Guid.NewGuid();

_mockDatasetRepository
.Setup(repo => repo.GetById(datasetId, false, null))
.Setup(repo => repo.GetById(datasetId, false, "CreatedBy"))
.ReturnsAsync(ResultDTO<Dataset?>.Ok(null));

// Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,9 @@ public async Task DeleteTrainingRun_NoMatchingJobs_ReturnsSuccess()
var trainingRunId = Guid.NewGuid();
var enqueuedJobs = new JobList<EnqueuedJobDto>(new List<KeyValuePair<string, EnqueuedJobDto>>());
Mock<IMonitoringApi> mockMonitoringApi = new Mock<IMonitoringApi>();

mockMonitoringApi.Setup(api => api.EnqueuedJobs("default", 0, int.MaxValue)).Returns(enqueuedJobs);
mockMonitoringApi.Setup(api => api.ProcessingJobs(0, int.MaxValue)).Returns(new JobList<ProcessingJobDto>(new List<KeyValuePair<string, ProcessingJobDto>>()));

JobStorage.Current = Mock.Of<JobStorage>(storage => storage.GetMonitoringApi() == mockMonitoringApi.Object);

Expand All @@ -837,6 +839,10 @@ public async Task DeleteTrainingRun_DeleteTrainingRunFails_ReturnsFailureResult(
.Setup(api => api.EnqueuedJobs("default", 0, int.MaxValue))
.Returns(new JobList<EnqueuedJobDto>(new List<KeyValuePair<string, EnqueuedJobDto>>()));

mockMonitoringApi
.Setup(api => api.ProcessingJobs(0, int.MaxValue))
.Returns(new JobList<ProcessingJobDto>(new List<KeyValuePair<string, ProcessingJobDto>>()));

var mockJobStorage = new Mock<JobStorage>();
mockJobStorage.Setup(js => js.GetMonitoringApi()).Returns(mockMonitoringApi.Object);
JobStorage.Current = mockJobStorage.Object;
Expand Down Expand Up @@ -864,6 +870,10 @@ public async Task DeleteTrainingRun_DeleteTrainingRunSucceeds_ReturnsSuccessResu
.Setup(api => api.EnqueuedJobs("default", 0, int.MaxValue))
.Returns(new JobList<EnqueuedJobDto>(new List<KeyValuePair<string, EnqueuedJobDto>>()));

mockMonitoringApi
.Setup(api => api.ProcessingJobs(0, int.MaxValue))
.Returns(new JobList<ProcessingJobDto>(new List<KeyValuePair<string, ProcessingJobDto>>()));

var mockJobStorage = new Mock<JobStorage>();
mockJobStorage.Setup(js => js.GetMonitoringApi()).Returns(mockMonitoringApi.Object);
JobStorage.Current = mockJobStorage.Object;
Expand All @@ -890,6 +900,8 @@ public async Task ExecuteTrainingRunProcess_ExceptionThrown_ReturnsExceptionFail
.Setup(service => service.UpdateTrainingRunEntity(trainingRunDTO.Id.Value, null, nameof(ScheduleRunsStatus.Processing), false,null))
.Throws(new Exception("Unexpected error"));

_mockMMDetectionConfiguration.Setup(s => s.GetTrainingRunsBaseOutDirAbsPath()).Returns("TrainingRunsBaseOytDirAbsPath");

// Act
var result = await _controller.ExecuteTrainingRunProcess(trainingRunDTO, paramsDTO);

Expand All @@ -907,6 +919,8 @@ public async Task ExecuteTrainingRunProcess_ShouldReturnFail_WhenUpdateTrainRunF
_mockTrainingRunService.Setup(s => s.UpdateTrainingRunEntity(It.IsAny<Guid>(), null, nameof(ScheduleRunsStatus.Processing), null, null))
.ReturnsAsync(ResultDTO.Fail("Object reference not set to an instance of an object."));

_mockMMDetectionConfiguration.Setup(s => s.GetTrainingRunsBaseOutDirAbsPath()).Returns("TrainingRunsBaseOytDirAbsPath");

// Act
var result = await _controller.ExecuteTrainingRunProcess(trainingRunDTO, paramsDTO);

Expand All @@ -926,6 +940,8 @@ public async Task ExecuteTrainingRunProcess_ShouldReturnFail_WhenGetDatasetFails
_mockDatasetService.Setup(s => s.GetDatasetDTOFullyIncluded(It.IsAny<Guid>(), false))
.ReturnsAsync(ResultDTO<DatasetDTO>.Fail("Object reference not set to an instance of an object."));

_mockMMDetectionConfiguration.Setup(s => s.GetTrainingRunsBaseOutDirAbsPath()).Returns("TrainingRunsBaseOytDirAbsPath");

// Act
var result = await _controller.ExecuteTrainingRunProcess(trainingRunDTO, paramsDTO);

Expand All @@ -945,6 +961,7 @@ public async Task ExecuteTrainingRunProcess_ShouldReturnExceptionFail_WhenExcept
_mockTrainingRunService.Setup(s => s.UpdateTrainingRunEntity(It.IsAny<Guid>(), null, nameof(ScheduleRunsStatus.Processing), null, null))
.ThrowsAsync(new NullReferenceException("Object reference not set to an instance of an object"));

_mockMMDetectionConfiguration.Setup(s => s.GetTrainingRunsBaseOutDirAbsPath()).Returns("TrainingRunsBaseOytDirAbsPath");
// Act
var result = await _controller.ExecuteTrainingRunProcess(trainingRunDTO, paramsDTO);

Expand Down Expand Up @@ -1029,6 +1046,7 @@ public async Task ExecuteTrainingRunProcess_ShouldReturnSuccess_WhenAllStepsAreS
It.IsAny<string>()))
.ReturnsAsync(ResultDTO.Ok());

_mockMMDetectionConfiguration.Setup(s => s.GetTrainingRunsBaseOutDirAbsPath()).Returns("TrainingRunsBaseOytDirAbsPath");
// Act
var result = await _controller.ExecuteTrainingRunProcess(trainingRunDTO, paramsDTO);

Expand Down
Loading