Skip to content

Commit

Permalink
Iter
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Aug 7, 2024
1 parent a4574fb commit 4269062
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 49 deletions.
91 changes: 58 additions & 33 deletions glue/sample/src/sinter/_collection/_collection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import queue
import tempfile
import threading
from typing import Any, Optional, List, Dict, Iterable, Callable, Tuple
from typing import Union
from typing import cast
Expand All @@ -20,29 +21,32 @@
class _ManagedWorkerState:
def __init__(self, worker_id: int, *, cpu_pin: Optional[int] = None):
self.worker_id: int = worker_id
self.process: Optional[multiprocessing.Process] = None
self.process: Union[multiprocessing.Process, threading.Thread, None] = None
self.input_queue: Optional[multiprocessing.Queue[Tuple[str, Any]]] = None
self.assigned_work_key: Any = None
self.assigned_shots_remote: int = 0
self.asked_to_drop_shots: int = 0
self.cpu_pin = cpu_pin

# Shots transfer into this field when manager sends shot requests to workers.
# Shots transfer out of this field when clients flush results or respond to work return requests.
self.assigned_shots: int = 0

def send_message(self, message: Any):
self.input_queue.put(message)

def ask_to_return_all_shots(self):
if self.asked_to_drop_shots == 0 and self.assigned_shots_remote > 0:
if self.asked_to_drop_shots == 0 and self.assigned_shots > 0:
self.send_message((
'return_shots',
(
self.assigned_work_key,
self.assigned_shots_remote,
self.assigned_shots,
),
))
self.asked_to_drop_shots = self.assigned_shots_remote
self.asked_to_drop_shots = self.assigned_shots

def has_returned_all_shots(self) -> bool:
return self.assigned_shots_remote == 0 and self.asked_to_drop_shots == 0
return self.assigned_shots == 0 and self.asked_to_drop_shots == 0

def is_available_to_reassign(self) -> bool:
return self.assigned_work_key is None
Expand Down Expand Up @@ -78,6 +82,7 @@ def __init__(
count_observable_error_combos: bool = False,
count_detection_events: bool = False,
custom_error_count_key: Optional[str] = None,
use_threads_for_debugging: bool = False,
):
assert isinstance(custom_decoders, dict)
self.existing_data = existing_data
Expand All @@ -92,6 +97,7 @@ def __init__(
self.count_observable_error_combos = count_observable_error_combos
self.count_detection_events = count_detection_events
self.custom_error_count_key = custom_error_count_key
self.use_threads_for_debugging = use_threads_for_debugging

self.shared_worker_output_queue: Optional[multiprocessing.SimpleQueue[Tuple[str, int, Any]]] = None
self.task_states: Dict[Any, _ManagedTaskState] = {}
Expand Down Expand Up @@ -149,18 +155,25 @@ def start_workers(self, *, actually_start_worker_processes: bool = True):
worker_state.input_queue = multiprocessing.Queue()
worker_state.input_queue.cancel_join_thread()
worker_state.assigned_work_key = None
worker_state.process = multiprocessing.Process(
target=collection_worker_loop,
args=(
self.worker_flush_period,
worker_id,
sampler,
worker_state.input_queue,
self.shared_worker_output_queue,
worker_state.cpu_pin,
self.custom_error_count_key,
),
args = (
self.worker_flush_period,
worker_id,
sampler,
worker_state.input_queue,
self.shared_worker_output_queue,
worker_state.cpu_pin,
self.custom_error_count_key,
)
if self.use_threads_for_debugging:
worker_state.process = threading.Thread(
target=collection_worker_loop,
args=args,
)
else:
worker_state.process = multiprocessing.Process(
target=collection_worker_loop,
args=args,
)

if actually_start_worker_processes:
worker_state.process.start()
Expand Down Expand Up @@ -237,6 +250,8 @@ def hard_stop(self):

removed_workers = [state.process for state in self.worker_states]
for state in self.worker_states:
if isinstance(state.process, threading.Thread):
state.send_message('stop')
state.process = None
state.assigned_work_key = None
state.input_queue = None
Expand All @@ -246,7 +261,8 @@ def hard_stop(self):

# SIGKILL everything.
for w in removed_workers:
w.kill()
if isinstance(w, multiprocessing.Process):
w.kill()
# Wait for them to be done.
for w in removed_workers:
w.join()
Expand All @@ -260,7 +276,7 @@ def _handle_task_progress(self, task_id: Any):
del self.task_states[task_id]
for worker_id in task_state.workers_assigned:
w = self.worker_states[worker_id]
assert w.assigned_shots_remote == 0
assert w.assigned_shots <= 0
assert w.asked_to_drop_shots == 0
w.assigned_work_key = None
self._distribute_work()
Expand All @@ -278,7 +294,7 @@ def state_summary(self) -> str:
for worker_id, worker in enumerate(self.worker_states):
lines.append(f'worker {worker_id}:'
f' asked_to_drop_shots={worker.asked_to_drop_shots}'
f' assigned_shots_remote={worker.assigned_shots_remote}'
f' assigned_shots={worker.assigned_shots}'
f' assigned_work_key={worker.assigned_work_key}')
for task in self.task_states.values():
lines.append(f'task {task.strong_id=}:\n'
Expand All @@ -304,8 +320,18 @@ def process_message(self) -> bool:
assert worker_state.assigned_work_key == task_strong_id
task_state = self.task_states[task_strong_id]

worker_state.assigned_shots_remote -= anon_stat.shots
worker_state.assigned_shots -= anon_stat.shots
task_state.shots_left -= anon_stat.shots
if worker_state.assigned_shots < 0:
# Worker over-achieved. Correct the imbalance by giving them the shots.
extra_shots = abs(worker_state.assigned_shots)
worker_state.assigned_shots += extra_shots
task_state.shots_unassigned -= extra_shots
worker_state.send_message((
'accept_shots',
(task_state.strong_id, extra_shots),
))

if self.custom_error_count_key is None:
task_state.errors_left -= anon_stat.errors
else:
Expand Down Expand Up @@ -346,9 +372,8 @@ def process_message(self) -> bool:
worker_state.asked_to_drop_shots = 0
worker_state.asked_to_drop_errors = 0
task_state.shots_unassigned += shots_returned
worker_state.assigned_shots_remote -= shots_returned
if worker_state.assigned_shots_remote < 0:
worker_state.assigned_shots_remote = 0
worker_state.assigned_shots -= shots_returned
assert worker_state.assigned_shots >= 0
self._handle_task_progress(task_key)

elif message_type == 'stopped_due_to_exception':
Expand Down Expand Up @@ -415,14 +440,14 @@ def _distribute_unassigned_work_to_workers_within_a_job(self, task_state: _Manag
expected_shots_per_worker = (task_state.shots_left + num_task_workers - 1) // num_task_workers

# Give unassigned shots to idle workers.
for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots_remote):
for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots):
worker_state = self.worker_states[worker_id]
if worker_state.assigned_shots_remote < expected_shots_per_worker:
shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots_remote,
if worker_state.assigned_shots < expected_shots_per_worker:
shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots,
task_state.shots_unassigned)
if shots_to_assign > 0:
task_state.shots_unassigned -= shots_to_assign
worker_state.assigned_shots_remote += shots_to_assign
worker_state.assigned_shots += shots_to_assign
worker_state.send_message((
'accept_shots',
(task_state.strong_id, shots_to_assign),
Expand Down Expand Up @@ -461,7 +486,7 @@ def status_message(self) -> str:
line = [
f'{w}',
self.partial_tasks[k].decoder,
("?" if dt is None else "[draining]" if dt <= 0 else "<1m" if dt < 1 else str(round(dt)) + 'm') + ('·∞' if w == 0 else ''),
("?" if dt is None or dt == 0 else "[draining]" if dt <= 0 else "<1m" if dt < 1 else str(round(dt)) + 'm') + ('·∞' if w == 0 else ''),
f'{max_shots - c.shots}' if max_shots is not None else f'{c.shots}',
f'{max_errors - c_errors}' if max_errors is not None else f'{c_errors}',
",".join(
Expand Down Expand Up @@ -518,18 +543,18 @@ def _take_work_if_unsatisfied_workers_within_a_job(self, task_state: _ManagedTas
if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0:
return

if all(self.worker_states[w].assigned_shots_remote for w in task_state.workers_assigned):
if all(self.worker_states[w].assigned_shots > 0 for w in task_state.workers_assigned):
return

w = len(task_state.workers_assigned)
expected_shots_per_worker = (task_state.shots_left + w - 1) // w

# There are idle workers that couldn't be given any shots. Take shots from other workers.
for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots_remote, reverse=True):
for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots, reverse=True):
worker_state = self.worker_states[worker_id]
if worker_state.asked_to_drop_shots or worker_state.assigned_shots_remote <= expected_shots_per_worker:
if worker_state.asked_to_drop_shots or worker_state.assigned_shots <= expected_shots_per_worker:
continue
shots_to_take = worker_state.assigned_shots_remote - expected_shots_per_worker
shots_to_take = worker_state.assigned_shots - expected_shots_per_worker
assert shots_to_take > 0
worker_state.asked_to_drop_shots = shots_to_take
task_state.shot_return_requests += 1
Expand Down
34 changes: 34 additions & 0 deletions glue/sample/src/sinter/_collection/_collection_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import collections
import math
import pathlib
import sys
import tempfile
import time

import pytest
import stim
Expand Down Expand Up @@ -238,3 +241,34 @@ def test_fixed_size_sampler():
custom_decoders={'fixed_size_sampler': FixedSizeSampler()}
)
assert 100_000 <= results[0].shots <= 100_000 + 3000


class MockTimingSampler(sinter.Sampler, sinter.CompiledSampler):
def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler:
return self

def sample(self, suggested_shots: int) -> 'sinter.AnonTaskStats':
actual_shots = -(-suggested_shots // 1024) * 1024
time.sleep(actual_shots * 0.00001)
return sinter.AnonTaskStats(
shots=actual_shots,
errors=5,
seconds=actual_shots * 0.00001,
)


def test_mock_timing_sampler():
results = sinter.collect(
num_workers=12,
tasks=[
sinter.Task(
circuit=stim.Circuit(),
decoder='MockTimingSampler',
json_metadata={},
)
],
max_shots=1_000_000,
max_errors=10_000,
custom_decoders={'MockTimingSampler': MockTimingSampler()},
)
assert 1_000_000 <= results[0].shots <= 1_000_000 + 12000
41 changes: 31 additions & 10 deletions glue/sample/src/sinter/_collection/_collection_worker_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import queue
import time
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -70,10 +71,32 @@ def __init__(
self.last_flush_message_time = time.monotonic()
self.soft_error_flush_threshold: int = 1

def _send_message_to_manager(self, message: Any):
self.out.put(message)

def state_summary(self) -> str:
lines = [
f'Worker(id={self.worker_id}) [',
f' max_flush_period={self.max_flush_period}',
f' cur_flush_period={self.cur_flush_period}',
f' sampler={self.sampler}',
f' compiled_sampler={self.compiled_sampler}',
f' current_task={self.current_task}',
f' current_error_cutoff={self.current_error_cutoff}',
f' custom_error_count_key={self.custom_error_count_key}',
f' current_task_shots_left={self.current_task_shots_left}',
f' unflushed_results={self.unflushed_results}',
f' last_flush_message_time={self.last_flush_message_time}',
f' soft_error_flush_threshold={self.soft_error_flush_threshold}',
f']',
]
return '\n' + '\n'.join(lines) + '\n'

def flush_results(self):
if self.unflushed_results.shots > 0:
self.last_flush_message_time = time.monotonic()
self.out.put((
self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period)
self._send_message_to_manager((
'flushed_results',
self.worker_id,
(self.current_task.strong_id(), self.unflushed_results),
Expand All @@ -85,7 +108,7 @@ def flush_results(self):
def accept_shots(self, *, shots_delta: int):
assert shots_delta >= 0
self.current_task_shots_left += shots_delta
self.out.put((
self._send_message_to_manager((
'accepted_shots',
self.worker_id,
(self.current_task.strong_id(), shots_delta),
Expand All @@ -97,17 +120,15 @@ def return_shots(self, *, requested_shots: int):
self.current_task_shots_left -= returned_shots
if self.current_task_shots_left <= 0:
self.flush_results()
if self.current_task_shots_left < 0:
self.current_task_shots_left = 0
self.out.put((
self._send_message_to_manager((
'returned_shots',
self.worker_id,
(self.current_task.strong_id(), returned_shots),
))

def compute_strong_id(self, *, new_task: Task):
strong_id = _fill_in_task(new_task).strong_id()
self.out.put((
self._send_message_to_manager((
'computed_strong_id',
self.worker_id,
strong_id,
Expand All @@ -123,7 +144,7 @@ def change_job(self, *, new_task: Task, new_collection_options: CollectionOption
self.current_task_shots_left = 0
self.last_flush_message_time = time.monotonic()

self.out.put((
self._send_message_to_manager((
'changed_job',
self.worker_id,
(self.current_task.strong_id(),),
Expand Down Expand Up @@ -196,7 +217,8 @@ def do_some_work(self) -> bool:
assert isinstance(some_work_done, AnonTaskStats)
self.current_task_shots_left -= some_work_done.shots
if self.current_error_cutoff is not None:
self.current_error_cutoff -= self.num_unflushed_errors()
errors_done = some_work_done.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else some_work_done.errors
self.current_error_cutoff -= errors_done
self.unflushed_results += some_work_done
did_some_work = True

Expand All @@ -208,7 +230,6 @@ def do_some_work(self) -> bool:
if self.current_task_shots_left <= 0 or self.last_flush_message_time + self.cur_flush_period < time.monotonic():
should_flush = True
if should_flush:
self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period)
did_some_work |= self.flush_results()

return did_some_work
Expand All @@ -228,7 +249,7 @@ def run_message_loop(self):

except BaseException as ex:
import traceback
self.out.put((
self._send_message_to_manager((
'stopped_due_to_exception',
self.worker_id,
(None if self.current_task is None else self.current_task.strong_id(), self.current_task_shots_left, self.unflushed_results, traceback.format_exc(), ex),
Expand Down
Loading

0 comments on commit 4269062

Please sign in to comment.