Skip to content

Commit

Permalink
Interval classifiers (#5)
Browse files Browse the repository at this point in the history
* actually change python version

* dummy classifiers and sklearn lower bound change

* test fix

* test fix

* dev

* early sklearn version fixes

* all interval classifiers
  • Loading branch information
MatthewMiddlehurst authored Apr 6, 2023
1 parent d0a8c6d commit fb4f7ab
Show file tree
Hide file tree
Showing 20 changed files with 617 additions and 261 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "tsml"
version = "0.0.4"
version = "0.0.5"
description = "A toolkit for time series machine learning algorithms."
authors = [
{name = "Matthew Middlehurst", email = "[email protected]"},
Expand Down Expand Up @@ -42,8 +42,9 @@ dependencies = [

[project.optional-dependencies]
extras = [
"pycatch22",
"pyfftw"
"pycatch22>=0.4.2",
"pyfftw>=0.12.0",
"statsmodels>=0.12.1",
]
dev = [
"pre-commit",
Expand Down
2 changes: 1 addition & 1 deletion tsml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""tsml."""

__version__ = "0.0.4"
__version__ = "0.0.5"
6 changes: 3 additions & 3 deletions tsml/dummy/_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def fit(self, X, y):
for index, classVal in enumerate(self.classes_):
self.class_dictionary_[classVal] = index

if len(self.classes_) == 1:
if self.n_classes_ == 1:
return self

self._clf = SklearnDummyClassifier(
Expand All @@ -120,12 +120,12 @@ def predict(self, X) -> np.ndarray:
""""""
check_is_fitted(self)

X = self._validate_data(X=X, reset=False, ensure_min_series_length=1)

# treat case of single class seen in fit
if self.n_classes_ == 1:
return np.repeat(list(self.class_dictionary_.keys()), X.shape[0], axis=0)

X = self._validate_data(X=X, reset=False, ensure_min_series_length=1)

return self._clf.predict(np.zeros(X.shape))

def predict_proba(self, X) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion tsml/feature_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"Catch22Regressor",
]

from tsml.feature_based._catch22_classifier import Catch22Classifier, Catch22Regressor
from tsml.feature_based._catch22 import Catch22Classifier, Catch22Regressor
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def fit(self, X, y):
for index, classVal in enumerate(self.classes_):
self.class_dictionary_[classVal] = index

if self.n_classes_ == 1:
return self

self._n_jobs = check_n_jobs(self.n_jobs)

self._transformer = Catch22Transformer(
Expand Down Expand Up @@ -164,6 +167,10 @@ def predict(self, X) -> np.ndarray:
"""
check_is_fitted(self)

# treat case of single class seen in fit
if self.n_classes_ == 1:
return np.repeat(list(self.class_dictionary_.keys()), X.shape[0], axis=0)

X = self._validate_data(X=X, reset=False)

return self._estimator.predict(self._transformer.transform(X))
Expand All @@ -183,6 +190,10 @@ def predict_proba(self, X) -> np.ndarray:
"""
check_is_fitted(self)

# treat case of single class seen in fit
if self.n_classes_ == 1:
return np.repeat([[1]], X.shape[0], axis=0)

X = self._validate_data(X=X, reset=False)

m = getattr(self._estimator, "predict_proba", None)
Expand Down
21 changes: 14 additions & 7 deletions tsml/interval_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,28 @@
"BaseIntervalForest",
"CIFClassifier",
"CIFRegressor",
# "DrCIFClassifier",
# "DrCIFRegressor",
"DrCIFClassifier",
"DrCIFRegressor",
"IntervalForestClassifier",
"IntervalForestRegressor",
"RandomIntervalClassifier",
"RandomIntervalRegressor",
"SupervisedIntervalClassifier",
# "RISEClassifier",
# "RISERegressor",
# "STSFClassifier",
# "RSTSFClassifier",
"RISEClassifier",
"RISERegressor",
"STSFClassifier",
"RSTSFClassifier",
"TSFClassifier",
"TSFRegressor",
]

from tsml.interval_based._base import BaseIntervalForest
from tsml.interval_based._cif import CIFClassifier, CIFRegressor
from tsml.interval_based._cif import (
CIFClassifier,
CIFRegressor,
DrCIFClassifier,
DrCIFRegressor,
)
from tsml.interval_based._interval_forest import (
IntervalForestClassifier,
IntervalForestRegressor,
Expand All @@ -31,4 +36,6 @@
RandomIntervalRegressor,
SupervisedIntervalClassifier,
)
from tsml.interval_based._rise import RISEClassifier, RISERegressor
from tsml.interval_based._stsf import RSTSFClassifier, STSFClassifier
from tsml.interval_based._tsf import TSFClassifier, TSFRegressor
Loading

0 comments on commit fb4f7ab

Please sign in to comment.