Skip to content

Commit

Permalink
Add a demo notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 30, 2024
1 parent b879035 commit 9a258a9
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 1 deletion.
306 changes: 306 additions & 0 deletions docs/notebooks/introduction_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Building Simple Models\n",
"\n",
"In this tutorial we look at how to build a simple source model and sample the parameters from a variety of sources including pzflow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parameterized Nodes\n",
"\n",
"All sources of information in TDAstro live as `ParameterizedNode`s. This allows us to link the nodes (and their variables) together and sample them as a single block. As you will see in this tutorial, most of the nodes are specific to the object that you want to simulate. For example if we wanted to create a static source in the night sky with a brightness of 10, we could use: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tdastro.sources.static_source import StaticSource\n",
"\n",
"source = StaticSource(brightness=10.0, node_label=\"my_static_source\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`ParameterizedNode`s can then be sampled with the `sample_parameters()` function. This function will return a `GraphState` data structure that stores all of the data of the samples. \n",
"\n",
"**Note:** Users do not need to know the details of the `GraphState` storage, only that it can be accessed like a dictionary using the node's label and the variable name."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"state = source.sample_parameters(num_samples=10)\n",
"state[\"my_static_source\"][\"brightness\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The sample function produced 10 independent samples of our system's state.\n",
"\n",
"The brightness values of these samples are not particularly interesting because we were sampling from a fixed parameter. The brightness is always 10.0. However TDAstro allows the user to set a node's parameter from a variety of sources including constants (as with 10.0 above), the values stored in other nodes, or even the results of a \"functional\" or \"computation\" type node.\n",
"\n",
"TDAStro includes the built-in `NumpyRandomFunc` which will sample from a given numpy function and use the results to set a given parameter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tdastro.math_nodes.np_random import NumpyRandomFunc\n",
"\n",
"brightness_func = NumpyRandomFunc(\"uniform\", low=11.0, high=15.5)\n",
"source2 = StaticSource(brightness=brightness_func, node_label=\"my_static_source_2\")\n",
"state = source2.sample_parameters(num_samples=10)\n",
"\n",
"state[\"my_static_source_2\"][\"brightness\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now each of our 10 samples use different a different brightness value.\n",
"\n",
"We can make the distributions of objects more interesting, by using combinations of randomly generated parameters. Note here that we do not need to create the random nodes separately (as with the line `brightness_func = ...`). \n",
"\n",
"As shown below, we will often include the random sampler definition directly in the arguments as the parameter value.\n",
"\n",
"```\n",
" brightness=NumpyRandomFunc(\"normal\", loc=20.0, scale=2.0),\n",
"```\n",
"\n",
"The sampling process create a vector of samples for each parameter such that the `i`-th value of each parameter is from the same sampling run. Again, the user will rarely (if ever) need to interact with the samples directly.\n",
"\n",
"Here we sample the brightness from a Gaussian and sample the redshift from a uniform distribution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"source3 = StaticSource(\n",
" brightness=NumpyRandomFunc(\"normal\", loc=20.0, scale=2.0),\n",
" redshift=NumpyRandomFunc(\"uniform\", low=0.1, high=0.5),\n",
" node_label=\"test\",\n",
")\n",
"\n",
"num_samples = 10\n",
"state = source3.sample_parameters(num_samples=num_samples)\n",
"for i in range(num_samples):\n",
" print(f\"{i}: brightness={state['test']['brightness'][i]} redshift={state['test']['redshift'][i]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sample 0 consists of all the parameter values for that sample (everything at index=0), sample 1 consists of all parameter values for that sample (everything at index=1), and so forth. We can slice out a single sample using `extract_single_sample()` and display it. This is particularly important when different parameter values for a given sample depend on each other. We will see this case below when sampling a source's RA from it's host's RA."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"single_sample = state.extract_single_sample(0)\n",
"print(str(single_sample))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You'll notice that there are more parameters than we manually set. Parameters are created automatically by the nodes if needed. In particular functional nodes often create extra parameters for book keeping. In general the user should not need to worry about these extra parameters. They can access the ones of interest with the dictionary notation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Linked Sources\n",
"\n",
"Often the values of one node might depend on the values of another. A great case of this is a source/host pair where the location of the source depends on that of the host. We can access another node’s sampled parameters using a `.` notation: `{model_object}.{parameter_name}`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"host = StaticSource(brightness=15.0, ra=1.0, dec=2.0, node_label=\"host\")\n",
"source = StaticSource(brightness=10.0, ra=host.ra, dec=host.dec, node_label=\"source\")\n",
"state = source.sample_parameters(num_samples=5)\n",
"\n",
"for i in range(5):\n",
" print(\n",
" f\"{i}: Host=({state['host']['ra'][i]}, {state['host']['dec'][i]})\"\n",
" f\"Source=({state['source']['ra'][i]}, {state['source']['dec'][i]})\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can combine the node-parameter references with functional nodes to perform actions such as sampling with noise.\n",
"\n",
"Here we generate the host's (RA, dec) from a uniform patch of the sky and then generate the source's (RA, dec) using a Gaussian distribution centered on the host's position. For each sample the host and source should be close, but not necessarily identical."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"host = StaticSource(\n",
" brightness=15.0,\n",
" ra=NumpyRandomFunc(\"uniform\", low=10.0, high=15.0),\n",
" dec=NumpyRandomFunc(\"uniform\", low=-10.0, high=10.0),\n",
" node_label=\"host\",\n",
")\n",
"\n",
"source = StaticSource(\n",
" brightness=100.0,\n",
" ra=NumpyRandomFunc(\"normal\", loc=host.ra, scale=0.1),\n",
" dec=NumpyRandomFunc(\"normal\", loc=host.dec, scale=0.1),\n",
" node_label=\"source\",\n",
")\n",
"state = source.sample_parameters(num_samples=10)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"ax = plt.figure().add_subplot()\n",
"ax.plot(state[\"host\"][\"ra\"], state[\"host\"][\"dec\"], \"b.\")\n",
"ax.plot(state[\"source\"][\"ra\"], state[\"source\"][\"dec\"], \"r.\")\n",
"\n",
"for i in range(5):\n",
" print(\n",
" f\"{i}: Host=({state['host']['ra'][i]}, {state['host']['dec'][i]}) \"\n",
" f\"Source=({state['source']['ra'][i]}, {state['source']['dec'][i]})\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again we can access all the information for a single sample. Here we see the full state tracked by the system. In addition to the `host` and `source` nodes we created, the information for the functional nodes is tracked."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"single_sample = state.extract_single_sample(0)\n",
"print(str(single_sample))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is interesting to note that functional nodes themselves are parameterized nodes, allowing for more complex forms of chaining. For example we could set the `low` parameter from one of the `NumpyRandomFunc`s from another function node. This allows us to specify priors and comlex distributions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## pzflow\n",
"\n",
"We can use the `pzflow` package to generate data from joint distributions that have been learned from real data. Sampling with `pzflow` creates Pandas tables of values. We can access each column using the same `.` notation as above, allowing us to use a consistent set of values from multiple variables.\n",
"\n",
"We start by loading a simple pzflow that we have saved with (nonsense) redshift and brightness values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pzflow import Flow\n",
"from tdastro import _TDASTRO_TEST_DATA_DIR\n",
"\n",
"pzflow = Flow(file=_TDASTRO_TEST_DATA_DIR / \"test_flow.pkl\")\n",
"pzflow.sample(nsamples=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can create a `ParameterizedNode` for the pzflow as shown below. This simply wraps the existing `Flow` object as a `ParameterizedNode` and provides it a user readable name (\"pznode\") for later queries. Note that within the `PZFlowNode` node all parameters for the `i`-th sample are generated together."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tdastro.astro_utils.pzflow_node import PZFlowNode\n",
"\n",
"pz_node = PZFlowNode(pzflow, node_label=\"pznode\")\n",
"\n",
"source = StaticSource(\n",
" brightness=pz_node.brightness,\n",
" redshift=pz_node.redshift,\n",
" node_label=\"source\",\n",
")\n",
"\n",
"num_samples = 10\n",
"state = source.sample_parameters(num_samples=num_samples)\n",
"for i in range(num_samples):\n",
" print(f\"{i}: brightness={state['source']['brightness'][i]} z={state['source']['redshift'][i]}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tdastro",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion tests/tdastro/astro_utils/test_pzflow_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ def test_pzflow_node_from_file(test_flow_filename):
state = pz_node.sample_parameters(num_samples=100)
assert len(state["loaded_node"]) == 2
assert len(state["loaded_node"]["redshift"]) == 100
assert len(state["loaded_node"]["hostmass"]) == 100
assert len(state["loaded_node"]["brightness"]) == 100
Binary file modified tests/tdastro/data/test_flow.pkl
Binary file not shown.

0 comments on commit 9a258a9

Please sign in to comment.