Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
aromberg committed Nov 14, 2024
1 parent a91f9e3 commit faf40eb
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 60 deletions.
7 changes: 3 additions & 4 deletions src/xspect/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ def get_model_by_slug(model_slug: str):
model_metadata = get_model_metadata(model_path)
if model_metadata["model_class"] == "ProbabilisticSingleFilterModel":
return ProbabilisticSingleFilterModel.load(model_path)
elif model_metadata["model_class"] == "ProbabilisticFilterSVMModel":
if model_metadata["model_class"] == "ProbabilisticFilterSVMModel":
return ProbabilisticFilterSVMModel.load(model_path)
elif model_metadata["model_class"] == "ProbabilisticFilterModel":
if model_metadata["model_class"] == "ProbabilisticFilterModel":
return ProbabilisticFilterModel.load(model_path)
else:
raise ValueError(f"Model class {model_metadata['model_class']} not recognized.")
raise ValueError(f"Model class {model_metadata['model_class']} not recognized.")


def get_model_metadata(model: str | Path):
Expand Down
13 changes: 2 additions & 11 deletions src/xspect/models/probabilistic_filter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from Bio import SeqIO
from slugify import slugify
import cobs_index as cobs
from xspect.definitions import fasta_endings, fastq_endings
from xspect.file_io import get_record_iterator
from xspect.models.result import ModelResult

Expand Down Expand Up @@ -64,10 +65,6 @@ def to_dict(self) -> dict:
"num_hashes": self.num_hashes,
}

def __dict__(self) -> dict:
"""Returns a dictionary representation of the model"""
return self.to_dict()

def slug(self) -> str:
"""Returns a slug representation of the model"""
return slugify(self.model_display_name + "-" + str(self.model_type))
Expand All @@ -89,13 +86,7 @@ def fit(self, dir_path: Path, display_names: dict = None) -> None:

doclist = cobs.DocumentList()
for file in dir_path.iterdir():
if file.is_file() and file.suffix in [
".fasta",
".fna",
".fa",
".fastq",
".fq",
]:
if file.is_file() and file.suffix[1:] in fasta_endings + fastq_endings:
# cobs only uses the file name to the first "." as the document name
if file.name in display_names:
self.display_names[file.name.split(".")[0]] = display_names[
Expand Down
3 changes: 3 additions & 0 deletions src/xspect/models/probabilistic_filter_svm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def fit(
) -> None:
"""Fit the SVM to the sequences and labels"""

# Since the SVM works with score data, we need to train
# the underlying data structure for score generation first
super().fit(dir_path, display_names=display_names)

# calculate scores for SVM training
score_list = []
for file in svm_path.iterdir():
if not file.is_file():
Expand Down
10 changes: 4 additions & 6 deletions src/xspect/models/probabilistic_single_filter_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Probabilistic filter SVM model for sequence data"""
"""Base probabilistic filter model for sequence data"""

# pylint: disable=no-name-in-module, too-many-instance-attributes

Expand All @@ -14,7 +14,7 @@


class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
"""Probabilistic filter SVM model for sequence data"""
"""Base probabilistic filter model for sequence data"""

def __init__(
self,
Expand All @@ -25,7 +25,6 @@ def __init__(
model_type: str,
base_path: Path,
fpr: float = 0.01,
num_hashes: int = 7,
) -> None:
super().__init__(
k=k,
Expand All @@ -35,12 +34,12 @@ def __init__(
model_type=model_type,
base_path=base_path,
fpr=fpr,
num_hashes=num_hashes,
num_hashes=1,
)
self.bf = None

def fit(self, file_path: Path, display_name: str) -> None:
"""Fit the SVM to the sequences and labels"""
"""Fit the cobs classic index to the sequences and labels"""
# estimate number of kmers
total_length = 0
for record in get_record_iterator(file_path):
Expand Down Expand Up @@ -89,7 +88,6 @@ def load(path: Path) -> "ProbabilisticSingleFilterModel":
model_json["model_type"],
path.parent,
fpr=model_json["fpr"],
num_hashes=model_json["num_hashes"],
)
model.display_names = model_json["display_names"]
bloom_path = model.base_path / model.slug() / "filter.bloom"
Expand Down
13 changes: 7 additions & 6 deletions src/xspect/models/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

def get_last_processing_step(result: "ModelResult") -> "ModelResult":
"""Get the last subprocessing step of the result. First path only."""
last_step = result
while last_step.subprocessing_steps:
last_step = last_step.subprocessing_steps[-1].result
return last_step

# traverse result tree to get last step
while result.subprocessing_steps:
result = result.subprocessing_steps[-1].result
return result


class StepType(Enum):
Expand Down Expand Up @@ -82,9 +83,9 @@ def get_scores(self) -> dict:
scores = {
subsequence: {
label: round(hits / self.num_kmers[subsequence], 2)
for label, hits in subseuqence_hits.items()
for label, hits in subsequence_hits.items()
}
for subsequence, subseuqence_hits in self.hits.items()
for subsequence, subsequence_hits in self.hits.items()
}

# calculate total scores
Expand Down
34 changes: 1 addition & 33 deletions src/xspect/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def check_user_input(user_input: str):
rank = metadata["rank"]
lineage = metadata["lineage"]
bacteria_id = 2
if not sci_name == user_input and not tax_id == user_input:
if user_input not in (sci_name, tax_id):
print(
f"{get_current_time()}| The given genus: {user_input} was found as"
f" genus: {sci_name} ID: {tax_id}"
Expand All @@ -60,38 +60,6 @@ def check_user_input(user_input: str):
sys.exit()


def copy_custom_data(bf_path: str, svm_path: str, dir_name: str):
"""
:param bf_path:
:param svm_path:
:param dir_name:
:return:
"""
path = Path(os.getcwd()) / "genus_metadata" / dir_name
new_bf_path = path / "concatenate"
new_svm_path = path / "training_data"

# Make the new directories.
path.mkdir(exist_ok=True)
new_bf_path.mkdir(exist_ok=True)
new_svm_path.mkdir(exist_ok=True)

# Move bloomfilter files.
bf_files = os.listdir(bf_path)
for file in bf_files:
file_path = Path(bf_path) / file
new_file_path = new_bf_path / file
shutil.copy2(file_path, new_file_path)

# Move svm files.
svm_files = os.listdir(svm_path)
for file in svm_files:
file_path = Path(svm_path) / file
new_file_path = new_svm_path / file
shutil.copy2(file_path, new_file_path)


def set_logger(dir_name: str):
"""Sets the logger parameters.
Expand Down

0 comments on commit faf40eb

Please sign in to comment.