From 82ea6e4d6dbcdf9f75d07e5acec85c18234fd36e Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai <77720426+aadya940@users.noreply.github.com> Date: Fri, 23 Feb 2024 22:42:49 +0530 Subject: [PATCH] Add example for arviz.InferenceData.map (#2304) * Add example for arviz.InferenceData.map * minor fix --- .../WorkingWithInferenceData.ipynb | 13820 +++++++++++++++- 1 file changed, 13250 insertions(+), 570 deletions(-) diff --git a/doc/source/getting_started/WorkingWithInferenceData.ipynb b/doc/source/getting_started/WorkingWithInferenceData.ipynb index af1406cc5c..d9677d5c75 100644 --- a/doc/source/getting_started/WorkingWithInferenceData.ipynb +++ b/doc/source/getting_started/WorkingWithInferenceData.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -49,8 +49,8 @@ "
<xarray.DataArray 'school' (school: 8)>\n", "'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon'\n", "Coordinates:\n", - " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
PandasIndex(Index(['Essex College', 'Moordale'], dtype='object', name='new_school'))
<xarray.Dataset>\n", + "Dimensions: (draw: 500, school: 8, school_bis: 8)\n", + "Coordinates:\n", + " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", + " * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", + "Data variables:\n", + " mu (draw) float64 5.974 5.096 7.177 ... 3.284 4.739 3.146\n", + " theta (draw, school) float64 9.519 5.554 6.118 ... 5.595 3.773\n", + " tau (draw) float64 4.068 3.156 3.603 ... 2.725 3.225 2.979\n", + " log_tau (draw) float64 1.322 1.118 1.234 ... 0.958 1.035 0.9508\n", + " mlogtau (draw) float64 nan nan nan nan ... 0.993 1.002 1.01 1.021\n", + " theta_school_diff (draw, school, school_bis) float64 0.0 3.965 ... 0.0
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, new_school: 2)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", + " * new_school (new_school) <U13 'Essex College' 'Moordale'\n", + "Data variables:\n", + " obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n", + "Attributes: (2)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + "Data variables: (12/16)\n", + " max_energy_error (chain, draw) float64 ...\n", + " energy_error (chain, draw) float64 ...\n", + " lp (chain, draw) float64 ...\n", + " index_in_trajectory (chain, draw) int64 ...\n", + " acceptance_rate (chain, draw) float64 ...\n", + " diverging (chain, draw) bool ...\n", + " ... ...\n", + " smallest_eigval (chain, draw) float64 ...\n", + " step_size_bar (chain, draw) float64 ...\n", + " step_size (chain, draw) float64 ...\n", + " energy (chain, draw) float64 ...\n", + " tree_depth (chain, draw) int64 ...\n", + " perf_counter_diff (chain, draw) float64 ...\n", + "Attributes: (6)
<xarray.Dataset>\n", + "Dimensions: (draw: 500, school: 8)\n", + "Coordinates:\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " tau (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n", + " theta (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045\n", + " mu (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869
<xarray.Dataset>\n", + "Dimensions: (chain: 1, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (school: 8)\n", + "Coordinates:\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (school: 8)\n", + "Coordinates:\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " scores (school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (draw: 500, school: 8, school_bis: 8)\n", + "Coordinates:\n", + " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", + " * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", + "Data variables:\n", + " mu (draw) float64 8.974 8.096 10.18 ... 6.284 7.739 6.146\n", + " theta (draw, school) float64 12.52 8.554 9.118 ... 8.595 6.773\n", + " tau (draw) float64 7.068 6.156 6.603 ... 5.725 6.225 5.979\n", + " log_tau (draw) float64 4.322 4.118 4.234 ... 3.958 4.035 3.951\n", + " mlogtau (draw) float64 nan nan nan nan ... 3.993 4.002 4.01 4.021\n", + " theta_school_diff (draw, school, school_bis) float64 3.0 6.965 ... 3.0
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, new_school: 2)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", + " * new_school (new_school) <U13 'Essex College' 'Moordale'\n", + "Data variables:\n", + " obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n", + "Attributes: (2)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + "Data variables: (12/16)\n", + " max_energy_error (chain, draw) float64 ...\n", + " energy_error (chain, draw) float64 ...\n", + " lp (chain, draw) float64 ...\n", + " index_in_trajectory (chain, draw) int64 ...\n", + " acceptance_rate (chain, draw) float64 ...\n", + " diverging (chain, draw) bool ...\n", + " ... ...\n", + " smallest_eigval (chain, draw) float64 ...\n", + " step_size_bar (chain, draw) float64 ...\n", + " step_size (chain, draw) float64 ...\n", + " energy (chain, draw) float64 ...\n", + " tree_depth (chain, draw) int64 ...\n", + " perf_counter_diff (chain, draw) float64 ...\n", + "Attributes: (6)
<xarray.Dataset>\n", + "Dimensions: (draw: 500, school: 8)\n", + "Coordinates:\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " tau (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n", + " theta (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045\n", + " mu (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869
<xarray.Dataset>\n", + "Dimensions: (chain: 1, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (school: 8)\n", + "Coordinates:\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (school: 8)\n", + "Coordinates:\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " scores (school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (draw: 500, school: 8, school_bis: 8)\n", + "Coordinates:\n", + " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", + " * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", + "Data variables:\n", + " mu (draw) float64 5.974 5.096 7.177 ... 3.284 4.739 3.146\n", + " theta (draw, school) float64 9.519 5.554 6.118 ... 5.595 3.773\n", + " tau (draw) float64 4.068 3.156 3.603 ... 2.725 3.225 2.979\n", + " log_tau (draw) float64 1.322 1.118 1.234 ... 0.958 1.035 0.9508\n", + " mlogtau (draw) float64 nan nan nan nan ... 0.993 1.002 1.01 1.021\n", + " theta_school_diff (draw, school, school_bis) float64 0.0 3.965 ... 0.0
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, school: 8, Upper: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + " upper (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON'\n", + "Dimensions without coordinates: Upper\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, new_school: 2)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", + " * new_school (new_school) <U13 'Essex College' 'Moordale'\n", + "Data variables:\n", + " obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n", + "Attributes: (2)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500, school: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (chain: 4, draw: 500)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + "Data variables: (12/16)\n", + " max_energy_error (chain, draw) float64 ...\n", + " energy_error (chain, draw) float64 ...\n", + " lp (chain, draw) float64 ...\n", + " index_in_trajectory (chain, draw) int64 ...\n", + " acceptance_rate (chain, draw) float64 ...\n", + " diverging (chain, draw) bool ...\n", + " ... ...\n", + " smallest_eigval (chain, draw) float64 ...\n", + " step_size_bar (chain, draw) float64 ...\n", + " step_size (chain, draw) float64 ...\n", + " energy (chain, draw) float64 ...\n", + " tree_depth (chain, draw) int64 ...\n", + " perf_counter_diff (chain, draw) float64 ...\n", + "Attributes: (6)
<xarray.Dataset>\n", + "Dimensions: (draw: 500, school: 8)\n", + "Coordinates:\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " tau (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n", + " theta (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045\n", + " mu (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869
<xarray.Dataset>\n", + "Dimensions: (chain: 1, draw: 500, school: 8, Upper: 8)\n", + "Coordinates:\n", + " * chain (chain) int64 0\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + " upper (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON'\n", + "Dimensions without coordinates: Upper\n", + "Data variables:\n", + " obs (chain, draw, school) float64 ...\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (school: 8, Upper: 8)\n", + "Coordinates:\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + " upper (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON'\n", + "Dimensions without coordinates: Upper\n", + "Data variables:\n", + " obs (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n", + "Attributes: (4)
<xarray.Dataset>\n", + "Dimensions: (school: 8)\n", + "Coordinates:\n", + " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", + "Data variables:\n", + " scores (school) float64 ...\n", + "Attributes: (4)