diff --git a/mutex.go b/mutex.go index 9455aa3..f85c1cd 100644 --- a/mutex.go +++ b/mutex.go @@ -18,11 +18,11 @@ func (m *mutex) lockId() string { return m.ns + ":" + m.id } -func (m *mutex) GetLockId() string { +func (m *mutex) GetId() string { return m.lockId() } -func (m *mutex) GetLockOwner() string { +func (m *mutex) GetOwner() string { return m.owner } @@ -71,7 +71,7 @@ func newMutex(provider Provider, id string, opts ...Option) Mutex { // String implements print interface. func (m *mutex) String() string { - return "Mutex(" + m.provider.Name() + ":" + m.GetLockId() + ")" + return "Mutex(" + m.provider.Name() + ":" + m.GetId() + ")" } // Lock locks the named resourc diff --git a/mutex_test.go b/mutex_test.go index c177ff5..230a009 100644 --- a/mutex_test.go +++ b/mutex_test.go @@ -8,7 +8,23 @@ import ( "github.com/stretchr/testify/assert" ) -func runBasicLockTests(t *testing.T, provider Provider) { +func runLockTestsWithoutLifetime(t *testing.T, provider Provider) { + factory := New(provider, WithNamespace("deadlock")) + m1 := factory.New("build-images") + m2 := factory.New("build-images") + m3 := factory.New("start-containers") + + assert.NoError(t, m1.Lock()) + assert.ErrorIs(t, m1.Lock(), ErrAlreadyLocked) + assert.ErrorIs(t, m2.Lock(), ErrAlreadyLocked) + assert.NoError(t, m3.Lock()) + + assert.NoError(t, m1.Unlock()) + assert.ErrorIs(t, m2.Unlock(), ErrNotLocked) + assert.NoError(t, m3.Unlock()) +} + +func runLockTestsWithLifetime(t *testing.T, provider Provider) { factory := New(provider, WithLockLifetime(1*time.Second)) m := factory.New("johndoe", WithNamespace("questions")) expectedMutexDisplayName := fmt.Sprintf("Mutex(%s:questions:johndoe)", provider.Name()) @@ -51,7 +67,7 @@ func testLockContention(t *testing.T, m Mutex) { func testUnlockAfterOwnerChange(t *testing.T, m1, m2 Mutex) { assert.NoError(t, m1.Lock()) assert.ErrorIs(t, m2.Lock(), ErrAlreadyLocked) - time.Sleep(10 * time.Millisecond) // m1 expired (released by system) + time.Sleep(50 * time.Millisecond) // m1 expired (released by system) assert.NoError(t, m2.Lock()) // m2 can obtain the lock, since m1 is expired assert.ErrorIs(t, m1.Unlock(), ErrNotLocked) } diff --git a/mysql.go b/mysql.go index 492e651..34eb13c 100644 --- a/mysql.go +++ b/mysql.go @@ -15,10 +15,10 @@ const ( mysqlLockSQL = `INSERT INTO %s (id, owner, expire_at) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE - owner = IF(expire_at < ?, VALUES(owner), owner), - expire_at = IF(expire_at < ?, VALUES(expire_at), expire_at);` + owner = IF(expire_at > 0 AND expire_at < ?, VALUES(owner), owner), + expire_at = IF(expire_at > 0 AND expire_at < ?, VALUES(expire_at), expire_at);` - mysqlUnlockSQL = `DELETE FROM %s WHERE id = ? AND owner = ? AND expire_at >= ?;` + mysqlUnlockSQL = `DELETE FROM %s WHERE id = ? AND owner = ? AND (expire_at = 0 OR expire_at >= ?);` ) type mysqlProvider struct { @@ -65,11 +65,10 @@ func (p *mysqlProvider) init() error { func (p *mysqlProvider) Lock(lock NamedLock) error { now := time.Now() - expireAt := now.Add(lock.GetLifetime()) rs, err := p.lockStmt.Exec( - lock.GetLockId(), - lock.GetLockOwner(), - expireAt.UnixNano(), + lock.GetId(), + lock.GetOwner(), + computeExpireAt(now, lock.GetLifetime()), now.UnixNano(), now.UnixNano(), ) @@ -88,8 +87,8 @@ func (p *mysqlProvider) Lock(lock NamedLock) error { func (p *mysqlProvider) Unlock(lock NamedLock) error { rs, err := p.unlockStmt.Exec( - lock.GetLockId(), - lock.GetLockOwner(), + lock.GetId(), + lock.GetOwner(), time.Now().UnixNano(), ) if err != nil { @@ -108,3 +107,10 @@ func (p *mysqlProvider) Unlock(lock NamedLock) error { func formatSQL(sqlTemplate, tableName string) string { return fmt.Sprintf(sqlTemplate, tableName) } + +func computeExpireAt(now time.Time, lifetime time.Duration) int64 { + if lifetime == 0 { + return 0 // never expire + } + return now.Add(lifetime).UnixNano() +} diff --git a/postgres.go b/postgres.go index a53fc8c..0319448 100644 --- a/postgres.go +++ b/postgres.go @@ -15,9 +15,10 @@ const ( pgLockSQL = `INSERT INTO %s AS t (id, owner, expire_at) VALUES ($1, $2, $3) ON CONFLICT (id) DO UPDATE - SET owner = $2, expire_at = $3 WHERE t.id = $1 AND t.expire_at < $4;` + SET owner = $2, expire_at = $3 + WHERE t.id = $1 AND t.expire_at > 0 AND t.expire_at < $4;` - pgUnlockSQL = `DELETE FROM %s WHERE id = $1 AND owner = $2 AND expire_at >= $3;` + pgUnlockSQL = `DELETE FROM %s WHERE id = $1 AND owner = $2 AND (expire_at = 0 OR expire_at >= $3);` ) type postgreSQLProvider mysqlProvider @@ -58,11 +59,10 @@ func (p *postgreSQLProvider) init() error { func (p *postgreSQLProvider) Lock(lock NamedLock) error { now := time.Now() - expireAt := now.Add(lock.GetLifetime()) rs, err := p.lockStmt.Exec( - lock.GetLockId(), - lock.GetLockOwner(), - expireAt.UnixNano(), + lock.GetId(), + lock.GetOwner(), + computeExpireAt(now, lock.GetLifetime()), now.UnixNano(), ) if err != nil { @@ -80,8 +80,8 @@ func (p *postgreSQLProvider) Lock(lock NamedLock) error { func (p *postgreSQLProvider) Unlock(lock NamedLock) error { rs, err := p.unlockStmt.Exec( - lock.GetLockId(), - lock.GetLockOwner(), + lock.GetId(), + lock.GetOwner(), time.Now().UnixNano(), ) if err != nil { diff --git a/provider.go b/provider.go index 616c1b9..adc8993 100644 --- a/provider.go +++ b/provider.go @@ -3,8 +3,8 @@ package distlock import "time" type NamedLock interface { - GetLockId() string - GetLockOwner() string + GetId() string + GetOwner() string GetLifetime() time.Duration } diff --git a/redis.go b/redis.go index 15ad85b..6ad7760 100644 --- a/redis.go +++ b/redis.go @@ -41,14 +41,28 @@ func (p *redisProvider) Lock(lock NamedLock) error { conn := p.pool.Get() defer conn.Close() - // SET key value PX milliseconds NX - // PX: Set the specified expire time, in milliseconds. - // NX: Only set the key if it does not already exist. - reply, err := conn.Do( - "SET", lock.GetLockId(), lock.GetLockOwner(), - "PX", lock.GetLifetime().Nanoseconds()/int64(time.Millisecond), - "NX", + var ( + reply interface{} + err error ) + + lifetime := lock.GetLifetime() + if lifetime > 0 { + // SET key value PX milliseconds NX + // PX: Set the specified expire time, in milliseconds. + // NX: Only set the key if it does not already exist. + reply, err = conn.Do( + "SET", lock.GetId(), lock.GetOwner(), + "PX", lock.GetLifetime().Nanoseconds()/int64(time.Millisecond), + "NX", + ) + } else { // never expire + reply, err = conn.Do( + "SET", lock.GetId(), lock.GetOwner(), + "NX", + ) + } + if err != nil { return fmt.Errorf("redis SET: %w", err) } @@ -64,7 +78,7 @@ func (p *redisProvider) Unlock(lock NamedLock) error { defer conn.Close() command := redis.NewScript(1, unlockScript) - ret, err := redis.Int(command.Do(conn, lock.GetLockId(), lock.GetLockOwner())) + ret, err := redis.Int(command.Do(conn, lock.GetId(), lock.GetOwner())) if err != nil { return fmt.Errorf("redis EVAL: %w", err) } diff --git a/redis_test.go b/redis_test.go index 1fcd601..48521e3 100644 --- a/redis_test.go +++ b/redis_test.go @@ -25,7 +25,17 @@ var ( } ) +func cleanupRedis() { + conn := redisPool.Get() + defer conn.Close() + + conn.Do("FLUSHDB") +} + func TestRedisProvider(t *testing.T) { + cleanupRedis() + provider, _ := NewRedisProvider(redisPool) - runBasicLockTests(t, provider) + runLockTestsWithLifetime(t, provider) + runLockTestsWithoutLifetime(t, provider) } diff --git a/sql_test.go b/sql_test.go index fd66a11..50688d8 100644 --- a/sql_test.go +++ b/sql_test.go @@ -8,18 +8,29 @@ import ( _ "github.com/lib/pq" ) +const TestTableName = "ggicci_distlock_test" + +func cleanupMySQL(db *sql.DB) { + _, _ = db.Exec(formatSQL("DROP TABLE IF EXISTS %s", TestTableName)) +} + +func cleanupPostgreSQL(db *sql.DB) { + _, _ = db.Exec(formatSQL("DROP TABLE IF EXISTS %s", TestTableName)) +} + func TestMySQLProvider(t *testing.T) { db, err := sql.Open("mysql", "root@tcp(localhost:3306)/test") if err != nil { t.Fatal(err) } + cleanupMySQL(db) - provider, err := NewMySQLProvider(db, "distlocks") + provider, err := NewMySQLProvider(db, TestTableName) if err != nil { t.Fatalf("could not create provider: %s", err) } - - runBasicLockTests(t, provider) + runLockTestsWithoutLifetime(t, provider) + runLockTestsWithLifetime(t, provider) } func TestPostgreSQLProvider(t *testing.T) { @@ -31,10 +42,12 @@ func TestPostgreSQLProvider(t *testing.T) { t.Fatal(err) } - provider, err := NewPostgreSQLProvider(db, "distlocks") + cleanupPostgreSQL(db) + + provider, err := NewPostgreSQLProvider(db, TestTableName) if err != nil { t.Fatalf("could not create provider: %s", err) } - - runBasicLockTests(t, provider) + runLockTestsWithoutLifetime(t, provider) + runLockTestsWithLifetime(t, provider) }