diff --git a/gp_emulator/multivariate_gp.py b/gp_emulator/multivariate_gp.py index ef8a014..87103db 100644 --- a/gp_emulator/multivariate_gp.py +++ b/gp_emulator/multivariate_gp.py @@ -148,30 +148,32 @@ def dump_emulator ( self, fname, model_name, sza, vza, raa ): sza = int ( sza ) vza = int ( vza ) raa = int ( raa ) - try: - f = h5py.File (fname, 'r+') - except IOError: - print "The file %s did not exist. Creating it" % fname - f = h5py.File (fname, 'w') - f - group = '%s_%03d_%03d_%03d' % ( model_name, sza, vza, raa ) - if group in f.keys(): - raise ValueError, "Emulator already exists!" - f.create_group ("/%s" % group ) - f.create_dataset ( "/%s/X_train" % group, data=self.X_train, compression="gzip" ) - f.create_dataset ( "/%s/y_train" % group, data=self.y_train, compression="gzip" ) - f.create_dataset ( "/%s/hyperparams" % group, data=self.hyperparams, - compression="gzip" ) - f.create_dataset ( "/%s/basis_functions" % group, data=self.basis_functions, - compression="gzip" ) - f.create_dataset ( "/%s/thresh" % group, data=self.thresh ) - f.create_dataset ( "/%s/n_pcs" % group, data=self.n_pcs) - f.close() - print "Emulator safely saved" - - #np.savez_compressed ( fname, X=self.X_train, y=self.y_train, \ - #hyperparams=self.hyperparams, thresh=self.thresh, \ - #basis_functions=self.basis_functions, n_pcs=self.n_pcs ) + if fname.find ( ".npz" ) < 0 and ( fname.find ( "h5" ) >= 0 \ + or fname.find ( ".hdf" ) >= 0 ): + try: + f = h5py.File (fname, 'r+') + except IOError: + print "The file %s did not exist. Creating it" % fname + f = h5py.File (fname, 'w') + f + group = '%s_%03d_%03d_%03d' % ( model_name, sza, vza, raa ) + if group in f.keys(): + raise ValueError, "Emulator already exists!" + f.create_group ("/%s" % group ) + f.create_dataset ( "/%s/X_train" % group, data=self.X_train, compression="gzip" ) + f.create_dataset ( "/%s/y_train" % group, data=self.y_train, compression="gzip" ) + f.create_dataset ( "/%s/hyperparams" % group, data=self.hyperparams, + compression="gzip" ) + f.create_dataset ( "/%s/basis_functions" % group, data=self.basis_functions, + compression="gzip" ) + f.create_dataset ( "/%s/thresh" % group, data=self.thresh ) + f.create_dataset ( "/%s/n_pcs" % group, data=self.n_pcs) + f.close() + print "Emulator safely saved" + else: + np.savez_compressed ( fname, X=self.X_train, y=self.y_train, \ + hyperparams=self.hyperparams, thresh=self.thresh, \ + basis_functions=self.basis_functions, n_pcs=self.n_pcs ) def calculate_decomposition ( self, X, thresh ):