diff --git a/exponax/normalized/__init__.py b/exponax/normalized/__init__.py index 41819ee..109a1dc 100644 --- a/exponax/normalized/__init__.py +++ b/exponax/normalized/__init__.py @@ -9,7 +9,7 @@ """ from ._convection import DifficultyConvectionStepper, NormalizedConvectionStepper from ._general_nonlinear import NormlizedGeneralNonlinearStepper -from ._gradient_norm import NormalizedGradientNormStepper +from ._gradient_norm import DifficultyGradientNormStepper, NormalizedGradientNormStepper from ._linear import ( DifficultyLinearStepper, DiffultyLinearStepperSimple, @@ -23,12 +23,14 @@ denormalize_polynomial_scales, extract_normalized_coefficients_from_difficulty, extract_normalized_convection_scale_from_difficulty, + extract_normalized_gradient_norm_scale_from_difficulty, normalize_coefficients, normalize_convection_scale, normalize_gradient_norm_scale, normalize_polynomial_scales, reduce_normalized_coefficients_to_difficulty, reduce_normalized_convection_scale_to_difficulty, + reduce_normalized_gradient_norm_scale_to_difficulty, ) from ._vorticity_convection import NormalizedVorticityConvection @@ -36,6 +38,7 @@ "DifficultyLinearStepper", "DiffultyLinearStepperSimple", "DifficultyConvectionStepper", + "DifficultyGradientNormStepper", "NormalizedConvectionStepper", "NormlizedGeneralNonlinearStepper", "NormalizedGradientNormStepper", @@ -54,4 +57,6 @@ "extract_normalized_coefficients_from_difficulty", "reduce_normalized_convection_scale_to_difficulty", "extract_normalized_convection_scale_from_difficulty", + "reduce_normalized_gradient_norm_scale_to_difficulty", + "extract_normalized_gradient_norm_scale_from_difficulty", ] diff --git a/exponax/normalized/_gradient_norm.py b/exponax/normalized/_gradient_norm.py index 181d798..dee5c72 100644 --- a/exponax/normalized/_gradient_norm.py +++ b/exponax/normalized/_gradient_norm.py @@ -3,6 +3,10 @@ from .._base_stepper import BaseStepper from ..nonlin_fun import GradientNormNonlinearFun +from ._utils import ( + extract_normalized_coefficients_from_difficulty, + extract_normalized_gradient_norm_scale_from_difficulty, +) class NormalizedGradientNormStepper(BaseStepper): @@ -69,3 +73,52 @@ def _build_nonlinear_fun(self, derivative_operator: Array): scale=self.normalized_gradient_norm_scale, zero_mode_fix=True, ) + + +class DifficultyGradientNormStepper(NormalizedGradientNormStepper): + linear_difficulties: tuple[float, ...] + gradient_norm_difficulty: float + + def __init__( + self, + num_spatial_dims: int = 1, + num_points: int = 48, + *, + linear_difficulties: tuple[float, ...] = (0.0, 0.0, -0.064, 0.0, -0.04096), + gradient_norm_difficulty: float = 0.064, + maximum_absolute: float = 1.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + num_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + By default: KS equation + """ + self.linear_difficulties = linear_difficulties + self.gradient_norm_difficulty = gradient_norm_difficulty + + normalized_coefficients = extract_normalized_coefficients_from_difficulty( + linear_difficulties, + num_spatial_dims=num_spatial_dims, + num_points=num_points, + ) + normalized_gradient_norm_scale = ( + extract_normalized_gradient_norm_scale_from_difficulty( + gradient_norm_difficulty, + num_spatial_dims=num_spatial_dims, + num_points=num_points, + maximum_absolute=maximum_absolute, + ) + ) + + super().__init__( + num_spatial_dims=num_spatial_dims, + num_points=num_points, + normalized_coefficients=normalized_coefficients, + normalized_gradient_norm_scale=normalized_gradient_norm_scale, + order=order, + dealiasing_fraction=dealiasing_fraction, + num_circle_points=num_circle_points, + circle_radius=circle_radius, + ) diff --git a/exponax/normalized/_utils.py b/exponax/normalized/_utils.py index 8f66960..c48eabc 100644 --- a/exponax/normalized/_utils.py +++ b/exponax/normalized/_utils.py @@ -189,3 +189,32 @@ def extract_normalized_convection_scale_from_difficulty( maximum_absolute * num_points * num_spatial_dims ) return normalized_convection_scale + + +def reduce_normalized_gradient_norm_scale_to_difficulty( + normalized_gradient_norm_scale: float, + *, + num_spatial_dims: int, + num_points: int, + maximum_absolute: float, +): + difficulty_gradient_norm_scale = ( + normalized_gradient_norm_scale + * maximum_absolute + * jnp.square(num_points) + * num_spatial_dims + ) + return difficulty_gradient_norm_scale + + +def extract_normalized_gradient_norm_scale_from_difficulty( + difficulty_gradient_norm_scale: float, + *, + num_spatial_dims: int, + num_points: int, + maximum_absolute: float, +): + normalized_gradient_norm_scale = difficulty_gradient_norm_scale / ( + maximum_absolute * jnp.square(num_points) * num_spatial_dims + ) + return normalized_gradient_norm_scale