diff --git a/skgstat/SpaceTimeVariogram.py b/skgstat/SpaceTimeVariogram.py index 92a9cfb..17b8f7c 100644 --- a/skgstat/SpaceTimeVariogram.py +++ b/skgstat/SpaceTimeVariogram.py @@ -5,10 +5,12 @@ from scipy.spatial.distance import pdist from scipy.ndimage.interpolation import zoom from scipy.interpolate import griddata +from scipy.optimize import curve_fit import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D +import inspect -from skgstat import binning, estimators, Variogram +from skgstat import binning, estimators, Variogram, stmodels class SpaceTimeVariogram: @@ -27,6 +29,7 @@ def __init__(self, tbins='even', estimator='matheron', use_nugget=False, + model='product-sum', verbose=False ): @@ -90,6 +93,11 @@ def __init__(self, self._use_nugget = None self.use_nugget = use_nugget + # set the model + self._model = model + self.set_model(model_name=model) + self._model_params = {} + # _x and values are set, build the marginal Variogram objects # marginal space variogram self.create_XMarginal() @@ -97,8 +105,8 @@ def __init__(self, # marginal time variogram self.create_TMarginal() - # do one preprocessing run - self.preprocessing(force=True) + # fit the model with forced preprocessing + #self.fit(force=True) # ------------------------------------------------------------------------ # # ATTRIBUTE SETTING # @@ -167,7 +175,6 @@ def set_values(self, values): if self.TMarginal is not None: self.create_TMarginal() - @values.setter def values(self, new_values): self.set_values(values=new_values) @@ -569,6 +576,48 @@ def set_estimator(self, estimator_name): self._set_xmarg_params() self._set_tmarg_params() + @property + def model(self): + return self._model + + @model.setter + def model(self, value): + self.set_model(model_name=value) + + def set_model(self, model_name): + """Set space-time model + + Set a new space-time model. It has to be either a callable of correct + signature or a string identifying one of the predefined models + + Parameters + ---------- + model_name : str, callable + Either a callable of correct signature or a valid model name. + Valid names are: + + * sum + * product + * product-sum + + + """ + # reset fitting + self.cof, self.cov = None, None + + if isinstance(model_name, str): + name = model_name.lower() + if name == 'sum': + self._model = stmodels.sum + elif name == 'product': + self._model = stmodels.product + elif name == 'product-sum' or name == 'product_sum': + self._model = stmodels.product_sum + elif callable(model_name): + self._model = model_name + else: + raise ValueError('model_name has to be a string or callable.') + def create_XMarginal(self): """ Create an instance of skgstat.Variogram for the space marginal variogram @@ -919,6 +968,90 @@ def preprocessing(self, force=False): # ------------------------------------------------------------------------ # # FITTING # # ------------------------------------------------------------------------ # + def fit(self, force=False): + # delete the last cov and cof + self.cof = None + self.cov = None + + # if force, force a clean preprocessing + self.preprocessing(force=force) + + # This is not finished + return + + # load the fitting data + xx, yy = self.meshbins + z = self.experimental + + # remove NaN values + ydata = z[np.where(~np.isnan(z))] + _xx = xx.flatten()[np.where(~np.isnan(z))[0]] + _yy = yy.flatten()[np.where(~np.isnan(z))[0]] + xdata = np.vstack((_xx, _yy)) + + # get the marginal variogram functions + Vx = self.XMarginal.fitted_model + Vt = self.TMarginal.fitted_model + + # get the params of the model + _code_obj = self._model.__wrapped__.__code__ + model_args = inspect.getargs(_code_obj).args + self._model_params = dict() + +# if 'Vx' in model_args: +# self._model_params['Vx'] = Vx +# if 'Vt' in model_args: +# self._model_params['Vt'] = Vt + + # fix the sills? + fix_sills = True # TODO: Make this a param in __init__ + if fix_sills and 'Cx' in model_args: + self._model_params['Cx'] = self.XMarginal.describe()['sill'] + if fix_sills and 'Ct' in model_args: + self._model_params['Ct'] = self.TMarginal.describe()['sill'] + + # are there parameters left to fit? + free_args = len(model_args) - 3 - len(self._model_params.keys()) + if free_args == 0: + # no params left + self.cof = [] + self.cov = [] + return + + # wrap the model + def _model(lags, *args): + return self._model(lags, Vx, Vt, *args, **self._model_params) + + self.cof, self.cov = curve_fit( + _model, xdata, ydata, bounds=[0, np.inf], p0=[1.] * free_args + ) + + return + + @property + def fitted_model(self): + """ + + Returns + ------- + + """ + # if not model not fitted, fit it + if self.cof is None or self.cov is None: + self.fit(force=False) + + # get the model func + func = self._model + + # get the marginal Variograms + Vx = self.XMarginal.fitted_model + Vt = self.TMarginal.fitted_model + + # define the function + def model(lags): + return func(lags, Vx, Vt, *self.cof, **self._model_params) + + return model # ------------------------------------------------------------------------ # # RESULTS # @@ -968,19 +1101,6 @@ def _get_member(self, xlag, tlag): t_idxs = self._tgroups == tlag return self._diff[np.where(x_idxs)[0]][:, np.where(t_idxs)[0]].flatten() - def build_marginal_variograms(self): - """build marginal Variogram classes - - The two marginal variograms for space and time axis will be - initialized and added to this instance. Both are an instance of - skgstat.Variogram in order to model them properly. Use these classes - to well working valid models to the marginal Variograms before - modeling the space-time model. - The two objects will be available as SpaceTimeVariogram.XMarginal and - SpaceTimeVariogram.TMarginal. - - """ - # ------------------------------------------------------------------------ # # PLOTTING # # ------------------------------------------------------------------------ # @@ -1372,7 +1492,8 @@ def _plot2d(self, kind='contour', ax=None, zoom_factor=100., return fig - def marginals(self, plot=True, axes=None, sharey=True, **kwargs): + def marginals(self, plot=True, axes=None, sharey=True, include_model=False, + **kwargs): """Plot marginal variograms Plots the two marginal variograms into a new or existing figure. The @@ -1396,6 +1517,9 @@ class 0. In case the expected variability is not of same magnitude, If True (default), the two marginal variograms will share their y-axis to increase comparability. Should be set to False in the variances are of different magnitude. + include_model : bool + If True, the marginal variogram models fitted to the respective + axis are included into the plot. kwargs : dict Only kwargs accepted is ``figsize``, if ax is None. Anything else will be ignored. @@ -1410,8 +1534,10 @@ class 0. In case the expected variability is not of same magnitude, """ # get the marginal space variogram - vx = self.get_marginal(axis='space', lag=0) - vy = self.get_marginal(axis='time', lag=0) +# vx = self.get_marginal(axis='space', lag=0) +# vy = self.get_marginal(axis='time', lag=0) + vx = self.XMarginal.experimental + vy = self.TMarginal.experimental # if no plot is desired, return the experimental variograms if not plot: @@ -1434,8 +1560,22 @@ class 0. In case the expected variability is not of same magnitude, ax2 = axes[1] ax3 = ax2.twinx() ax3.get_shared_y_axes().join(ax3, ax) - ax.plot(self.xbins, vx, '-ob') - ax3.plot(self.tbins, vy, '-og') + + if include_model: + # transform + xx = np.linspace(0, self.xbins[-1], 50) + xy = np.linspace(0, self.tbins[-1], 50) + y_vx = self.XMarginal.transform(xx) + y_vy = self.TMarginal.transform(xy) + + # plot + ax.plot(self.xbins, vx, 'Db') + ax.plot(xx, y_vx, '-b') + ax3.plot(self.tbins, vy, 'Dg') + ax3.plot(xy, y_vy, '-g') + else: + ax.plot(self.xbins, vx, '-ob') + ax3.plot(self.tbins, vy, '-og') # set labels ax.set_xlabel('distance [spatial]')