diff --git a/scikit_tt/tensor_train.py b/scikit_tt/tensor_train.py index c9a6783..0373f35 100644 --- a/scikit_tt/tensor_train.py +++ b/scikit_tt/tensor_train.py @@ -1323,30 +1323,12 @@ def ortho(self, threshold: float=0, max_rank: Union[int, List[int]]=np.infty) -> ------- TT right-orthonormalized representation of self - - Raises - ------ - ValueError - if threshold is less than 0 - ValueError - if max_rank is not a positive integer """ - if isinstance(threshold, (int, np.int32, np.int64, float, np.float32, np.float64)) and threshold >= 0: + # left- and right-orthonormalize self + self.ortho_left(threshold=threshold, max_rank=np.infty).ortho_right(threshold=threshold, max_rank=max_rank) - if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty: - - # left- and right-orthonormalize self - self.ortho_left(threshold=threshold, max_rank=np.infty).ortho_right(threshold=threshold, - max_rank=max_rank) - - return self - - else: - raise ValueError('Maximum rank must be a positive integer.') - - else: - raise ValueError('Threshold must be greater or equal 0.') + return self def norm(self, p: int=2) -> float: """