From 82ea6e4d6dbcdf9f75d07e5acec85c18234fd36e Mon Sep 17 00:00:00 2001 From: Aadya Chinubhai <77720426+aadya940@users.noreply.github.com> Date: Fri, 23 Feb 2024 22:42:49 +0530 Subject: [PATCH] Add example for arviz.InferenceData.map (#2304) * Add example for arviz.InferenceData.map * minor fix --- .../WorkingWithInferenceData.ipynb | 13820 +++++++++++++++- 1 file changed, 13250 insertions(+), 570 deletions(-) diff --git a/doc/source/getting_started/WorkingWithInferenceData.ipynb b/doc/source/getting_started/WorkingWithInferenceData.ipynb index af1406cc5c..d9677d5c75 100644 --- a/doc/source/getting_started/WorkingWithInferenceData.ipynb +++ b/doc/source/getting_started/WorkingWithInferenceData.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -49,8 +49,8 @@ "
\n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:41.460544
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.487399
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.324929
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.602116
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.604969
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.606375
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.607471
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", @@ -3509,7 +3509,7 @@ "\t> constant_data" ] }, - "execution_count": 2, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -3528,7 +3528,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -3907,13 +3907,13 @@ " mu (chain, draw) float64 ...\n", " theta (chain, draw, school) float64 ...\n", " tau (chain, draw) float64 ...\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000
  • " ], "text/plain": [ "\n", @@ -3929,7 +3929,7 @@ "Attributes: (6)" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -3959,7 +3959,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -4339,23 +4339,23 @@ " theta (chain, draw, school) float64 ...\n", " tau (chain, draw) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461\n", " log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.008 1.076 1.495\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000
  • " ], "text/plain": [ "\n", @@ -4372,7 +4372,7 @@ "Attributes: (6)" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -4391,7 +4391,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": { "scrolled": true }, @@ -4774,9 +4774,9 @@ " theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295\n", " tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461\n", " log_tau (sample) float64 1.553 1.363 1.578 0.6188 ... 1.008 1.076 1.495\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000
  • " ], "text/plain": [ "\n", @@ -4831,7 +4831,7 @@ "Attributes: (6)" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -4862,7 +4862,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -5236,239 +5236,239 @@ "Coordinates:\n", " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", " * sample (sample) object MultiIndex\n", - " * chain (sample) int64 0 2 2 2 2 2 0 1 3 3 2 1 ... 1 1 1 0 3 3 2 0 3 3 2 2\n", - " * draw (sample) int64 358 475 205 429 168 25 14 ... 136 467 99 271 50 395\n", + " * chain (sample) int64 3 1 1 3 2 3 3 3 3 1 0 0 ... 2 0 2 2 3 2 2 3 2 0 3 2\n", + " * draw (sample) int64 214 202 176 487 371 27 ... 207 411 170 102 70 385\n", "Data variables:\n", - " mu (sample) float64 0.8319 7.433 2.281 6.153 ... 0.6951 3.324 9.385\n", - " theta (school, sample) float64 10.78 6.728 2.047 ... 0.7347 -0.8202 7.359\n", - " tau (sample) float64 8.878 3.047 2.865 2.852 ... 1.176 2.038 6.46 1.431\n", - " log_tau (sample) float64 2.184 1.114 1.052 1.048 ... 0.7121 1.866 0.3586\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000
  • " ], "text/plain": [ "\n", @@ -5476,17 +5476,17 @@ "Coordinates:\n", " * school (school)
    <xarray.DataArray 'school' (school: 8)>\n",
            "'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon'\n",
            "Coordinates:\n",
    -       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  • " ], "text/plain": [ "\n", @@ -5959,7 +5959,7 @@ " * school (school) \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -6372,9 +6372,9 @@ " theta (chain, draw, school) float64 12.32 9.905 14.95 ... 2.363 -2.968\n", " tau (chain, draw) float64 4.726 3.909 4.844 1.857 ... 4.09 2.72 1.917\n", " log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.408 1.001 0.6508\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -6812,20 +6812,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:41.460544
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -7200,20 +7200,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.487399
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -7599,17 +7599,17 @@ " energy (chain, draw) float64 ...\n", " tree_depth (chain, draw) int64 ...\n", " perf_counter_diff (chain, draw) float64 ...\n", - "Attributes: (6)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.324929
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -7986,20 +7986,20 @@ " tau (chain, draw) float64 ...\n", " theta (chain, draw, school) float64 ...\n", " mu (chain, draw) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.602116
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -8374,20 +8374,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.604969
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -8760,17 +8760,17 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.606375
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -9143,10 +9143,10 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " scores (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.607471
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", @@ -9506,7 +9506,7 @@ "\t> constant_data" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -9526,7 +9526,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -9540,8 +9540,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -9919,11 +9919,11 @@ " theta (chain, draw, school) float64 14.23 9.72 9.195 ... 6.762 1.295\n", " tau (chain, draw) float64 4.289 2.765 2.457 1.719 ... 2.741 2.932 4.461\n", " log_tau (chain, draw) float64 1.456 1.017 0.8991 ... 1.008 1.076 1.495\n", - "Attributes: (6)
    • created_at :
      2022-10-13T14:37:37.315398
      arviz_version :
      0.13.0.dev0
      inference_library :
      pymc
      inference_library_version :
      4.2.2
      sampling_time :
      7.480114936828613
      tuning_steps :
      1000

    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -10342,20 +10342,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:41.460544
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -10730,20 +10730,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.487399
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -11129,17 +11129,17 @@ " energy (chain, draw) float64 ...\n", " tree_depth (chain, draw) int64 ...\n", " perf_counter_diff (chain, draw) float64 ...\n", - "Attributes: (6)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.324929
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -11516,20 +11516,20 @@ " tau (chain, draw) float64 ...\n", " theta (chain, draw, school) float64 ...\n", " mu (chain, draw) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.602116
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -11904,20 +11904,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.604969
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -12290,17 +12290,17 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.606375
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -12673,10 +12673,10 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " scores (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.607471
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", @@ -13036,7 +13036,7 @@ "\t> constant_data" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -13054,7 +13054,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -13068,8 +13068,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -13447,11 +13447,11 @@ " theta (chain, draw, school) float64 14.23 9.72 9.195 ... 6.762 1.295\n", " tau (chain, draw) float64 4.289 2.765 2.457 1.719 ... 2.741 2.932 4.461\n", " log_tau (chain, draw) float64 1.456 1.017 0.8991 ... 1.008 1.076 1.495\n", - "Attributes: (6)
    • created_at :
      2022-10-13T14:37:37.315398
      arviz_version :
      0.13.0.dev0
      inference_library :
      pymc
      inference_library_version :
      4.2.2
      sampling_time :
      7.480114936828613
      tuning_steps :
      1000

    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -13870,20 +13870,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:41.460544
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -14258,20 +14258,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.487399
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -14657,17 +14657,17 @@ " energy (chain, draw) float64 ...\n", " tree_depth (chain, draw) int64 ...\n", " perf_counter_diff (chain, draw) float64 ...\n", - "Attributes: (6)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.324929
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -15044,20 +15044,20 @@ " tau (chain, draw) float64 ...\n", " theta (chain, draw, school) float64 ...\n", " mu (chain, draw) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.602116
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -15432,20 +15432,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.604969
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -15818,17 +15818,17 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.606375
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -16201,10 +16201,10 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " scores (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.607471
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", @@ -16564,7 +16564,7 @@ "\t> constant_data" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -16584,7 +16584,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -16959,7 +16959,7 @@ " mu float64 4.486\n", " theta float64 4.912\n", " tau float64 4.124\n", - " log_tau float64 1.173" + " log_tau float64 1.173" ], "text/plain": [ "\n", @@ -16971,7 +16971,7 @@ " log_tau float64 1.173" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -16991,7 +16991,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -17368,11 +17368,11 @@ " mu float64 4.486\n", " theta (school) float64 6.46 5.028 3.938 4.872 3.667 3.975 6.581 4.772\n", " tau float64 4.124\n", - " log_tau float64 1.173
  • " ], "text/plain": [ "\n", @@ -17386,7 +17386,7 @@ " log_tau float64 1.173" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -17415,7 +17415,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -17442,7 +17442,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -17468,7 +17468,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -17851,12 +17851,12 @@ " log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.076 1.495\n", " mlogtau (chain, draw) float64 nan nan nan ... 1.494 1.496 1.511\n", " theta_school_diff (chain, draw, school, school_bis) float64 0.0 ... 0.0\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000
  • " ], "text/plain": [ "\n", @@ -17962,7 +17962,7 @@ "Attributes: (6)" ] }, - "execution_count": 17, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -17981,7 +17981,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -18356,17 +18356,17 @@ " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", " school <U16 'Choate'\n", - " school_bis <U16 'Deerfield'
  • " ], "text/plain": [ "\n", @@ -18378,7 +18378,7 @@ " school_bis
  • " ], "text/plain": [ "
  • " ], "text/plain": [ "\n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -19754,12 +19754,12 @@ " log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.076 1.495\n", " mlogtau (chain, draw) float64 nan nan nan ... 1.494 1.496 1.511\n", " theta_school_diff (chain, draw, school, school_bis) float64 0.0 ... 0.0\n", - "Attributes: (6)
  • created_at :
    2022-10-13T14:37:37.315398
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -20227,20 +20227,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:41.460544
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -20615,7 +20615,7 @@ " * new_school (new_school) <U13 'Essex College' 'Moordale'\n", "Data variables:\n", " obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n", - "Attributes: (2)
    • new_school
      PandasIndex
      PandasIndex(Index(['Essex College', 'Moordale'], dtype='object', name='new_school'))
  • created_at :
    2023-12-28T12:47:21.311677
    arviz_version :
    0.16.1

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -21030,20 +21030,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.487399
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -21429,17 +21429,17 @@ " energy (chain, draw) float64 ...\n", " tree_depth (chain, draw) int64 ...\n", " perf_counter_diff (chain, draw) float64 ...\n", - "Attributes: (6)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:37.324929
    inference_library :
    pymc
    inference_library_version :
    4.2.2
    sampling_time :
    7.480114936828613
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -21816,20 +21816,20 @@ " tau (chain, draw) float64 ...\n", " theta (chain, draw, school) float64 ...\n", " mu (chain, draw) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.602116
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -22204,20 +22204,20 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (chain, draw, school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.604969
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -22590,17 +22590,17 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " obs (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.606375
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -22973,10 +22973,10 @@ " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", " scores (school) float64 ...\n", - "Attributes: (4)
  • arviz_version :
    0.13.0.dev0
    created_at :
    2022-10-13T14:37:26.607471
    inference_library :
    pymc
    inference_library_version :
    4.2.2

  • \n", " \n", " \n", " \n", @@ -23337,7 +23337,7 @@ "\t> constant_data" ] }, - "execution_count": 23, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -23352,6 +23352,12686 @@ "idata" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add Transformations to Multiple Groups" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also add transformations to Multiple InferenceData Groups using {meth}`arviz.InferenceData.map`. It takes a function as an input and applies the function groupwise to the selected InferenceData groups and overwrites the group with the result of the function." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
    \n", + "
    \n", + "
    arviz.InferenceData
    \n", + "
    \n", + "
      \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:            (draw: 500, school: 8, school_bis: 8)\n",
        +       "Coordinates:\n",
        +       "  * draw               (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
        +       "  * school             (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
        +       "  * school_bis         (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    mu                 (draw) float64 5.974 5.096 7.177 ... 3.284 4.739 3.146\n",
        +       "    theta              (draw, school) float64 9.519 5.554 6.118 ... 5.595 3.773\n",
        +       "    tau                (draw) float64 4.068 3.156 3.603 ... 2.725 3.225 2.979\n",
        +       "    log_tau            (draw) float64 1.322 1.118 1.234 ... 0.958 1.035 0.9508\n",
        +       "    mlogtau            (draw) float64 nan nan nan nan ... 0.993 1.002 1.01 1.021\n",
        +       "    theta_school_diff  (draw, school, school_bis) float64 0.0 3.965 ... 0.0

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0 1 2 3\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
        +       "Coordinates:\n",
        +       "  * chain       (chain) int64 0 1 2 3\n",
        +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
        +       "  * new_school  (new_school) <U13 'Essex College' 'Moordale'\n",
        +       "Data variables:\n",
        +       "    obs         (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n",
        +       "Attributes: (2)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0 1 2 3\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:              (chain: 4, draw: 500)\n",
        +       "Coordinates:\n",
        +       "  * chain                (chain) int64 0 1 2 3\n",
        +       "  * draw                 (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
        +       "Data variables: (12/16)\n",
        +       "    max_energy_error     (chain, draw) float64 ...\n",
        +       "    energy_error         (chain, draw) float64 ...\n",
        +       "    lp                   (chain, draw) float64 ...\n",
        +       "    index_in_trajectory  (chain, draw) int64 ...\n",
        +       "    acceptance_rate      (chain, draw) float64 ...\n",
        +       "    diverging            (chain, draw) bool ...\n",
        +       "    ...                   ...\n",
        +       "    smallest_eigval      (chain, draw) float64 ...\n",
        +       "    step_size_bar        (chain, draw) float64 ...\n",
        +       "    step_size            (chain, draw) float64 ...\n",
        +       "    energy               (chain, draw) float64 ...\n",
        +       "    tree_depth           (chain, draw) int64 ...\n",
        +       "    perf_counter_diff    (chain, draw) float64 ...\n",
        +       "Attributes: (6)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    tau      (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n",
        +       "    theta    (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045\n",
        +       "    mu       (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (school: 8)\n",
        +       "Coordinates:\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (school: 8)\n",
        +       "Coordinates:\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    scores   (school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    \n", + "
    \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> predictions\n", + "\t> log_likelihood\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "selected_groups = [\"posterior\", \"prior\"]\n", + "\n", + "def calc_mean(dataset, *args, **kwargs):\n", + " result = dataset.mean(dim=\"chain\", *args, **kwargs)\n", + " return result\n", + "\n", + "means = idata.map(calc_mean, groups=selected_groups, inplace=False)\n", + "means" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also pass a lambda function in `map`" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
    \n", + "
    \n", + "
    arviz.InferenceData
    \n", + "
    \n", + "
      \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:            (draw: 500, school: 8, school_bis: 8)\n",
        +       "Coordinates:\n",
        +       "  * draw               (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
        +       "  * school             (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
        +       "  * school_bis         (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    mu                 (draw) float64 8.974 8.096 10.18 ... 6.284 7.739 6.146\n",
        +       "    theta              (draw, school) float64 12.52 8.554 9.118 ... 8.595 6.773\n",
        +       "    tau                (draw) float64 7.068 6.156 6.603 ... 5.725 6.225 5.979\n",
        +       "    log_tau            (draw) float64 4.322 4.118 4.234 ... 3.958 4.035 3.951\n",
        +       "    mlogtau            (draw) float64 nan nan nan nan ... 3.993 4.002 4.01 4.021\n",
        +       "    theta_school_diff  (draw, school, school_bis) float64 3.0 6.965 ... 3.0

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0 1 2 3\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
        +       "Coordinates:\n",
        +       "  * chain       (chain) int64 0 1 2 3\n",
        +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
        +       "  * new_school  (new_school) <U13 'Essex College' 'Moordale'\n",
        +       "Data variables:\n",
        +       "    obs         (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n",
        +       "Attributes: (2)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0 1 2 3\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:              (chain: 4, draw: 500)\n",
        +       "Coordinates:\n",
        +       "  * chain                (chain) int64 0 1 2 3\n",
        +       "  * draw                 (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
        +       "Data variables: (12/16)\n",
        +       "    max_energy_error     (chain, draw) float64 ...\n",
        +       "    energy_error         (chain, draw) float64 ...\n",
        +       "    lp                   (chain, draw) float64 ...\n",
        +       "    index_in_trajectory  (chain, draw) int64 ...\n",
        +       "    acceptance_rate      (chain, draw) float64 ...\n",
        +       "    diverging            (chain, draw) bool ...\n",
        +       "    ...                   ...\n",
        +       "    smallest_eigval      (chain, draw) float64 ...\n",
        +       "    step_size_bar        (chain, draw) float64 ...\n",
        +       "    step_size            (chain, draw) float64 ...\n",
        +       "    energy               (chain, draw) float64 ...\n",
        +       "    tree_depth           (chain, draw) int64 ...\n",
        +       "    perf_counter_diff    (chain, draw) float64 ...\n",
        +       "Attributes: (6)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    tau      (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n",
        +       "    theta    (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045\n",
        +       "    mu       (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (school: 8)\n",
        +       "Coordinates:\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (school: 8)\n",
        +       "Coordinates:\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    scores   (school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    \n", + "
    \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> predictions\n", + "\t> log_likelihood\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idata_shifted_obs = idata.map(lambda x: x + 3, groups=\"posterior\")\n", + "idata_shifted_obs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also add extra coordinates using `map`" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
    \n", + "
    \n", + "
    arviz.InferenceData
    \n", + "
    \n", + "
      \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:            (draw: 500, school: 8, school_bis: 8)\n",
        +       "Coordinates:\n",
        +       "  * draw               (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
        +       "  * school             (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
        +       "  * school_bis         (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    mu                 (draw) float64 5.974 5.096 7.177 ... 3.284 4.739 3.146\n",
        +       "    theta              (draw, school) float64 9.519 5.554 6.118 ... 5.595 3.773\n",
        +       "    tau                (draw) float64 4.068 3.156 3.603 ... 2.725 3.225 2.979\n",
        +       "    log_tau            (draw) float64 1.322 1.118 1.234 ... 0.958 1.035 0.9508\n",
        +       "    mlogtau            (draw) float64 nan nan nan nan ... 0.993 1.002 1.01 1.021\n",
        +       "    theta_school_diff  (draw, school, school_bis) float64 0.0 3.965 ... 0.0

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 4, draw: 500, school: 8, Upper: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0 1 2 3\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "    upper    (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON'\n",
        +       "Dimensions without coordinates: Upper\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
        +       "Coordinates:\n",
        +       "  * chain       (chain) int64 0 1 2 3\n",
        +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
        +       "  * new_school  (new_school) <U13 'Essex College' 'Moordale'\n",
        +       "Data variables:\n",
        +       "    obs         (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n",
        +       "Attributes: (2)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0 1 2 3\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:              (chain: 4, draw: 500)\n",
        +       "Coordinates:\n",
        +       "  * chain                (chain) int64 0 1 2 3\n",
        +       "  * draw                 (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
        +       "Data variables: (12/16)\n",
        +       "    max_energy_error     (chain, draw) float64 ...\n",
        +       "    energy_error         (chain, draw) float64 ...\n",
        +       "    lp                   (chain, draw) float64 ...\n",
        +       "    index_in_trajectory  (chain, draw) int64 ...\n",
        +       "    acceptance_rate      (chain, draw) float64 ...\n",
        +       "    diverging            (chain, draw) bool ...\n",
        +       "    ...                   ...\n",
        +       "    smallest_eigval      (chain, draw) float64 ...\n",
        +       "    step_size_bar        (chain, draw) float64 ...\n",
        +       "    step_size            (chain, draw) float64 ...\n",
        +       "    energy               (chain, draw) float64 ...\n",
        +       "    tree_depth           (chain, draw) int64 ...\n",
        +       "    perf_counter_diff    (chain, draw) float64 ...\n",
        +       "Attributes: (6)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (draw: 500, school: 8)\n",
        +       "Coordinates:\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    tau      (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n",
        +       "    theta    (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045\n",
        +       "    mu       (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (chain: 1, draw: 500, school: 8, Upper: 8)\n",
        +       "Coordinates:\n",
        +       "  * chain    (chain) int64 0\n",
        +       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "    upper    (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON'\n",
        +       "Dimensions without coordinates: Upper\n",
        +       "Data variables:\n",
        +       "    obs      (chain, draw, school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (school: 8, Upper: 8)\n",
        +       "Coordinates:\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "    upper    (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON'\n",
        +       "Dimensions without coordinates: Upper\n",
        +       "Data variables:\n",
        +       "    obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:  (school: 8)\n",
        +       "Coordinates:\n",
        +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
        +       "Data variables:\n",
        +       "    scores   (school) float64 ...\n",
        +       "Attributes: (4)

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    \n", + "
    \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> predictions\n", + "\t> log_likelihood\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_upper = np.array([\n", + " x.upper() for x in idata.observed_data.school.values\n", + "]).T \n", + "idata_with_upper = idata.map(\n", + " lambda ds, **kwargs: ds.assign_coords(**kwargs),\n", + " groups=\"observed_vars\",\n", + " upper=(\"Upper\", _upper),\n", + ")\n", + "idata_with_upper" + ] + }, { "cell_type": "code", "execution_count": null, @@ -23376,7 +36056,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.12" }, "varInspector": { "cols": {