Skip to content

Commit

Permalink
Format all python files with black
Browse files Browse the repository at this point in the history
  • Loading branch information
langdal committed Apr 17, 2023
1 parent 8edd1b9 commit e31d3e2
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 134 deletions.
23 changes: 12 additions & 11 deletions optimizerapi/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import os
from keycloak import KeycloakOpenID

AUTH_API_KEY = os.getenv("AUTH_API_KEY", 'none')
AUTH_API_KEY = os.getenv("AUTH_API_KEY", "none")
AUTH_SERVER = os.getenv("AUTH_SERVER", None)
AUTH_CLIENT_ID = os.getenv("AUTH_CLIENT_ID", None)
AUTH_CLIENT_SECRET = os.getenv("AUTH_CLIENT_SECRET", None)
AUTH_REALM_NAME = os.getenv("AUTH_REALM_NAME", None)

keycloak_openid = KeycloakOpenID(server_url=AUTH_SERVER,
realm_name=AUTH_REALM_NAME,
client_id=AUTH_CLIENT_ID,
client_secret_key=AUTH_CLIENT_SECRET
)
keycloak_openid = KeycloakOpenID(
server_url=AUTH_SERVER,
realm_name=AUTH_REALM_NAME,
client_id=AUTH_CLIENT_ID,
client_secret_key=AUTH_CLIENT_SECRET,
)


def token_info(access_token) -> dict:
Expand All @@ -29,13 +30,13 @@ def token_info(access_token) -> dict:
"""
print(access_token)
if not AUTH_SERVER:
return {'scope': []}
return {"scope": []}
token = access_token
token_data = keycloak_openid.introspect(token)
if 'active' in token_data and token_data['active']:
print('OK')
if "active" in token_data and token_data["active"]:
print("OK")
return token_data
print('NOT OK')
print("NOT OK")
print(token_data)
return None

Expand All @@ -50,5 +51,5 @@ def apikey_handler(access_token) -> dict:
None in case of invalid token
"""
if not AUTH_SERVER and AUTH_API_KEY == access_token:
return {'scope': []}
return {"scope": []}
return None
129 changes: 65 additions & 64 deletions optimizerapi/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

queue = Queue(connection=Redis.from_url(REDIS_URL))

plt.switch_backend('Agg')
plt.switch_backend("Agg")


def run(body) -> dict:
Expand All @@ -52,42 +52,45 @@ def run(body) -> dict:


def do_run_work(body) -> dict:
""""Handle the run request
"""
""" "Handle the run request"""
try:
return __handle_run(body)
except IOError as err:
return ({'message': 'I/O error', 'error': str(err)}, 400)
return ({"message": "I/O error", "error": str(err)}, 400)
except TypeError as err:
return ({'message': 'Type error', 'error': str(err)}, 400)
return ({"message": "Type error", "error": str(err)}, 400)
except ValueError as err:
return ({'message': 'Validation error', 'error': str(err)}, 400)
return ({"message": "Validation error", "error": str(err)}, 400)
except Exception as err:
# Log unknown exceptions to support debugging
traceback.print_exc()
return ({'message': 'Unknown error', 'error': str(err)}, 500)
return ({"message": "Unknown error", "error": str(err)}, 500)


def __handle_run(body) -> dict:
""""Handle the run request
"""
""" "Handle the run request"""
# print("Receive: " + str(body))
data = [(run["xi"], run["yi"]) for run in body["data"]]
cfg = body["optimizerConfig"]
extras = {}
if "extras" in body:
extras = body["extras"]
print("Received extras " + str(extras))
space = [(convert_number_type(x["from"], x["type"]),
convert_number_type(x["to"], x["type"]))
if (x["type"] == "discrete" or x["type"] == "continuous")
else tuple(x["categories"]) for x in cfg["space"]]
space = [
(
convert_number_type(x["from"], x["type"]),
convert_number_type(x["to"], x["type"]),
)
if (x["type"] == "discrete" or x["type"] == "continuous")
else tuple(x["categories"])
for x in cfg["space"]
]
dimensions = [x["name"] for x in cfg["space"]]
hyperparams = {
'base_estimator': cfg["baseEstimator"],
'acq_func': cfg["acqFunc"],
'n_initial_points': cfg["initialPoints"],
'acq_func_kwargs': {'kappa': cfg["kappa"], 'xi': cfg["xi"]}
"base_estimator": cfg["baseEstimator"],
"acq_func": cfg["acqFunc"],
"n_initial_points": cfg["initialPoints"],
"acq_func_kwargs": {"kappa": cfg["kappa"], "xi": cfg["xi"]},
}

Xi = []
Expand All @@ -110,16 +113,15 @@ def __handle_run(body) -> dict:
else:
result = []

response = process_result(
result, optimizer, dimensions, cfg, extras, data, space)
response = process_result(result, optimizer, dimensions, cfg, extras, data, space)

response["result"]["extras"]["parameters"] = {
"dimensions": dimensions,
"space": space,
"hyperparams": hyperparams,
"Xi": Xi,
"Yi": Yi,
"extras": extras
"extras": extras,
}

# It is necesarry to convert response to a json string and then back to
Expand All @@ -128,8 +130,7 @@ def __handle_run(body) -> dict:


def convert_number_type(value, num_type):
"""Converts input value to either integer or float depending on the string supplied in numType
"""
"""Converts input value to either integer or float depending on the string supplied in numType"""
if num_type == "discrete":
return int(value)
return float(value)
Expand Down Expand Up @@ -168,17 +169,9 @@ def process_result(result, optimizer, dimensions, cfg, extras, data, space):
model representation etc.}
}
"""
result_details = {
"next": [],
"models": [],
"pickled": "",
"extras": {}
}
result_details = {"next": [], "models": [], "pickled": "", "extras": {}}
plots = []
response = {
"plots": plots,
"result": result_details
}
response = {"plots": plots, "result": result_details}
# GraphFormat should, at the moment, be either "png" or "none". Default (legacy)
# behavior is "png", so the API returns png images. Any other input is interpreted
# as "None" at the moment.
Expand All @@ -201,40 +194,49 @@ def process_result(result, optimizer, dimensions, cfg, extras, data, space):
if len(data) >= cfg["initialPoints"]:
# Some calculations are only possible if the model has
# processed more than "initialPoints" data points
result_details["models"] = [process_model(
model, optimizer) for model in result]
result_details["models"] = [process_model(model, optimizer) for model in result]
if graph_format == "png":
for idx, model in enumerate(result):
plot_convergence(model)
add_plot(plots, f"convergence_{idx}")

plot_objective(model, dimensions=dimensions,
usepartialdependence=False,
show_confidence=True,
pars=objective_pars)
plot_objective(
model,
dimensions=dimensions,
usepartialdependence=False,
show_confidence=True,
pars=objective_pars,
)
add_plot(plots, f"objective_{idx}")

if optimizer.n_objectives == 1:
minimum = expected_minimum(result[0])

result_details["expected_minimum"] = [
round_to_length_scales(minimum[0], optimizer.space), round(minimum[1], 2)]
round_to_length_scales(minimum[0], optimizer.space),
round(minimum[1], 2),
]
else:
plot_Pareto(optimizer)
add_plot(plots, "pareto")

result_details["pickled"] = pickleToString(
result, get_crypto())
result_details["pickled"] = pickleToString(result, get_crypto())

add_version_info(result_details["extras"])

# print(str(response))
org_models = response["result"]['models']
org_models = response["result"]["models"]
for model in org_models:
# Flatten expected minimum entries
model['expected_minimum'] = [[
item for sublist in [x if isinstance(
x, list) else [x] for x in model['expected_minimum']] for item in sublist]]
model["expected_minimum"] = [
[
item
for sublist in [
x if isinstance(x, list) else [x] for x in model["expected_minimum"]
]
for item in sublist
]
]
return response


Expand All @@ -251,13 +253,12 @@ def process_model(model, optimizer):
dict
a dictionary containing the model specific results.
"""
result_details = {
"expected_minimum": [],
"extras": {}
}
result_details = {"expected_minimum": [], "extras": {}}
minimum = expected_minimum(model)
result_details["expected_minimum"] = [
round_to_length_scales(minimum[0], optimizer.space), round(minimum[1], 2)]
round_to_length_scales(minimum[0], optimizer.space),
round(minimum[1], 2),
]
return result_details


Expand All @@ -283,29 +284,26 @@ def add_plot(result, id="generic", close=True, debug=False):
relative to current working directory. (default is False)
"""
pic_io_bytes = io.BytesIO()
plt.savefig(pic_io_bytes, format='png', bbox_inches='tight')
plt.savefig(pic_io_bytes, format="png", bbox_inches="tight")
pic_io_bytes.seek(0)
pic_hash = base64.b64encode(pic_io_bytes.read())
result.append({
"id": id,
"plot": str(pic_hash, "utf-8")
})
result.append({"id": id, "plot": str(pic_hash, "utf-8")})

if debug:
with open('tmp/process_optimizer_' + id + '.png', 'wb') as imgfile:
plt.savefig(imgfile, bbox_inches='tight', pad_inches=0)
with open("tmp/process_optimizer_" + id + ".png", "wb") as imgfile:
plt.savefig(imgfile, bbox_inches="tight", pad_inches=0)

# print("IMAGE: " + str(pic_hash, "utf-8"))
if close:
plt.clf()


def round_to_length_scales(x, space):
""" Rounds a suggested experiment to to the length scales of each dimension
"""Rounds a suggested experiment to to the length scales of each dimension
For each dimension the length of the dimension is calculated and the
length scale is defined as 1/1000th of the length.
The precision is the n in 10^n which is the closest to the
The precision is the n in 10^n which is the closest to the
length_scale (rounded) up .
The suggested experiment value is then rounded to n decimals
Expand All @@ -326,10 +324,10 @@ def round_to_length_scales(x, space):
if isinstance(dim, Real):
length = dim.high - dim.low
# Length scale of the dimension is 1/1000 of the dimension length
length_scale = length/1000
length_scale = length / 1000
# The precision is found by taking the
# negative log10 to the length scale ceiled
precision = int(numpy.ceil(- numpy.log10(length_scale)))
precision = int(numpy.ceil(-numpy.log10(length_scale)))

# If multiple experiments round dimension values for all experiments
# else round dimension value
Expand Down Expand Up @@ -361,9 +359,12 @@ def add_version_info(extras):
extras["apiVersion"] = version_file.readline().rstrip()
else:
try:
extras["apiVersion"] = subprocess.check_output(
["git", "describe", "--always"]).strip().decode()
extras["apiVersion"] = (
subprocess.check_output(["git", "describe", "--always"])
.strip()
.decode()
)
except IOError:
extras["apiVersion"] = 'Unknown development version'
extras["apiVersion"] = "Unknown development version"

extras["timeOfExecution"] = strftime("%Y-%m-%d %H:%M:%S")
15 changes: 6 additions & 9 deletions optimizerapi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from flask_cors import CORS
from .securepickle import get_crypto

if __name__ == '__main__':
if __name__ == "__main__":
# Initialize crypto
get_crypto()
app = connexion.FlaskApp(
__name__, port=9090, specification_dir='./openapi/')
app.add_api('specification.yml', strict_validation=True,
validate_responses=True)
app = connexion.FlaskApp(__name__, port=9090, specification_dir="./openapi/")
app.add_api("specification.yml", strict_validation=True, validate_responses=True)

DEVELOPMENT = "development"
flask_env = os.getenv("FLASK_ENV", DEVELOPMENT)
Expand All @@ -25,8 +23,7 @@
# It should be easy to get started developing locally which is the reason
# why we allow for all origins in development mode.
ALLOW_ALL_ORIGINS = ".*"
cors_origin = os.getenv(
"CORS_ORIGIN", ALLOW_ALL_ORIGINS if development else None)
cors_origin = os.getenv("CORS_ORIGIN", ALLOW_ALL_ORIGINS if development else None)

# By default we do not want to enable CORS. That should be a conscious
# descision from the host of the API server. This way we do not expose any
Expand All @@ -44,7 +41,7 @@
# our environment variable. The List would be cumbersome to
# parse and the simple string is not enough functionality for
# what we want to support.
origins=re.compile(cors_origin)
origins=re.compile(cors_origin),
)
print("CORS: " + cors_origin)
except re.error:
Expand All @@ -55,4 +52,4 @@
if development:
app.run()
else:
serve(app, listen='*:9090')
serve(app, listen="*:9090")
2 changes: 1 addition & 1 deletion optimizerapi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
else:
REDIS_URL = "redis://localhost:6379"

if __name__ == '__main__':
if __name__ == "__main__":
with Connection(Redis.from_url(REDIS_URL)):
queue = Queue()
Worker(queue).work()
6 changes: 4 additions & 2 deletions tests/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys

if True:
sys.path.insert(0, os.path.abspath(os.path.join(
os.path.dirname(__file__), '../optimizerapi')))
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../optimizerapi"))
)
import optimizer
import securepickle
Loading

0 comments on commit e31d3e2

Please sign in to comment.