Skip to content

Commit

Permalink
add classifier type flag
Browse files Browse the repository at this point in the history
  • Loading branch information
santi1234567 committed May 2, 2024
1 parent 268a3d0 commit 52a3545
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def parse_args():
parser.add_argument(
"--group", default=[], nargs="+", help="clients to group during classification"
)
parser.add_argument(
"--classifier-type", default="knn", choices=["knn", "mlp"], help="the type of classifier to use"
)
parser.add_argument(
"--persist",
action="store_true",
Expand Down Expand Up @@ -296,7 +299,7 @@ def main():
grouped_clients = args.group
should_persist = args.should_persist
graffiti_only = args.graffiti_only

classifier_type = args.classifier_type
disabled_clients = args.disable
enabled_clients = [
client
Expand Down Expand Up @@ -326,6 +329,7 @@ def main():
graffiti_only_clients=graffiti_only,
features=feature_vec,
enable_cv=True,
classifier_type=classifier_type
)
print(f"enabled clients: {classifier.enabled_clients}")
print(f"classifier scores: {classifier.scores['test_score']}")
Expand All @@ -343,7 +347,8 @@ def main():
assert classify_dir is not None, "classify dir required"
print(f"classifying all data in directory {classify_dir}")
print(f"grouped clients: {grouped_clients}")
classifier = Classifier(data_dir, grouped_clients=grouped_clients)
classifier = Classifier(data_dir, grouped_clients=grouped_clients,
classifier_type=classifier_type)

if args.plot is not None:
classifier.plot_feature_matrix(args.plot)
Expand Down

0 comments on commit 52a3545

Please sign in to comment.