-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #181 from lincc-frameworks/pzflow
Create a new node type that wraps pzflow and create an introductory notebook
- Loading branch information
Showing
6 changed files
with
489 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Building Simple Models\n", | ||
"\n", | ||
"In this tutorial we look at how to build a simple source model and sample the parameters from a variety of sources including pzflow." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Parameterized Nodes\n", | ||
"\n", | ||
"All sources of information in TDAstro live as `ParameterizedNode`s. This allows us to link the nodes (and their variables) together and sample them as a single block. As you will see in this tutorial, most of the nodes are specific to the object that you want to simulate. For example if we wanted to create a static source in the night sky with a brightness of 10, we could use: " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tdastro.sources.static_source import StaticSource\n", | ||
"\n", | ||
"source = StaticSource(brightness=10.0, node_label=\"my_static_source\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"`ParameterizedNode`s can then be sampled with the `sample_parameters()` function. This function will return a `GraphState` data structure that stores all of the data of the samples. \n", | ||
"\n", | ||
"**Note:** Users do not need to know the details of the `GraphState` storage, only that it can be accessed like a dictionary using the node's label and the variable name." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"state = source.sample_parameters(num_samples=10)\n", | ||
"state[\"my_static_source\"][\"brightness\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The sample function produced 10 independent samples of our system's state.\n", | ||
"\n", | ||
"The brightness values of these samples are not particularly interesting because we were sampling from a fixed parameter. The brightness is always 10.0. However TDAstro allows the user to set a node's parameter from a variety of sources including constants (as with 10.0 above), the values stored in other nodes, or even the results of a \"functional\" or \"computation\" type node.\n", | ||
"\n", | ||
"TDAStro includes the built-in `NumpyRandomFunc` which will sample from a given numpy function and use the results to set a given parameter." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tdastro.math_nodes.np_random import NumpyRandomFunc\n", | ||
"\n", | ||
"brightness_func = NumpyRandomFunc(\"uniform\", low=11.0, high=15.5)\n", | ||
"source2 = StaticSource(brightness=brightness_func, node_label=\"my_static_source_2\")\n", | ||
"state = source2.sample_parameters(num_samples=10)\n", | ||
"\n", | ||
"state[\"my_static_source_2\"][\"brightness\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now each of our 10 samples use different a different brightness value.\n", | ||
"\n", | ||
"We can make the distributions of objects more interesting, by using combinations of randomly generated parameters. Note here that we do not need to create the random nodes separately (as with the line `brightness_func = ...`). \n", | ||
"\n", | ||
"As shown below, we will often include the random sampler definition directly in the arguments as the parameter value.\n", | ||
"\n", | ||
"```\n", | ||
" brightness=NumpyRandomFunc(\"normal\", loc=20.0, scale=2.0),\n", | ||
"```\n", | ||
"\n", | ||
"The sampling process create a vector of samples for each parameter such that the `i`-th value of each parameter is from the same sampling run. Again, the user will rarely (if ever) need to interact with the samples directly.\n", | ||
"\n", | ||
"Here we sample the brightness from a Gaussian and sample the redshift from a uniform distribution." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"source3 = StaticSource(\n", | ||
" brightness=NumpyRandomFunc(\"normal\", loc=20.0, scale=2.0),\n", | ||
" redshift=NumpyRandomFunc(\"uniform\", low=0.1, high=0.5),\n", | ||
" node_label=\"test\",\n", | ||
")\n", | ||
"\n", | ||
"num_samples = 10\n", | ||
"state = source3.sample_parameters(num_samples=num_samples)\n", | ||
"for i in range(num_samples):\n", | ||
" print(f\"{i}: brightness={state['test']['brightness'][i]} redshift={state['test']['redshift'][i]}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Sample 0 consists of all the parameter values for that sample (everything at index=0), sample 1 consists of all parameter values for that sample (everything at index=1), and so forth. We can slice out a single sample using `extract_single_sample()` and display it. This is particularly important when different parameter values for a given sample depend on each other. We will see this case below when sampling a source's RA from it's host's RA." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"single_sample = state.extract_single_sample(0)\n", | ||
"print(str(single_sample))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"You'll notice that there are more parameters than we manually set. Parameters are created automatically by the nodes if needed. In particular functional nodes often create extra parameters for book keeping. In general the user should not need to worry about these extra parameters. They can access the ones of interest with the dictionary notation." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Linked Sources\n", | ||
"\n", | ||
"Often the values of one node might depend on the values of another. A great case of this is a source/host pair where the location of the source depends on that of the host. We can access another node’s sampled parameters using a `.` notation: `{model_object}.{parameter_name}`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"host = StaticSource(brightness=15.0, ra=1.0, dec=2.0, node_label=\"host\")\n", | ||
"source = StaticSource(brightness=10.0, ra=host.ra, dec=host.dec, node_label=\"source\")\n", | ||
"state = source.sample_parameters(num_samples=5)\n", | ||
"\n", | ||
"for i in range(5):\n", | ||
" print(\n", | ||
" f\"{i}: Host=({state['host']['ra'][i]}, {state['host']['dec'][i]})\"\n", | ||
" f\"Source=({state['source']['ra'][i]}, {state['source']['dec'][i]})\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can combine the node-parameter references with functional nodes to perform actions such as sampling with noise.\n", | ||
"\n", | ||
"Here we generate the host's (RA, dec) from a uniform patch of the sky and then generate the source's (RA, dec) using a Gaussian distribution centered on the host's position. For each sample the host and source should be close, but not necessarily identical." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"host = StaticSource(\n", | ||
" brightness=15.0,\n", | ||
" ra=NumpyRandomFunc(\"uniform\", low=10.0, high=15.0),\n", | ||
" dec=NumpyRandomFunc(\"uniform\", low=-10.0, high=10.0),\n", | ||
" node_label=\"host\",\n", | ||
")\n", | ||
"\n", | ||
"source = StaticSource(\n", | ||
" brightness=100.0,\n", | ||
" ra=NumpyRandomFunc(\"normal\", loc=host.ra, scale=0.1),\n", | ||
" dec=NumpyRandomFunc(\"normal\", loc=host.dec, scale=0.1),\n", | ||
" node_label=\"source\",\n", | ||
")\n", | ||
"state = source.sample_parameters(num_samples=10)\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"ax = plt.figure().add_subplot()\n", | ||
"ax.plot(state[\"host\"][\"ra\"], state[\"host\"][\"dec\"], \"b.\")\n", | ||
"ax.plot(state[\"source\"][\"ra\"], state[\"source\"][\"dec\"], \"r.\")\n", | ||
"\n", | ||
"for i in range(5):\n", | ||
" print(\n", | ||
" f\"{i}: Host=({state['host']['ra'][i]}, {state['host']['dec'][i]}) \"\n", | ||
" f\"Source=({state['source']['ra'][i]}, {state['source']['dec'][i]})\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Again we can access all the information for a single sample. Here we see the full state tracked by the system. In addition to the `host` and `source` nodes we created, the information for the functional nodes is tracked." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"single_sample = state.extract_single_sample(0)\n", | ||
"print(str(single_sample))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"It is interesting to note that functional nodes themselves are parameterized nodes, allowing for more complex forms of chaining. For example we could set the `low` parameter from one of the `NumpyRandomFunc`s from another function node. This allows us to specify priors and comlex distributions." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## pzflow\n", | ||
"\n", | ||
"We can use the `pzflow` package to generate data from joint distributions that have been learned from real data. Sampling with `pzflow` creates Pandas tables of values. We can access each column using the same `.` notation as above, allowing us to use a consistent set of values from multiple variables.\n", | ||
"\n", | ||
"We start by loading a simple pzflow that we have saved with (nonsense) redshift and brightness values." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pzflow import Flow\n", | ||
"from tdastro import _TDASTRO_TEST_DATA_DIR\n", | ||
"\n", | ||
"pzflow = Flow(file=_TDASTRO_TEST_DATA_DIR / \"test_flow.pkl\")\n", | ||
"pzflow.sample(nsamples=10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can create a `ParameterizedNode` for the pzflow as shown below. This simply wraps the existing `Flow` object as a `ParameterizedNode` and provides it a user readable name (\"pznode\") for later queries. Note that within the `PZFlowNode` node all parameters for the `i`-th sample are generated together." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tdastro.astro_utils.pzflow_node import PZFlowNode\n", | ||
"\n", | ||
"pz_node = PZFlowNode(pzflow, node_label=\"pznode\")\n", | ||
"\n", | ||
"source = StaticSource(\n", | ||
" brightness=pz_node.brightness,\n", | ||
" redshift=pz_node.redshift,\n", | ||
" node_label=\"source\",\n", | ||
")\n", | ||
"\n", | ||
"num_samples = 10\n", | ||
"state = source.sample_parameters(num_samples=num_samples)\n", | ||
"for i in range(num_samples):\n", | ||
" print(f\"{i}: brightness={state['source']['brightness'][i]} z={state['source']['redshift'][i]}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "tdastro", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ dependencies = [ | |
"jax", | ||
"numpy", | ||
"pandas", | ||
"pzflow", | ||
"scipy", | ||
"sncosmo", | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
"""A wrapper around pzflow sampling. | ||
For the full pzflow package see: | ||
https://github.com/jfcrenshaw/pzflow | ||
""" | ||
|
||
import numpy as np | ||
from pzflow import Flow | ||
|
||
from tdastro.base_models import FunctionNode | ||
|
||
|
||
class PZFlowNode(FunctionNode): | ||
"""A node that wraps sampling from pzflow. | ||
Attributes | ||
---------- | ||
flow : pzflow.flow.Flow or pzflow.flowEnsemble.FlowEnsemble | ||
The object from which to sample. | ||
columns : list of str | ||
The column names for the output columns. | ||
""" | ||
|
||
def __init__(self, flow_obj, node_label=None, **kwargs): | ||
self.flow = flow_obj | ||
|
||
# Add each of the flow's data columns as an output parameter. | ||
self.columns = [x for x in flow_obj.data_columns] | ||
super().__init__(self._non_func, node_label=node_label, outputs=self.columns, **kwargs) | ||
|
||
def _non_func(self): | ||
"""This function does nothing. Everything happens in the overloaded compute().""" | ||
pass | ||
|
||
@classmethod | ||
def from_file(cls, filename, node_label=None): | ||
"""Create a PZFlowNode from a saved flow file. | ||
Parameters | ||
---------- | ||
filename : str or Path | ||
The location of the saved flow. | ||
node_label : `str` | ||
An optional human readable identifier (name) for the current node. | ||
""" | ||
flow_to_use = Flow(file=filename) | ||
return PZFlowNode(flow_to_use, node_label=node_label) | ||
|
||
def compute(self, graph_state, rng_info=None, **kwargs): | ||
"""Return the given values. | ||
Parameters | ||
---------- | ||
graph_state : `GraphState` | ||
An object mapping graph parameters to their values. This object is modified | ||
in place as it is sampled. | ||
rng_info : numpy.random._generator.Generator, optional | ||
Unused in this function, but included to provide consistency with other | ||
compute functions. | ||
**kwargs : `dict`, optional | ||
Additional function arguments. | ||
Returns | ||
------- | ||
results : any | ||
The result of the computation. This return value is provided so that testing | ||
functions can easily access the results. | ||
""" | ||
# If a random number generator is used, use that to compute the seed. | ||
seed = None if rng_info is None else int.from_bytes(rng_info.bytes(4), byteorder="big") | ||
samples = self.flow.sample(graph_state.num_samples, seed=seed) | ||
|
||
# Parse out each column in the flow samples as a result vector. | ||
results = [] | ||
for attr_name in self.flow.data_columns: | ||
attr_values = samples[attr_name].values | ||
if graph_state.num_samples == 1: | ||
results.append(attr_values[0]) | ||
else: | ||
results.append(np.array(attr_values)) | ||
|
||
# Save and return the results. | ||
self._save_results(results, graph_state) | ||
return results |
Oops, something went wrong.