Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flex extend #16

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion redlock/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .lock import RedLock, ReentrantRedLock, RedLockFactory, RedLockError
from .lock import ( # NOQA
RedLock, ReentrantRedLock, RedLockFactory, RedLockError
)
__VERSION__ = '1.2.0'
170 changes: 150 additions & 20 deletions redlock/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
import redis


DEFAULT_RETRY_TIMES = 3
DEFAULT_RETRY_DELAY = 200
DEFAULT_TTL = 100000
CLOCK_DRIFT_FACTOR = 0.01

# Reference: http://redis.io/topics/distlock
# Section Correct implementation with a single instance
RELEASE_LUA_SCRIPT = """
Expand Down Expand Up @@ -75,16 +70,30 @@ class RedLock(object):
Python Standard Library.
"""

DEFAULT_RETRY_TIMES = 3
DEFAULT_RETRY_DELAY = 200
DEFAULT_TTL = 100000
CLOCK_DRIFT_FACTOR = 0.01

def __init__(self, resource, connection_details=None,
retry_times=DEFAULT_RETRY_TIMES,
retry_delay=DEFAULT_RETRY_DELAY,
ttl=DEFAULT_TTL,
created_by_factory=False):
retry_times=None, retry_delay=None, ttl=None,
created_by_factory=False,
key=None):

self.resource = resource
self.retry_times = retry_times
self.retry_delay = retry_delay
self.ttl = ttl
self.retry_times = retry_times or self.DEFAULT_RETRY_TIMES
self.retry_delay = retry_delay or self.DEFAULT_RETRY_DELAY
self.ttl = ttl or self.DEFAULT_TTL
if not key:
# lock_key should be random and unique
self.lock_key = uuid.uuid4().hex
else:
# To enable release of an externally stored key
self.lock_key = key
# In python3 the lock_key must be a instance bytes (which is
# in python2 the same as str)
if isinstance(self.lock_key, str):
self.lock_key = self.lock_key.encode()

if created_by_factory:
self.factory = None
Expand All @@ -110,7 +119,10 @@ def __init__(self, resource, connection_details=None,
node = redis.StrictRedis(**conn)
node._release_script = node.register_script(RELEASE_LUA_SCRIPT)
self.redis_nodes.append(node)
self.quorum = len(self.redis_nodes) // 2 + 1
self.min_quorum = len(self.redis_nodes) // 2 + 1
self.max_quorum = len(self.redis_nodes)
self.quorum = self.min_quorum
self.lock_acquired = False

def __enter__(self):
acquired, validity = self.acquire_with_validity()
Expand All @@ -129,13 +141,22 @@ def _total_ms(self, delta):
delta_seconds = delta.seconds + delta.days * 24 * 3600
return (delta.microseconds + delta_seconds * 10**6) / 10**3

def set_quorum(self, quorum=None):
"""
Set the quorum
"""
if self.lock_acquired:
raise RedLockError("Cannot set quorum when lock is acquired.")
self.quorum = min(max(self.max_quorum, quorum), self.max_quorum)

def acquire_node(self, node):
"""
acquire a single redis node
"""
try:
return node.set(self.resource, self.lock_key, nx=True, px=self.ttl)
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError):
except (redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError):
return False

def release_node(self, node):
Expand All @@ -145,8 +166,40 @@ def release_node(self, node):
# use the lua script to release the lock in a safe way
try:
node._release_script(keys=[self.resource], args=[self.lock_key])
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError):
except (redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError):
pass

def info_node(self, node):
"""
Get lock info from a single node
"""
value = ttl = None
try:
value = node.get(self.resource)
ttl = node.pttl(self.resource)
except (redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError):
pass
return (value, ttl)

def extend_node(self, node, ttl=None):
"""
Extend lock (set new ttl to ttl or self.ttl)
"""
if not ttl:
ttl = self.ttl
acquired = self.acquire_node(node)
try:
if acquired or node.get(self.resource) == self.lock_key:
if node.pexpire(self.resource, ttl):
return True
except (redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError):
pass
if acquired:
return True
return False

def acquire(self):
acquired, validity = self._acquire()
Expand All @@ -156,10 +209,11 @@ def acquire_with_validity(self):
return self._acquire()

def _acquire(self):
"""
acquire a lock on at least quorum redis nodes
"""

# lock_key should be random and unique
self.lock_key = uuid.uuid4().hex

self.lock_acquired = False
for retry in range(self.retry_times + 1):
acquired_node_count = 0
start_time = datetime.utcnow()
Expand All @@ -173,12 +227,13 @@ def _acquire(self):
elapsed_milliseconds = self._total_ms(end_time - start_time)

# Add 2 milliseconds to the drift to account for Redis expires
# precision, which is 1 milliescond, plus 1 millisecond min drift
# precision, which is 1 milliscond, plus 1 millisecond min drift
# for small TTLs.
drift = (self.ttl * CLOCK_DRIFT_FACTOR) + 2
drift = (self.ttl * self.CLOCK_DRIFT_FACTOR) + 2

validity = self.ttl - (elapsed_milliseconds + drift)
if acquired_node_count >= self.quorum and validity > 0:
self.lock_acquired = True
return True, validity
else:
for node in self.redis_nodes:
Expand All @@ -187,11 +242,76 @@ def _acquire(self):
return False, 0

def release(self):
"""
Release lock on all redis nodes
"""
for node in self.redis_nodes:
self.release_node(node)
self.lock_acquired = False

def holding(self):
"""
Check if this lock is acquired
"""
acquired_node_count = 0
for node in self.redis_nodes:
value, ttl = self.info_node(node)
if value == self.lock_key:
acquired_node_count += 1
if acquired_node_count >= self.quorum:
self.lock_acquired = True
return True
for node in self.redis_nodes:
self.release_node(node)
self.lock_acquired = False
return False

def info(self):
"""
Get lock_key and remaining ttl of the lock on all redis nodes
"""
return list([
self.info_node(node)
for node in self.redis_nodes
])

def extend(self, ttl=None):
acquired, validity = self._extend(ttl)
return acquired

def extend_with_validity(self, ttl=None):
return self._extend(ttl)

def _extend(self, ttl=None):
"""
Extend the previously acquired lock on all redis nodes
"""
acquired_node_count = 0
start_time = datetime.utcnow()
# Extend the lock ttls
for node in self.redis_nodes:
if self.extend_node(node, ttl):
acquired_node_count += 1
end_time = datetime.utcnow()
elapsed_milliseconds = self._total_ms(end_time - start_time)

# Add 2 milliseconds to the drift to account for Redis expires
# precision, which is 1 milliscond, plus 1 millisecond min drift
# for small TTLs.
drift = (self.ttl * self.CLOCK_DRIFT_FACTOR) + 2

validity = self.ttl - (elapsed_milliseconds + drift)
if acquired_node_count >= self.quorum and validity > 0:
self.lock_acquired = True
return True, validity
for node in self.redis_nodes:
self.release_node(node)
self.lock_acquired = False
return False, 0


class ReentrantRedLock(RedLock):

def __init__(self, *args, **kwargs):
super(ReentrantRedLock, self).__init__(*args, **kwargs)
self._acquired = 0
Expand All @@ -213,3 +333,13 @@ def release(self):
return super(ReentrantRedLock, self).release()
return True
return False

def extend(self, ttl=None):
if self._acquired == 0:
result = super(ReentrantRedLock, self).extend(ttl)
if result:
self._acquired += 1
return result
else:
self._acquired += 1
return True
4 changes: 3 additions & 1 deletion tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
def test_factory_create():
factory = RedLockFactory([{"host": "localhost"}])

lock = factory.create_lock("test_factory_create", ttl=500, retry_times=5, retry_delay=100)
lock = factory.create_lock(
"test_factory_create", ttl=500, retry_times=5, retry_delay=100
)

assert factory.redis_nodes == lock.redis_nodes
assert factory.quorum == lock.quorum
Expand Down
Loading