Skip to content

Commit

Permalink
minor changes in tf dataset and forward model
Browse files Browse the repository at this point in the history
  • Loading branch information
Justinezgh committed Mar 5, 2024
1 parent e8707b8 commit b0763c1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions sbi_lens/gen_dataset/lensing_lpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
super().__init__(description=("LPT lensing simulations."), version=v1, **kwargs)
self.N = N
self.map_size = map_size
box_size = box_size
box_shape = box_shape
self.box_size = box_size
self.box_shape = box_shape
self.gal_per_arcmin2 = gal_per_arcmin2
self.sigma_e = sigma_e
self.nbins = nbins
Expand All @@ -69,8 +69,8 @@ class LensingLPTDataset(tfds.core.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
LensingLPTDatasetConfig(
name="year_10_with_noise_score_density",
N=config_lsst_y_10.N,
map_size=config_lsst_y_10.map_size,
N=60,
map_size=5,
box_size=[400.0, 400.0, 4000.0],
box_shape=[300, 300, 128],
gal_per_arcmin2=config_lsst_y_10.gals_per_arcmin2,
Expand All @@ -84,8 +84,8 @@ class LensingLPTDataset(tfds.core.GeneratorBasedBuilder):
),
LensingLPTDatasetConfig(
name="year_10_without_noise_score_density",
N=config_lsst_y_10.N,
map_size=config_lsst_y_10.map_size,
N=60,
map_size=5,
box_size=[400.0, 400.0, 4000.0],
box_shape=[300, 300, 128],
gal_per_arcmin2=config_lsst_y_10.gals_per_arcmin2,
Expand All @@ -99,8 +99,8 @@ class LensingLPTDataset(tfds.core.GeneratorBasedBuilder):
),
LensingLPTDatasetConfig(
name="year_10_with_noise_score_conditional",
N=config_lsst_y_10.N,
map_size=config_lsst_y_10.map_size,
N=60,
map_size=5,
box_size=[400.0, 400.0, 4000.0],
box_shape=[300, 300, 128],
gal_per_arcmin2=config_lsst_y_10.gals_per_arcmin2,
Expand Down Expand Up @@ -150,7 +150,7 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager):

def _generate_examples(self, size):
"""Yields examples."""
bs = 20
bs = 5

model = partial(
lensingLpt,
Expand Down
4 changes: 2 additions & 2 deletions sbi_lens/simulator/Lpt_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def lensingLpt(
a,
b,
z0,
with_noise=True
with_noise=True,
):
"""
This function defines the top-level forward model for our observations
Expand Down Expand Up @@ -215,7 +215,7 @@ def lensingLpt(
)

# Generate random convergence maps
nz = jc.redshift.smail_nz(a, b, z0, gals_per_arcmin2=gal_per_arcmin2)
nz = jc.redshift.smail_nz(a, b, z0, gals_per_arcmin2=gal_per_arcmin2, zmax=2.6)
nz_shear = subdivide(nz, nbins=nbins, zphot_sigma=0.05)

lensing_model = jax.jit(
Expand Down

0 comments on commit b0763c1

Please sign in to comment.