From ab8e480ed6663370e92b4a5c4f6959f13c8927ae Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Fri, 11 Aug 2023 11:16:36 -0400 Subject: [PATCH] Rework updating of dataset records a bit --- qcportal/qcportal/dataset_models.py | 46 ++++++++++++++++------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 5f729ac24..581b43817 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -677,6 +677,8 @@ def _internal_update_records( Names of the entries whose records to update. If None, fetch all entries specification_names Names of the specifications whose records to update. If None, fetch all specifications + status + Update records that have this status on the server. If None, update records with any status on the server include Additional fields/data to include when fetch the entry """ @@ -685,17 +687,24 @@ def _internal_update_records( return # Get all the record ids that we store that correspond to the entries/specs - existing_record_ids = [] - for entry_name, spec_name in itertools.product(entry_names, specification_names): - existing_rec = self.record_map_.get((entry_name, spec_name), None) - if existing_rec is not None: - existing_record_ids.append(existing_rec.id) + existing_record_info = [ + (e, s, r) for (e, s), r in self.record_map_.items() if e in entry_names and s in specification_names + ] + + # Subset that we should check for updated records on the server + # (completed and invalid rarely change) + updateable_record_ids = [ + r.id + for e, s, r in existing_record_info + if r.status not in (RecordStatusEnum.complete, RecordStatusEnum.invalid) + ] # Do a raw call to the records/bulkGet endpoint. This allows us to only get # the 'modified_on' and 'status' fields batch_size = self._client.api_limits["get_records"] // 4 - minfo_dict = {} - for record_id_batch in chunk_iterable(existing_record_ids, batch_size): + minfo_dict: Dict[int, datetime] = {} # record_id -> modified time + + for record_id_batch in chunk_iterable(updateable_record_ids, batch_size): body = CommonBulkGetBody(ids=record_id_batch, include=["modified_on", "status"]) modified_info = self._client.make_request( @@ -707,31 +716,28 @@ def _internal_update_records( # Too lazy to look up how pydantic stores datetime, so use pydantic to parse it for m in modified_info: + # Only store if the status on the server matches what the caller wants if status is None or m["status"] in status: minfo_dict[m["id"]] = pydantic.parse_obj_as(datetime, m["modified_on"]) - # Which ones need to be updated + # Which ones need to be fully updated need_updating = [] - for entry_name, spec_name in itertools.product(entry_names, specification_names): - existing_record = self.record_map_.get((entry_name, spec_name), None) - - if existing_record is None: - continue - - server_rec_mtime = minfo_dict.get(existing_record.id, None) + for entry_name, spec_name, existing_record in existing_record_info: + server_mtime = minfo_dict.get(existing_record.id, None) - # Maybe mismatched status - if server_rec_mtime is None: + # Perhaps a status mismatch (status on server isn't one we want) + if server_mtime is None: continue - if existing_record.modified_on < server_rec_mtime: + if existing_record.modified_on < server_mtime: need_updating.append((entry_name, spec_name)) # Go via one spec at a time for spec_name in specification_names: entries_to_update = [x[0] for x in need_updating if x[1] == spec_name] - if entries_to_update: - self._internal_fetch_records(entries_to_update, [spec_name], None, include) + + for entries_to_update_batch in chunk_iterable(entries_to_update, batch_size): + self._internal_fetch_records(entries_to_update_batch, [spec_name], None, include) def fetch_records( self,