Skip to content

Commit

Permalink
Refactoring _run_pyscf
Browse files Browse the repository at this point in the history
  • Loading branch information
max-radin committed Feb 9, 2024
1 parent 32ef2e8 commit 3d6c68f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def get_double_factorized_hamiltonian_block_encoding(
lam = compute_lambda_df(h1, eri_rr, LR)

allowable_phase_estimation_error = 1
(step_cost, total_cost, num_qubits,) = _get_double_factorized_qpe_info(
(
step_cost,
total_cost,
num_qubits,
) = _get_double_factorized_qpe_info(
h1,
eri,
threshold,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,16 @@ def _run_pyscf(
Raises:
SCFConvergenceError: If the SCF calculation does not converge.
ValueError: If mlflow_experiment_name is set but orq_workspace_id is not or
scf_options contains a "callback" key.
"""
molecule = _get_pyscf_molecule(mol_spec)
mean_field_object = (scf.RHF if mol_spec.multiplicity == 1 else scf.ROHF)(molecule)
mean_field_object.max_memory = 1e6 # set allowed memory high so tests pass

run_id = None
updated_scf_options = {}
if scf_options is not None:
updated_scf_options.update(scf_options)

if mlflow_experiment_name is not None:
if orq_workspace_id is None:
Expand All @@ -232,63 +236,26 @@ def _run_pyscf(
flat_mol_dict = _flatten_dict(asdict(mol_spec))
flat_active_dict = _flatten_dict(asdict(active_space_spec))

if scf_options is not None:
if "callback" in scf_options:
# user has defined a callback in scf_options
raise ValueError("scf_options should not contain a 'callback' key if mlflow_experiment_name is set.")
else:
# we want to log to mlflow, BUT haven't defined the
# callback in scf_options
client, run_id = _create_mlflow_setup(
mlflow_experiment_name, orq_workspace_id
)

for key, val in flat_mol_dict.items():
client.log_param(run_id, key, val)
for key, val in flat_active_dict.items():
client.log_param(run_id, key, val)
temp_options = deepcopy(scf_options)
temp_options["callback"] = create_mlflow_scf_callback(client, run_id)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The 'sym_pos' keyword is deprecated and should be",
)
mean_field_object.run(**temp_options)
else:
# we want to log to mlflow, but haven't defined scf_options
client, run_id = _create_mlflow_setup(
mlflow_experiment_name, orq_workspace_id
if scf_options is not None and "callback" in scf_options:
raise ValueError(
"scf_options should not contain a 'callback' key if mlflow_experiment_name is set."
)

for key, val in flat_mol_dict.items():
client.log_param(run_id, key, val)
for key, val in flat_active_dict.items():
client.log_param(run_id, key, val)
temp_options = {"callback": create_mlflow_scf_callback(client, run_id)}
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The 'sym_pos' keyword is deprecated and should be",
)
mean_field_object.run(**temp_options)
else:
if scf_options is not None:
# we don't want to run on mlflow, but we've specified scf_options
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The 'sym_pos' keyword is deprecated and should be",
)
mean_field_object.run(**scf_options)
else:
# we don't want to run on mlflow, and haven't specified scf_options
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The 'sym_pos' keyword is deprecated and should be",
)
mean_field_object.run()
client, run_id = _create_mlflow_setup(mlflow_experiment_name, orq_workspace_id)

for key, val in flat_mol_dict.items():
client.log_param(run_id, key, val)
for key, val in flat_active_dict.items():
client.log_param(run_id, key, val)

updated_scf_options["callback"] = create_mlflow_scf_callback(client, run_id)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The 'sym_pos' keyword is deprecated and should be",
)
mean_field_object.run(**updated_scf_options)

if not mean_field_object.converged:
raise SCFConvergenceError()
Expand Down
50 changes: 10 additions & 40 deletions tests/benchq/problem_ingestion/test_molecular_hamiltonians.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,35 +227,20 @@ def test_get_active_space_hamiltonian_logs_to_mlflow_no_specified_callback(
patch_log_metric.assert_any_call(ANY, ANY, "cput0_0", ANY)


def test_get_active_space_hamiltonian_logs_to_mlflow_with_specified_callback(
patch_sdk_token,
patch_sdk_uri,
patch_local_client,
):
def test_get_active_space_hamiltonian_raises_error_when_mlflow_and_callback_specified():
# Given
experiment_name = patch_local_client.create_experiment(
"test_get_active_space_hamiltonian_logs_to_mlflow_with_specified_callback",
)
experiment = patch_local_client.get_experiment_by_name(name=experiment_name)
run_id = patch_local_client.create_run(experiment.experiment_id).info.run_id
scf_callback = create_mlflow_scf_callback(patch_local_client, run_id)
scf_options = {"callback": scf_callback}
scf_options = {"callback": lambda: None}
new_hydrogen_chain_instance = get_hydrogen_chain_hamiltonian_generator(
2,
mlflow_experiment_name="pytest",
scf_options=scf_options,
orq_workspace_id="testing",
)

# When
_ = new_hydrogen_chain_instance.get_active_space_hamiltonian()

# Then
patch_local_client.log_metric.assert_called()

# last param (value) depends on optimization, so could be different run-to-run
patch_local_client.log_metric.assert_any_call(ANY, "last_hf_e", ANY)
patch_local_client.log_metric.assert_any_call(ANY, "cput0_0", ANY)
with pytest.raises(ValueError):
# When
_ = new_hydrogen_chain_instance.get_active_space_hamiltonian()


def test_get_active_space_hamiltonian_logs_to_mlflow_with_scf_options_no_callback(
Expand Down Expand Up @@ -325,35 +310,20 @@ def test_get_active_space_meanfield_object_logs_to_mlflow_no_specified_callback(
patch_log_metric.assert_any_call(ANY, ANY, "cput0_0", ANY)


def test_get_active_space_meanfield_object_logs_to_mlflow_with_specified_callback(
patch_sdk_token,
patch_sdk_uri,
patch_local_client,
):
def test_get_active_space_meanfield_object_raises_error_when_mlflow_and_callback_specified():
# Given
experiment_name = patch_local_client.create_experiment(
"test_get_active_space_hamiltonian_logs_to_mlflow_with_specified_callback",
)
experiment = patch_local_client.get_experiment_by_name(name=experiment_name)
run_id = patch_local_client.create_run(experiment.experiment_id).info.run_id
scf_callback = create_mlflow_scf_callback(patch_local_client, run_id)
scf_options = {"callback": scf_callback}
scf_options = {"callback": lambda: None}
new_hydrogen_chain_instance = get_hydrogen_chain_hamiltonian_generator(
2,
mlflow_experiment_name="pytest",
orq_workspace_id="testing",
scf_options=scf_options,
)

# When
_ = new_hydrogen_chain_instance.get_active_space_meanfield_object()

# Then
patch_local_client.log_metric.assert_called()

# last param (value) depends on optimization, so could be different run-to-run
patch_local_client.log_metric.assert_any_call(ANY, "last_hf_e", ANY)
patch_local_client.log_metric.assert_any_call(ANY, "cput0_0", ANY)
with pytest.raises(ValueError):
# When
_ = new_hydrogen_chain_instance.get_active_space_meanfield_object()


def test_get_active_space_meanfield_object_logs_to_mlflow_with_scf_options_no_callback(
Expand Down

0 comments on commit 3d6c68f

Please sign in to comment.