Skip to content

Commit

Permalink
feat: locks can never expire
Browse files Browse the repository at this point in the history
  • Loading branch information
ggicci committed Dec 10, 2021
1 parent 5ae4e94 commit 068c138
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 39 deletions.
6 changes: 3 additions & 3 deletions mutex.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ func (m *mutex) lockId() string {
return m.ns + ":" + m.id
}

func (m *mutex) GetLockId() string {
func (m *mutex) GetId() string {
return m.lockId()
}

func (m *mutex) GetLockOwner() string {
func (m *mutex) GetOwner() string {
return m.owner
}

Expand Down Expand Up @@ -71,7 +71,7 @@ func newMutex(provider Provider, id string, opts ...Option) Mutex {

// String implements print interface.
func (m *mutex) String() string {
return "Mutex(" + m.provider.Name() + ":" + m.GetLockId() + ")"
return "Mutex(" + m.provider.Name() + ":" + m.GetId() + ")"
}

// Lock locks the named resourc
Expand Down
20 changes: 18 additions & 2 deletions mutex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@ import (
"github.com/stretchr/testify/assert"
)

func runBasicLockTests(t *testing.T, provider Provider) {
func runLockTestsWithoutLifetime(t *testing.T, provider Provider) {
factory := New(provider, WithNamespace("deadlock"))
m1 := factory.New("build-images")
m2 := factory.New("build-images")
m3 := factory.New("start-containers")

assert.NoError(t, m1.Lock())
assert.ErrorIs(t, m1.Lock(), ErrAlreadyLocked)
assert.ErrorIs(t, m2.Lock(), ErrAlreadyLocked)
assert.NoError(t, m3.Lock())

assert.NoError(t, m1.Unlock())
assert.ErrorIs(t, m2.Unlock(), ErrNotLocked)
assert.NoError(t, m3.Unlock())
}

func runLockTestsWithLifetime(t *testing.T, provider Provider) {
factory := New(provider, WithLockLifetime(1*time.Second))
m := factory.New("johndoe", WithNamespace("questions"))
expectedMutexDisplayName := fmt.Sprintf("Mutex(%s:questions:johndoe)", provider.Name())
Expand Down Expand Up @@ -51,7 +67,7 @@ func testLockContention(t *testing.T, m Mutex) {
func testUnlockAfterOwnerChange(t *testing.T, m1, m2 Mutex) {
assert.NoError(t, m1.Lock())
assert.ErrorIs(t, m2.Lock(), ErrAlreadyLocked)
time.Sleep(10 * time.Millisecond) // m1 expired (released by system)
time.Sleep(50 * time.Millisecond) // m1 expired (released by system)
assert.NoError(t, m2.Lock()) // m2 can obtain the lock, since m1 is expired
assert.ErrorIs(t, m1.Unlock(), ErrNotLocked)
}
24 changes: 15 additions & 9 deletions mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ const (

mysqlLockSQL = `INSERT INTO %s (id, owner, expire_at) VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE
owner = IF(expire_at < ?, VALUES(owner), owner),
expire_at = IF(expire_at < ?, VALUES(expire_at), expire_at);`
owner = IF(expire_at > 0 AND expire_at < ?, VALUES(owner), owner),
expire_at = IF(expire_at > 0 AND expire_at < ?, VALUES(expire_at), expire_at);`

mysqlUnlockSQL = `DELETE FROM %s WHERE id = ? AND owner = ? AND expire_at >= ?;`
mysqlUnlockSQL = `DELETE FROM %s WHERE id = ? AND owner = ? AND (expire_at = 0 OR expire_at >= ?);`
)

type mysqlProvider struct {
Expand Down Expand Up @@ -65,11 +65,10 @@ func (p *mysqlProvider) init() error {

func (p *mysqlProvider) Lock(lock NamedLock) error {
now := time.Now()
expireAt := now.Add(lock.GetLifetime())
rs, err := p.lockStmt.Exec(
lock.GetLockId(),
lock.GetLockOwner(),
expireAt.UnixNano(),
lock.GetId(),
lock.GetOwner(),
computeExpireAt(now, lock.GetLifetime()),
now.UnixNano(),
now.UnixNano(),
)
Expand All @@ -88,8 +87,8 @@ func (p *mysqlProvider) Lock(lock NamedLock) error {

func (p *mysqlProvider) Unlock(lock NamedLock) error {
rs, err := p.unlockStmt.Exec(
lock.GetLockId(),
lock.GetLockOwner(),
lock.GetId(),
lock.GetOwner(),
time.Now().UnixNano(),
)
if err != nil {
Expand All @@ -108,3 +107,10 @@ func (p *mysqlProvider) Unlock(lock NamedLock) error {
func formatSQL(sqlTemplate, tableName string) string {
return fmt.Sprintf(sqlTemplate, tableName)
}

func computeExpireAt(now time.Time, lifetime time.Duration) int64 {
if lifetime == 0 {
return 0 // never expire
}
return now.Add(lifetime).UnixNano()
}
16 changes: 8 additions & 8 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ const (

pgLockSQL = `INSERT INTO %s AS t (id, owner, expire_at) VALUES ($1, $2, $3)
ON CONFLICT (id) DO UPDATE
SET owner = $2, expire_at = $3 WHERE t.id = $1 AND t.expire_at < $4;`
SET owner = $2, expire_at = $3
WHERE t.id = $1 AND t.expire_at > 0 AND t.expire_at < $4;`

pgUnlockSQL = `DELETE FROM %s WHERE id = $1 AND owner = $2 AND expire_at >= $3;`
pgUnlockSQL = `DELETE FROM %s WHERE id = $1 AND owner = $2 AND (expire_at = 0 OR expire_at >= $3);`
)

type postgreSQLProvider mysqlProvider
Expand Down Expand Up @@ -58,11 +59,10 @@ func (p *postgreSQLProvider) init() error {

func (p *postgreSQLProvider) Lock(lock NamedLock) error {
now := time.Now()
expireAt := now.Add(lock.GetLifetime())
rs, err := p.lockStmt.Exec(
lock.GetLockId(),
lock.GetLockOwner(),
expireAt.UnixNano(),
lock.GetId(),
lock.GetOwner(),
computeExpireAt(now, lock.GetLifetime()),
now.UnixNano(),
)
if err != nil {
Expand All @@ -80,8 +80,8 @@ func (p *postgreSQLProvider) Lock(lock NamedLock) error {

func (p *postgreSQLProvider) Unlock(lock NamedLock) error {
rs, err := p.unlockStmt.Exec(
lock.GetLockId(),
lock.GetLockOwner(),
lock.GetId(),
lock.GetOwner(),
time.Now().UnixNano(),
)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package distlock
import "time"

type NamedLock interface {
GetLockId() string
GetLockOwner() string
GetId() string
GetOwner() string
GetLifetime() time.Duration
}

Expand Down
30 changes: 22 additions & 8 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,28 @@ func (p *redisProvider) Lock(lock NamedLock) error {
conn := p.pool.Get()
defer conn.Close()

// SET key value PX milliseconds NX
// PX: Set the specified expire time, in milliseconds.
// NX: Only set the key if it does not already exist.
reply, err := conn.Do(
"SET", lock.GetLockId(), lock.GetLockOwner(),
"PX", lock.GetLifetime().Nanoseconds()/int64(time.Millisecond),
"NX",
var (
reply interface{}
err error
)

lifetime := lock.GetLifetime()
if lifetime > 0 {
// SET key value PX milliseconds NX
// PX: Set the specified expire time, in milliseconds.
// NX: Only set the key if it does not already exist.
reply, err = conn.Do(
"SET", lock.GetId(), lock.GetOwner(),
"PX", lock.GetLifetime().Nanoseconds()/int64(time.Millisecond),
"NX",
)
} else { // never expire
reply, err = conn.Do(
"SET", lock.GetId(), lock.GetOwner(),
"NX",
)
}

if err != nil {
return fmt.Errorf("redis SET: %w", err)
}
Expand All @@ -64,7 +78,7 @@ func (p *redisProvider) Unlock(lock NamedLock) error {
defer conn.Close()

command := redis.NewScript(1, unlockScript)
ret, err := redis.Int(command.Do(conn, lock.GetLockId(), lock.GetLockOwner()))
ret, err := redis.Int(command.Do(conn, lock.GetId(), lock.GetOwner()))
if err != nil {
return fmt.Errorf("redis EVAL: %w", err)
}
Expand Down
12 changes: 11 additions & 1 deletion redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@ var (
}
)

func cleanupRedis() {
conn := redisPool.Get()
defer conn.Close()

conn.Do("FLUSHDB")
}

func TestRedisProvider(t *testing.T) {
cleanupRedis()

provider, _ := NewRedisProvider(redisPool)
runBasicLockTests(t, provider)
runLockTestsWithLifetime(t, provider)
runLockTestsWithoutLifetime(t, provider)
}
25 changes: 19 additions & 6 deletions sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,29 @@ import (
_ "github.com/lib/pq"
)

const TestTableName = "ggicci_distlock_test"

func cleanupMySQL(db *sql.DB) {
_, _ = db.Exec(formatSQL("DROP TABLE IF EXISTS %s", TestTableName))
}

func cleanupPostgreSQL(db *sql.DB) {
_, _ = db.Exec(formatSQL("DROP TABLE IF EXISTS %s", TestTableName))
}

func TestMySQLProvider(t *testing.T) {
db, err := sql.Open("mysql", "root@tcp(localhost:3306)/test")
if err != nil {
t.Fatal(err)
}
cleanupMySQL(db)

provider, err := NewMySQLProvider(db, "distlocks")
provider, err := NewMySQLProvider(db, TestTableName)
if err != nil {
t.Fatalf("could not create provider: %s", err)
}

runBasicLockTests(t, provider)
runLockTestsWithoutLifetime(t, provider)
runLockTestsWithLifetime(t, provider)
}

func TestPostgreSQLProvider(t *testing.T) {
Expand All @@ -31,10 +42,12 @@ func TestPostgreSQLProvider(t *testing.T) {
t.Fatal(err)
}

provider, err := NewPostgreSQLProvider(db, "distlocks")
cleanupPostgreSQL(db)

provider, err := NewPostgreSQLProvider(db, TestTableName)
if err != nil {
t.Fatalf("could not create provider: %s", err)
}

runBasicLockTests(t, provider)
runLockTestsWithoutLifetime(t, provider)
runLockTestsWithLifetime(t, provider)
}

0 comments on commit 068c138

Please sign in to comment.