Skip to content

Commit

Permalink
new function to extrapolate lists
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Jul 26, 2024
1 parent 7206399 commit 632ba93
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 41 deletions.
48 changes: 7 additions & 41 deletions ogcore/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import paramtools
import ogcore
from ogcore import elliptical_u_est
from ogcore.utils import rate_conversion, extrapolate_arrays
from ogcore.utils import rate_conversion, extrapolate_arrays, extrapolate_nested_list
from ogcore.constants import BASELINE_DIR

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -223,53 +223,19 @@ def compute_default_params(self):
"mtry_params",
]
for item in tax_params_to_TP:
tax_to_set = getattr(self, item)
tax_to_set_in = getattr(self, item)
try:
tax_to_set = (
tax_to_set.tolist()
) # in case parameters are numpy arrays
except AttributeError: # catches if they are lists already
pass
if len(tax_to_set) == 1 and isinstance(tax_to_set[0], float):
setattr(
self,
item,
[
[[tax_to_set] for i in range(self.S)]
for t in range(self.T)
],
)
elif any(
[
isinstance(tax_to_set[i][j], list)
for i, v in enumerate(tax_to_set)
for j, vv in enumerate(tax_to_set[i])
]
):
if len(tax_to_set) > self.T + self.S:
tax_to_set = tax_to_set[: self.T + self.S]
if len(tax_to_set) < self.T + self.S:
tax_params_to_add = [tax_to_set[-1]] * (
self.T + self.S - len(tax_to_set)
)
tax_to_set.extend(tax_params_to_add)
if len(tax_to_set[0]) > self.S:
for t, v in enumerate(tax_to_set):
tax_to_set[t] = tax_to_set[t][: self.S]
if len(tax_to_set[0]) < self.S:
tax_params_to_add = [tax_to_set[:][-1]] * (
self.S - len(tax_to_set[0])
)
tax_to_set[0].extend(tax_params_to_add)
setattr(self, item, tax_to_set)
else:
len(tax_to_set_in[0][0])
except TypeError:
print(
"please give a "
+ item
+ " that is a single element or nested lists of"
+ " that is a nested lists of"
+ " lists that is three lists deep"
)
assert False
tax_to_set_out = extrapolate_nested_list(tax_to_set_in, dims=(self.T, self.S, len(tax_to_set_in[0][0])))
setattr(self, item, tax_to_set_out)

# Try to deal with size of eta. It may vary by S, J, T, but
# want to allow user to enter one that varies by only S, S and J,
Expand Down
63 changes: 63 additions & 0 deletions ogcore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,69 @@ def extrapolate_arrays(param_in, dims=None, item="Parameter Name"):
return param_out


def extrapolate_lists(list_in, dims=(400, 80, 1)):
try:
tax_to_set = (
tax_to_set.tolist()
) # in case parameters are numpy arrays
except AttributeError: # catches if they are lists already
pass
if len(tax_to_set) == 1 and isinstance(tax_to_set[0], float):
setattr(
self,
item,
[
[[tax_to_set] for i in range(self.S)]
for t in range(self.T)
],
)
elif any(
[
isinstance(tax_to_set[i][j], list)
for i, v in enumerate(tax_to_set)
for j, vv in enumerate(tax_to_set[i])
]
):
if len(tax_to_set) > self.T + self.S:
tax_to_set = tax_to_set[: self.T + self.S]
if len(tax_to_set) < self.T + self.S:
tax_params_to_add = [tax_to_set[-1]] * (
self.T + self.S - len(tax_to_set)
)
tax_to_set.extend(tax_params_to_add)
if len(tax_to_set[0]) > self.S:
for t, v in enumerate(tax_to_set):
tax_to_set[t] = tax_to_set[t][: self.S]
if len(tax_to_set[0]) < self.S:
tax_params_to_add = [tax_to_set[:][-1]] * (
self.S - len(tax_to_set[0])
)
tax_to_set[0].extend(tax_params_to_add)

# for t, v in enumerate(tax_to_set):
# for j, k in enumerate(tax_to_set)
# tax_params_to_add = [tax_to_set[t][-1]] * (
# self.S - len(tax_to_set[t])
# )
# tax_to_set[t].extend(tax_params_to_add)



print("TAX PARAMS TO ADD: ", tax_params_to_add)
print("TAX TO SET sizes before: ", len(tax_to_set), len(tax_to_set[0]), len(tax_to_set[0][0]))
tax_to_set[0].extend(tax_params_to_add)
print("TAX TO SET sizes after: ", len(tax_to_set), len(tax_to_set[0]), len(tax_to_set[0][0]))

setattr(self, item, tax_to_set)
else:
print(
"please give a "
+ item
+ " that is a single element or nested lists of"
+ " lists that is three lists deep"
)
assert False

class CustomHttpAdapter(requests.adapters.HTTPAdapter):
"""
The UN Data Portal server doesn't support "RFC 5746 secure renegotiation". This causes and error when the client is using OpenSSL 3, which enforces that standard by default.
Expand Down

0 comments on commit 632ba93

Please sign in to comment.