diff --git a/CHANGELOG.md b/CHANGELOG.md
index eee7de05d7..1d2e6a9e74 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,7 +7,7 @@
- Adds an option "voltage as a state" that can be "false" (default) or "true". If "true" adds an explicit algebraic equation for the voltage. ([#4507](https://github.com/pybamm-team/PyBaMM/pull/4507))
- Improved `QuickPlot` accuracy for simulations with Hermite interpolation. ([#4483](https://github.com/pybamm-team/PyBaMM/pull/4483))
- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464))
-- Added basic telemetry to record which functions are being run. See [Telemetry section in the User Guide](https://docs.pybamm.org/en/latest/source/user_guide/index.html#telemetry) for more information. ([#4441](https://github.com/pybamm-team/PyBaMM/pull/4441))
+- Added basic telemetry to record which functions are being run. See [Telemetry section in the User Guide](https://docs.pybamm.org/en/latest/source/user_guide/index.html#telemetry) for more information. ([#4441](https://github.com/pybamm-team/PyBaMM/pull/4441), [#4583](https://github.com/pybamm-team/PyBaMM/pull/4583))
- Added `BasicDFN` model for sodium-ion batteries ([#4451](https://github.com/pybamm-team/PyBaMM/pull/4451))
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))
- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449))
diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst
index 9225f1ee98..d475a1b308 100644
--- a/docs/source/user_guide/installation/index.rst
+++ b/docs/source/user_guide/installation/index.rst
@@ -72,6 +72,7 @@ Package Minimum supp
`pandas `__ 1.5.0
`pooch `__ 1.8.1
`posthog `__ 3.6.5
+`platformdirs `__ 4.0.0
=================================================================== ==========================
.. _install.optional_dependencies:
diff --git a/noxfile.py b/noxfile.py
index d65812b8ed..bd38528778 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -253,9 +253,11 @@ def run_tests(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("setuptools", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
- specific_test_files = session.posargs if session.posargs else []
session.run(
- "python", "-m", "pytest", *specific_test_files, "-m", "unit or integration"
+ "python",
+ "-m",
+ "pytest",
+ *(session.posargs if session.posargs else ["-m", "unit or integration"]),
)
diff --git a/pyproject.toml b/pyproject.toml
index b83c3704fe..b25f73e7c7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,6 +45,7 @@ dependencies = [
"pandas>=1.5.0",
"pooch>=1.8.1",
"posthog",
+ "platformdirs",
]
[project.urls]
diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py
index b466c3896b..5630ae8526 100644
--- a/src/pybamm/__init__.py
+++ b/src/pybamm/__init__.py
@@ -4,10 +4,10 @@
demote_expressions_to_32bit = False
# Utility classes and methods
-from .util import root_dir
from .util import Timer, TimerTime, FuzzyDict
from .util import (
root_dir,
+ is_notebook,
load,
is_constant_and_can_evaluate,
)
@@ -190,7 +190,7 @@
from .plotting.dynamic_plot import dynamic_plot
# Simulation
-from .simulation import Simulation, load_sim, is_notebook
+from .simulation import Simulation, load_sim
# Batch Study
from .batch_study import BatchStudy
diff --git a/src/pybamm/config.py b/src/pybamm/config.py
index ba7171e5d2..69b5e0e6d7 100644
--- a/src/pybamm/config.py
+++ b/src/pybamm/config.py
@@ -1,11 +1,14 @@
-import uuid
import os
-import platformdirs
-from pathlib import Path
-import pybamm
import sys
-import threading
import time
+import uuid
+
+from pathlib import Path
+
+import pybamm
+import platformdirs
+
+from pybamm.util import is_notebook
def is_running_tests(): # pragma: no cover
@@ -45,20 +48,118 @@ def is_running_tests(): # pragma: no cover
return False
-def ask_user_opt_in(timeout=10):
+def get_input_or_timeout(timeout): # pragma: no cover
"""
- Ask the user if they want to opt in to telemetry.
+ Cross-platform input with timeout, using various methods depending on the
+ environment. Works in Jupyter notebooks, Windows, and Unix-like systems.
+
+ Args:
+ timeout (float): Timeout in seconds
+
+ Returns:
+ str | None: The user input if received before the timeout, or None if:
+ - The timeout was reached
+ - Telemetry is disabled via environment variable
+ - Running in a non-interactive environment
+ - An error occurred
+
+ The caller can distinguish between an empty response (user just pressed Enter)
+ which returns '', and a timeout/error which returns None.
+ """
+ # Check for telemetry disable flag
+ if os.getenv("PYBAMM_DISABLE_TELEMETRY", "false").lower() != "false":
+ return None
+
+ if not (sys.stdin.isatty() or is_notebook()):
+ return None
+
+ # 1. Handling for Jupyter notebooks. This is a simplified
+ # implementation in comparison to widgets because notebooks
+ # are usually run interactively, and we don't need to worry
+ # about the event loop. The input is disabled when running
+ # tests due to the environment variable set.
+ if is_notebook(): # pragma: no cover
+ try:
+ from IPython.display import clear_output
+
+ user_input = input("Do you want to enable telemetry? (Y/n): ")
+ clear_output()
+
+ return user_input
- Parameters
- ----------
- timeout : float, optional
- The timeout for the user to respond to the prompt. Default is 10 seconds.
+ except Exception: # pragma: no cover
+ return None
- Returns
- -------
- bool
- True if the user opts in, False otherwise.
+ # 2. Windows-specific handling
+ if sys.platform == "win32":
+ try:
+ import msvcrt
+
+ start_time = time.time()
+ input_chars = []
+ sys.stdout.write("Do you want to enable telemetry? (Y/n): ")
+ sys.stdout.flush()
+
+ while time.time() - start_time < timeout:
+ if msvcrt.kbhit():
+ char = msvcrt.getwche()
+ if char in ("\r", "\n"):
+ sys.stdout.write("\n")
+ return "".join(input_chars)
+ input_chars.append(char)
+ time.sleep(0.1)
+ return None
+ except Exception:
+ return None
+
+ # 3. POSIX-like systems will need to use termios
+ else: # pragma: no cover
+ try:
+ import termios
+ import tty
+ import select
+
+ # Save terminal settings for later
+ old_settings = termios.tcgetattr(sys.stdin)
+ try:
+ # Set terminal to raw mode
+ tty.setraw(sys.stdin.fileno())
+
+ sys.stdout.write("Do you want to enable telemetry? (Y/n): ")
+ sys.stdout.flush()
+
+ input_chars = []
+ start_time = time.time()
+
+ while time.time() - start_time < timeout:
+ rlist, _, _ = select.select([sys.stdin], [], [], 0.1)
+ if rlist:
+ char = sys.stdin.read(1)
+ if char in ("\r", "\n"):
+ sys.stdout.write("\n")
+ return "".join(input_chars)
+ input_chars.append(char)
+ sys.stdout.write(char)
+ sys.stdout.flush()
+ return None
+
+ finally:
+ # Restore saved terminal settings
+ termios.tcsetattr(sys.stdin, termios.TCSAFLUSH, old_settings)
+ sys.stdout.write("\n")
+ sys.stdout.flush()
+
+ except Exception: # pragma: no cover
+ return None
+
+ return None
+
+
+def ask_user_opt_in(timeout=10): # pragma: no cover
"""
+ Ask the user if they want to opt in to telemetry.
+ """
+
print(
"PyBaMM can collect usage data and send it to the PyBaMM team to "
"help us improve the software.\n"
@@ -69,44 +170,25 @@ def ask_user_opt_in(timeout=10):
"For more information, see https://docs.pybamm.org/en/latest/source/user_guide/index.html#telemetry"
)
- def get_input(): # pragma: no cover
- try:
- user_input = (
- input("Do you want to enable telemetry? (Y/n): ").strip().lower()
- )
- answer.append(user_input)
- except Exception:
- # Handle any input errors
- pass
+ user_input = get_input_or_timeout(timeout)
- time_start = time.time()
+ if user_input is None:
+ print("\nTimeout reached. Defaulting to not enabling telemetry.")
+ return False
while True:
- if time.time() - time_start > timeout:
- print("\nTimeout reached. Defaulting to not enabling telemetry.")
+ if user_input.lower() in ["y", "yes", ""]:
+ print("Telemetry enabled.\n")
+ return True
+ elif user_input.lower() in ["n", "no"]:
+ print("Telemetry disabled.")
return False
-
- answer = []
- # Create and start input thread
- input_thread = threading.Thread(target=get_input)
- input_thread.daemon = True
- input_thread.start()
-
- # Wait for either timeout or input
- input_thread.join(timeout)
-
- if answer:
- if answer[0] in ["yes", "y", ""]:
- print("\nTelemetry enabled.\n")
- return True
- elif answer[0] in ["no", "n"]:
- print("\nTelemetry disabled.\n")
- return False
- else:
- print("\nInvalid input. Please enter 'yes/y' for yes or 'no/n' for no.")
else:
- print("\nTimeout reached. Defaulting to not enabling telemetry.")
- return False
+ print("Invalid input. Please enter 'Y/y' for yes or 'n/N' for no:")
+ user_input = get_input_or_timeout(timeout)
+ if user_input is None:
+ print("\nTimeout reached. Defaulting to not enabling telemetry.")
+ return False
def generate():
diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py
index 75799d6334..e555e27683 100644
--- a/src/pybamm/simulation.py
+++ b/src/pybamm/simulation.py
@@ -16,24 +16,6 @@
from pybamm.expression_tree.operations.serialise import Serialise
-def is_notebook():
- try:
- shell = get_ipython().__class__.__name__
- if shell == "ZMQInteractiveShell": # pragma: no cover
- # Jupyter notebook or qtconsole
- cfg = get_ipython().config
- nb = len(cfg["InteractiveShell"].keys()) == 0
- return nb
- elif shell == "TerminalInteractiveShell": # pragma: no cover
- return False # Terminal running IPython
- elif shell == "Shell": # pragma: no cover
- return True # Google Colab notebook
- else: # pragma: no cover
- return False # Other type (?)
- except NameError:
- return False # Probably standard Python interpreter
-
-
class Simulation:
"""A Simulation class for easy building and running of PyBaMM simulations.
@@ -141,7 +123,7 @@ def __init__(
self._set_random_seed()
# ignore runtime warnings in notebooks
- if is_notebook(): # pragma: no cover
+ if pybamm.is_notebook(): # pragma: no cover
import warnings
warnings.filterwarnings("ignore")
diff --git a/src/pybamm/util.py b/src/pybamm/util.py
index 5b10b23fcb..c0c669fd9d 100644
--- a/src/pybamm/util.py
+++ b/src/pybamm/util.py
@@ -22,6 +22,24 @@ def root_dir():
return str(pathlib.Path(pybamm.__path__[0]).parent.parent)
+def is_notebook():
+ try:
+ shell = get_ipython().__class__.__name__
+ if shell == "ZMQInteractiveShell": # pragma: no cover
+ # Jupyter notebook or qtconsole
+ cfg = get_ipython().config
+ nb = len(cfg["InteractiveShell"].keys()) == 0
+ return nb
+ elif shell == "TerminalInteractiveShell": # pragma: no cover
+ return False # Terminal running IPython
+ elif shell == "Shell": # pragma: no cover
+ return True # Google Colab notebook
+ else: # pragma: no cover
+ return False # Other type (?)
+ except NameError:
+ return False # Probably standard Python interpreter
+
+
def get_git_commit_info():
"""
Get the git commit info for the current PyBaMM version, e.g. v22.8-39-gb25ce8c41
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index 62906b348d..67f7ae1dec 100644
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -1,6 +1,6 @@
import pytest
-import select
import sys
+import os
import pybamm
import uuid
@@ -35,86 +35,67 @@ def test_write_read_uuid(self, tmp_path, write_opt_in):
else:
assert config_dict["enable_telemetry"] is False
- @pytest.mark.parametrize("user_opted_in, user_input", [(True, "y"), (False, "n")])
- def test_ask_user_opt_in(self, monkeypatch, capsys, user_opted_in, user_input):
- # Mock select.select to simulate user input
- def mock_select(*args, **kwargs):
- return [sys.stdin], [], []
-
- monkeypatch.setattr(select, "select", mock_select)
-
- # Mock sys.stdin.readline to return the desired input
- monkeypatch.setattr(sys.stdin, "readline", lambda: user_input + "\n")
-
- # Call the function to ask the user if they want to opt in
- opt_in = pybamm.config.ask_user_opt_in()
-
- # Check the result
- assert opt_in is user_opted_in
-
- # Check that the prompt was printed
- captured = capsys.readouterr()
- assert "Do you want to enable telemetry? (Y/n):" in captured.out
-
- def test_ask_user_opt_in_invalid_input(self, monkeypatch, capsys):
- # Mock select.select to simulate user input and then timeout
- def mock_select(*args, **kwargs):
- nonlocal call_count
- if call_count == 0:
- call_count += 1
- return [sys.stdin], [], []
- else:
- return [], [], []
-
- monkeypatch.setattr(select, "select", mock_select)
-
- # Mock sys.stdin.readline to return invalid input
- monkeypatch.setattr(sys.stdin, "readline", lambda: "invalid\n")
-
- # Initialize call count
- call_count = 0
-
- # Call the function to ask the user if they want to opt in
- opt_in = pybamm.config.ask_user_opt_in(timeout=1)
-
- # Check the result (should be False for timeout after invalid input)
- assert opt_in is False
+ @pytest.mark.parametrize(
+ "input_sequence,expected_output,expected_messages",
+ [
+ (["y"], True, ["Telemetry enabled"]),
+ (["n"], False, ["Telemetry disabled"]),
+ ([""], True, ["Telemetry enabled"]),
+ (["x", "y"], True, ["Invalid input", "Telemetry enabled"]),
+ (["x", "n"], False, ["Invalid input", "Telemetry disabled"]),
+ (["x", None], False, ["Invalid input", "Timeout reached"]),
+ ([None], False, ["Timeout reached"]),
+ ],
+ )
+ def test_ask_user_opt_in_scenarios(
+ self, monkeypatch, capsys, input_sequence, expected_output, expected_messages
+ ):
+ # mock is_running_tests to return False. This is done
+ # temporarily here in order to prevent an early return.
+ monkeypatch.setattr(pybamm.config, "is_running_tests", lambda: False)
+ monkeypatch.setattr(os, "getenv", lambda x, y: "false")
+ monkeypatch.setattr(sys.stdin, "isatty", lambda: True)
+ monkeypatch.setattr(pybamm.util, "is_notebook", lambda: False)
- # Check that the prompt, invalid input message, and timeout message were printed
- captured = capsys.readouterr()
- assert "Do you want to enable telemetry? (Y/n):" in captured.out
- assert (
- "Invalid input. Please enter 'yes/y' for yes or 'no/n' for no."
- in captured.out
- )
- assert "Timeout reached. Defaulting to not enabling telemetry." in captured.out
+ # Mock get_input_or_timeout to return sequence of inputs
+ input_iter = iter(input_sequence)
- def test_ask_user_opt_in_timeout(self, monkeypatch, capsys):
- # Mock select.select to simulate a timeout
- def mock_select(*args, **kwargs):
- return [], [], []
+ def mock_get_input(timeout):
+ print("Do you want to enable telemetry? (Y/n): ", end="")
+ try:
+ return next(input_iter)
+ except StopIteration:
+ return None
- monkeypatch.setattr(select, "select", mock_select)
+ monkeypatch.setattr(pybamm.config, "get_input_or_timeout", mock_get_input)
- # Call the function to ask the user if they want to opt in
opt_in = pybamm.config.ask_user_opt_in(timeout=1)
-
- # Check the result (should be False for timeout)
- assert opt_in is False
-
- # Check that the prompt and timeout message were printed
captured = capsys.readouterr()
- assert "Do you want to enable telemetry? (Y/n):" in captured.out
- assert "Timeout reached. Defaulting to not enabling telemetry." in captured.out
- def test_generate_and_read(self, monkeypatch, tmp_path):
+ assert "Do you want to enable telemetry?" in captured.out
+ assert "PyBaMM can collect usage data" in captured.out
+ for message in expected_messages:
+ assert message in captured.out
+ assert opt_in is expected_output
+
+ @pytest.mark.parametrize(
+ "test_scenario",
+ [
+ "first_generation", # Test first-time config generation
+ "config_exists", # Test when config already exists
+ ],
+ )
+ def test_generate_and_read(self, monkeypatch, tmp_path, test_scenario, timeout=2):
# Mock is_running_tests to return False
monkeypatch.setattr(pybamm.config, "is_running_tests", lambda: False)
# Mock ask_user_opt_in to return True
- monkeypatch.setattr(pybamm.config, "ask_user_opt_in", lambda: True)
+ def mock_ask_user_opt_in(timeout=10):
+ return True
+
+ monkeypatch.setattr(pybamm.config, "ask_user_opt_in", mock_ask_user_opt_in)
- # Mock telemetry capture
+ # Track if capture was called
capture_called = False
def mock_capture(event):
@@ -127,31 +108,49 @@ def mock_capture(event):
# Mock config directory
monkeypatch.setattr(platformdirs, "user_config_dir", lambda x: str(tmp_path))
- # Test generate() creates new config
- pybamm.config.generate()
-
- # Verify config was created
- config = pybamm.config.read()
- assert config is not None
- assert config["enable_telemetry"] is True
- assert "uuid" in config
- assert capture_called is True
-
- # Test generate() does nothing if config exists
- capture_called = False
- pybamm.config.generate()
- assert capture_called is False
-
- def test_read_uuid_from_file_no_file(self):
- config_dict = pybamm.config.read_uuid_from_file(Path("nonexistent_file.yml"))
- assert config_dict is None
-
- def test_read_uuid_from_file_invalid_yaml(self, tmp_path):
- # Create a temporary directory and file with invalid YAML content
- invalid_yaml = tmp_path / "invalid_yaml.yml"
- with open(invalid_yaml, "w") as f:
- f.write("invalid: yaml: content:")
-
- config_dict = pybamm.config.read_uuid_from_file(invalid_yaml)
-
- assert config_dict is None
+ if test_scenario == "first_generation":
+ # Test first-time generation
+ pybamm.config.generate()
+
+ # Verify config was created
+ config = pybamm.config.read()
+ assert config is not None
+ assert config["enable_telemetry"] is True
+ assert "uuid" in config
+ assert (
+ capture_called is True
+ ) # Should not ask for capturing telemetry when config exists
+
+ else: # config_exists case
+ # First create a config
+ pybamm.config.generate()
+ capture_called = False # Reset the flag
+
+ # Now test that generating again does nothing
+ pybamm.config.generate()
+ assert (
+ capture_called is False
+ ) # Should not ask for capturing telemetry when config exists
+
+ @pytest.mark.parametrize(
+ "file_scenario,expected_output",
+ [
+ ("nonexistent", None),
+ ("invalid_yaml", None),
+ ],
+ )
+ def test_read_uuid_from_file_scenarios(
+ self, tmp_path, file_scenario, expected_output
+ ):
+ if file_scenario == "nonexistent":
+ config_dict = pybamm.config.read_uuid_from_file(
+ Path("nonexistent_file.yml")
+ )
+ else: # invalid_yaml
+ # Create a temporary directory and file with invalid YAML content
+ invalid_yaml = tmp_path / "invalid_yaml.yml"
+ with open(invalid_yaml, "w") as f:
+ f.write("invalid: yaml: content:")
+ config_dict = pybamm.config.read_uuid_from_file(invalid_yaml)
+
+ assert config_dict is expected_output