Skip to content

Commit

Permalink
CM-38979 - Add the ability to cancel scans (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarshalX authored Aug 6, 2024
1 parent 3f69e89 commit 9c5b75c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 53 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

## [Unreleased]

## [1.2.0] - 2024-08-XX
## [1.2.0] - 2024-08-06

- Add Open-source Threats (SCA) support
- Add the ability to cancel scans

## [1.1.4] - 2024-07-25

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Cycode.VisualStudio.Extension.Shared.JsonContractResolvers;
using Cycode.VisualStudio.Extension.Shared.Services;
Expand All @@ -22,17 +23,20 @@ private async Task<string[]> GetDefaultCliArgsAsync() {
// cache
if (_defaultCliArgs.Length > 0) return _defaultCliArgs;

_defaultCliArgs = new[] {
_defaultCliArgs = [
"-o", "json",
"--user-agent", await UserAgent.GetUserAgentEscapedAsync()
};
];

_logger.Debug("Default CLI args: {0}", string.Join(" ", _defaultCliArgs));

return _defaultCliArgs;
}

public async Task<CliResult<T>> ExecuteCommandAsync<T>(string[] arguments, Func<bool> cancelledCallback = null) {
public async Task<CliResult<T>> ExecuteCommandAsync<T>(
string[] arguments,
CancellationToken cancellationToken = default
) {
General general = await General.GetLiveInstanceAsync();

ProcessStartInfo startInfo = new() {
Expand Down Expand Up @@ -89,12 +93,14 @@ public async Task<CliResult<T>> ExecuteCommandAsync<T>(string[] arguments, Func<
process.BeginErrorReadLine();

while (!process.HasExited) {
if (cancelledCallback != null && cancelledCallback()) {
try {
cancellationToken.ThrowIfCancellationRequested();
await Task.Delay(1000, cancellationToken);
} catch (Exception e) when (e is ObjectDisposedException or OperationCanceledException) {
process.Kill();
_logger.Debug("CLI Execution was canceled by user");
return new CliResult<T>.Panic(ExitCode.Termination, "Execution was canceled");
}

await Task.Delay(1000);
}

int exitCode = await tcs.Task;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Cycode.VisualStudio.Extension.Shared.Cli;
using Cycode.VisualStudio.Extension.Shared.Cli.DTO;
Expand All @@ -13,8 +14,6 @@

namespace Cycode.VisualStudio.Extension.Shared.Services;

using TaskCancelledCallback = Func<bool>;

public class CliService(
ILoggerService logger,
IStateService stateService,
Expand Down Expand Up @@ -79,8 +78,8 @@ private static CliResult<T> ProcessResult<T>(CliResult<T> result) {
}
}

public async Task<bool> HealthCheckAsync(TaskCancelledCallback cancelledCallback = null) {
CliResult<VersionResult> result = await _cli.ExecuteCommandAsync<VersionResult>(["version"], cancelledCallback);
public async Task<bool> HealthCheckAsync(CancellationToken cancellationToken = default) {
CliResult<VersionResult> result = await _cli.ExecuteCommandAsync<VersionResult>(["version"], cancellationToken);
CliResult<VersionResult> processedResult = ProcessResult(result);

if (processedResult is CliResult<VersionResult>.Success successResult) {
Expand All @@ -94,9 +93,9 @@ public async Task<bool> HealthCheckAsync(TaskCancelledCallback cancelledCallback
return false;
}

public async Task<bool> CheckAuthAsync(TaskCancelledCallback cancelledCallback = null) {
public async Task<bool> CheckAuthAsync(CancellationToken cancellationToken = default) {
CliResult<AuthCheckResult> result =
await _cli.ExecuteCommandAsync<AuthCheckResult>(["auth", "check"], cancelledCallback);
await _cli.ExecuteCommandAsync<AuthCheckResult>(["auth", "check"], cancellationToken);
CliResult<AuthCheckResult> processedResult = ProcessResult(result);

if (processedResult is CliResult<AuthCheckResult>.Success successResult) {
Expand All @@ -120,8 +119,8 @@ public async Task<bool> CheckAuthAsync(TaskCancelledCallback cancelledCallback =
return false;
}

public async Task<bool> DoAuthAsync(TaskCancelledCallback cancelledCallback = null) {
CliResult<AuthResult> result = await _cli.ExecuteCommandAsync<AuthResult>(["auth"], cancelledCallback);
public async Task<bool> DoAuthAsync(CancellationToken cancellationToken = default) {
CliResult<AuthResult> result = await _cli.ExecuteCommandAsync<AuthResult>(["auth"], cancellationToken);
CliResult<AuthResult> processedResult = ProcessResult(result);

if (processedResult is not CliResult<AuthResult>.Success successResult) {
Expand Down Expand Up @@ -151,24 +150,24 @@ private static string[] GetCliScanOptions(CliScanType scanType) {
}

private async Task<CliResult<T>> ScanPathsAsync<T>(
List<string> paths, CliScanType scanType, TaskCancelledCallback cancelledCallback = null
List<string> paths, CliScanType scanType, CancellationToken cancellationToken = default
) {
List<string> isolatedPaths = paths.Select(path => $"\"{path}\"").ToList();
string scanTypeString = scanType.ToString().ToLower();
CliResult<T> result = await _cli.ExecuteCommandAsync<T>(
new[] { "scan", "-t", scanTypeString }.Concat(GetCliScanOptions(scanType)).Concat(new[] { "path" })
.Concat(isolatedPaths).ToArray(),
cancelledCallback
cancellationToken
);

return ProcessResult(result);
}

public async Task ScanPathsSecretsAsync(
List<string> paths, bool onDemand = true, TaskCancelledCallback cancelledCallback = null
List<string> paths, bool onDemand = true, CancellationToken cancellationToken = default
) {
CliResult<SecretScanResult> results =
await ScanPathsAsync<SecretScanResult>(paths, CliScanType.Secret, cancelledCallback);
await ScanPathsAsync<SecretScanResult>(paths, CliScanType.Secret, cancellationToken);
if (results == null) {
logger.Warn("Failed to scan Secret paths: {0}", string.Join(", ", paths));
return;
Expand All @@ -185,10 +184,10 @@ public async Task ScanPathsSecretsAsync(
}

public async Task ScanPathsScaAsync(
List<string> paths, bool onDemand = true, TaskCancelledCallback cancelledCallback = null
List<string> paths, bool onDemand = true, CancellationToken cancellationToken = default
) {
CliResult<ScaScanResult> results =
await ScanPathsAsync<ScaScanResult>(paths, CliScanType.Sca, cancelledCallback);
await ScanPathsAsync<ScaScanResult>(paths, CliScanType.Sca, cancellationToken);
if (results == null) {
logger.Warn("Failed to scan SCA paths: {0}", string.Join(", ", paths));
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Threading;
using Cycode.VisualStudio.Extension.Shared.DTO;
using Cycode.VisualStudio.Extension.Shared.Helpers;
#if VS16 || VS17
Expand All @@ -18,7 +19,7 @@ IToolWindowMessengerService toolWindowMessengerService

#if VS16 || VS17 // We don't have VS16 constant because we support range of versions in one project
private static async Task WrapWithStatusCenterAsync(
Func<Task> taskFunction,
Func<CancellationToken, Task> taskFunction,
string label,
bool canBeCanceled
) {
Expand All @@ -32,26 +33,30 @@ bool canBeCanceled
data.CanBeCanceled = canBeCanceled;

ITaskHandler handler = tsc.PreRegister(options, data);
// TODO(MarshalX): Support CancellationToken!
// Task task = taskFunction(handler.UserCancellation);
Task task = taskFunction();
Task task = taskFunction(handler.UserCancellation);
handler.RegisterTask(task);

await task; // wait for the task to complete, otherwise it will be run in the background

data.PercentComplete = 100;
handler.Progress.Report(data);
try {
await task; // wait for the task to complete, otherwise it will be run in the background
} finally {
data.PercentComplete = 100;
handler.Progress.Report(data);
}
}
#else
private static async Task WrapWithStatusCenterAsync(
Func<Task> taskFunction,
Func<CancellationToken, Task> taskFunction,
string label,
bool canBeCanceled // For old VS version; doesn't support TaskStatusCenter; doesn't support cancellation
) {
// currentStep must have a value of 1 or higher!
await VS.StatusBar.ShowProgressAsync(label, currentStep: 1, numberOfSteps: 2);
await taskFunction();
await VS.StatusBar.ShowProgressAsync(label, currentStep: 2, numberOfSteps: 2);

try {
await taskFunction(default);
} finally {
await VS.StatusBar.ShowProgressAsync(label, currentStep: 2, numberOfSteps: 2);
}
}
#endif

Expand All @@ -74,7 +79,7 @@ await WrapWithStatusCenterAsync(
);
}

private async Task InstallCliIfNeededAndCheckAuthenticationAsyncInternalAsync() {
private async Task InstallCliIfNeededAndCheckAuthenticationAsyncInternalAsync(CancellationToken cancellationToken) {
try {
toolWindowMessengerService.Send(MessengerCommand.LoadLoadingControl);

Expand All @@ -84,8 +89,8 @@ private async Task InstallCliIfNeededAndCheckAuthenticationAsyncInternalAsync()
return;
}

await cliService.HealthCheckAsync();
await cliService.CheckAuthAsync();
await cliService.HealthCheckAsync(cancellationToken);
await cliService.CheckAuthAsync(cancellationToken);

UpdateToolWindowDependingOnState();
} catch (Exception e) {
Expand All @@ -101,16 +106,16 @@ await WrapWithStatusCenterAsync(
);
}

private async Task StartAuthInternalAsync() {
private async Task StartAuthInternalAsync(CancellationToken cancellationToken) {
if (!_pluginState.CliAuthed) {
logger.Debug("Start auth...");
await cliService.DoAuthAsync();
await cliService.DoAuthAsync(cancellationToken);
UpdateToolWindowDependingOnState();
} else {
logger.Debug("Already authenticated with Cycode CLI");
}
}

public async Task StartSecretScanForCurrentProjectAsync() {
string projectRoot = SolutionHelper.GetSolutionRootDirectory();
if (projectRoot == null) {
Expand All @@ -127,23 +132,26 @@ public async Task StartPathSecretScanAsync(string pathToScan, bool onDemand = fa

public async Task StartPathSecretScanAsync(List<string> pathsToScan, bool onDemand = false) {
await WrapWithStatusCenterAsync(
taskFunction: () => StartPathSecretScanInternalAsync(pathsToScan, onDemand),
taskFunction: cancellationToken =>
StartPathSecretScanInternalAsync(pathsToScan, onDemand, cancellationToken),
label: "Cycode is scanning files for hardcoded secrets...",
canBeCanceled: false // TODO(MarshalX): Should be cancellable. Not implemented yet
canBeCanceled: true
);
}

private async Task StartPathSecretScanInternalAsync(List<string> pathsToScan, bool onDemand = false) {
private async Task StartPathSecretScanInternalAsync(
List<string> pathsToScan, bool onDemand = false, CancellationToken cancellationToken = default
) {
if (!_pluginState.CliAuthed) {
logger.Debug("Not authenticated with Cycode CLI. Aborting scan...");
return;
}

logger.Debug("[Secret] Start scanning paths: {0}", string.Join(", ", pathsToScan));
await cliService.ScanPathsSecretsAsync(pathsToScan, onDemand);
await cliService.ScanPathsSecretsAsync(pathsToScan, onDemand, cancellationToken);
logger.Debug("[Secret] Finish scanning paths: {0}", string.Join(", ", pathsToScan));
}

public async Task StartScaScanForCurrentProjectAsync() {
string projectRoot = SolutionHelper.GetSolutionRootDirectory();
if (projectRoot == null) {
Expand All @@ -160,20 +168,22 @@ public async Task StartPathScaScanAsync(string pathToScan, bool onDemand = false

public async Task StartPathScaScanAsync(List<string> pathsToScan, bool onDemand = false) {
await WrapWithStatusCenterAsync(
taskFunction: () => StartPathScaScanInternalAsync(pathsToScan, onDemand),
taskFunction: cancellationToken => StartPathScaScanInternalAsync(pathsToScan, onDemand, cancellationToken),
label: "Cycode is scanning files for package vulnerabilities...",
canBeCanceled: false // TODO(MarshalX): Should be cancellable. Not implemented yet
canBeCanceled: true
);
}

private async Task StartPathScaScanInternalAsync(List<string> pathsToScan, bool onDemand = false) {
private async Task StartPathScaScanInternalAsync(
List<string> pathsToScan, bool onDemand = false, CancellationToken cancellationToken = default
) {
if (!_pluginState.CliAuthed) {
logger.Debug("Not authenticated with Cycode CLI. Aborting scan...");
return;
}

logger.Debug("[SCA] Start scanning paths: {0}", string.Join(", ", pathsToScan));
await cliService.ScanPathsScaAsync(pathsToScan, onDemand);
await cliService.ScanPathsScaAsync(pathsToScan, onDemand, cancellationToken);
logger.Debug("[SCA] Finish scanning paths: {0}", string.Join(", ", pathsToScan));
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using TaskCancelledCallback = System.Func<bool>;

namespace Cycode.VisualStudio.Extension.Shared.Services;

public interface ICliService {
Task<bool> HealthCheckAsync(TaskCancelledCallback cancelledCallback = null);
Task<bool> CheckAuthAsync(TaskCancelledCallback cancelledCallback = null);
Task<bool> DoAuthAsync(TaskCancelledCallback cancelledCallback = null);
Task<bool> HealthCheckAsync(CancellationToken cancellationToken = default);
Task<bool> CheckAuthAsync(CancellationToken cancellationToken = default);
Task<bool> DoAuthAsync(CancellationToken cancellationToken = default);

Task ScanPathsSecretsAsync(
List<string> paths, bool onDemand = true, TaskCancelledCallback cancelledCallback = null
List<string> paths, bool onDemand = true, CancellationToken cancellationToken = default
);

Task ScanPathsScaAsync(
List<string> paths, bool onDemand = true, TaskCancelledCallback cancelledCallback = null
List<string> paths, bool onDemand = true, CancellationToken cancellationToken = default
);
}

0 comments on commit 9c5b75c

Please sign in to comment.