diff --git a/client/utils/storage.py b/client/utils/storage.py index 92dd71ee..38113dba 100644 --- a/client/utils/storage.py +++ b/client/utils/storage.py @@ -1,3 +1,4 @@ +import os import ctypes import shelve # persistance import hmac # security @@ -30,14 +31,26 @@ def mac(value): mac.update(repr(value).encode('utf-8')) return mac.hexdigest() +def open_db(): + try: + return shelve.open(SHELVE_FILE) + except: + # just catch everything because there are a variety of errors that a corrupt db can cause + # if there is some other error that will happen on the retry hopefully + for name in os.listdir("."): + if name.startswith(SHELVE_FILE): + os.rename(name, name + ".corrupt") + return shelve.open(SHELVE_FILE) + + def contains(root, key): key = '{}-{}'.format(root, key) - with shelve.open(SHELVE_FILE) as db: + with open_db() as db: return key in db def store(root, key, value): key = '{}-{}'.format(root, key) - with shelve.open(SHELVE_FILE) as db: + with open_db() as db: db[key] = {'value': value, 'mac': mac(value)} return value @@ -45,7 +58,7 @@ def get(root, key, default=None): if not contains(root, key): return default key = '{}-{}'.format(root, key) - with shelve.open(SHELVE_FILE) as db: + with open_db() as db: data = db[key] if not hmac.compare_digest(data['mac'], mac(data['value'])): raise ProtocolException('{} was tampered. Reverse changes, or redownload assignment'.format(SHELVE_FILE)) diff --git a/tests/end_to_end/storage_corrupt_test.py b/tests/end_to_end/storage_corrupt_test.py new file mode 100644 index 00000000..ec54623a --- /dev/null +++ b/tests/end_to_end/storage_corrupt_test.py @@ -0,0 +1,17 @@ +import json +import tempfile + +from tests.end_to_end.end_to_end_test import EndToEndTest + + +class EncryptionTest(EndToEndTest): + def testEncrypt(self): + self.copy_examples() + stdout, stderr = self.run_ok("-q", "q1") + self.assertOnlyInvalidGrant(stderr) + # mess up the shelve + self.add_file(".ok_storage.dir", "\0") + self.add_file(".ok_storage.bak", "\0") + self.add_file(".ok_storage.dat", "\0") + stdout, stderr = self.run_ok("-q", "q1") + self.assertOnlyInvalidGrant(stderr)