Skip to content
This repository has been archived by the owner on Mar 8, 2023. It is now read-only.

Commit

Permalink
Add rhost to the history item to record IP address of the client
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin Lux committed Aug 11, 2020
1 parent 752fa3d commit 3c431bb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
37 changes: 20 additions & 17 deletions privacyidea_pam.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Authenticator(object):
def __init__(self, pamh, config):
self.pamh = pamh
self.user = pamh.get_user(None)
self.rhost = pamh.rhost
self.URL = config.get("url", "https://localhost")
self.sslverify = not config.get("nosslverify", False)
cacerts = config.get("cacerts")
Expand Down Expand Up @@ -208,7 +209,7 @@ def authenticate(self, password):
self.pamh.conversation(pam_message)

# Save history
save_history_item(self.sqlfile, self.user, serial, (True if rval == self.pamh.PAM_SUCCESS else False))
save_history_item(self.sqlfile, self.user, self.rhost, serial, (True if rval == self.pamh.PAM_SUCCESS else False))
return rval

def challenge_response(self, transaction_id, message, attributes):
Expand Down Expand Up @@ -318,7 +319,7 @@ def pam_sm_authenticate(pamh, flags, argv):
config = _get_config(argv)
debug = config.get("debug")
try_first_pass = config.get("try_first_pass")
prompt = config.get("prompt", "Your OTP")
prompt = config.get("prompt", "Your OTP").replace("_", " ")
grace_time = config.get("grace")
if prompt[-1] != ":":
prompt += ":"
Expand Down Expand Up @@ -485,7 +486,7 @@ def save_auth_item(sqlfile, user, serial, tokentype, authitem):
# Just be sure any changes have been committed or they will be lost.
conn.close()

def check_last_history(sqlfile, user, grace_time, window=10):
def check_last_history(sqlfile, user, rhost, grace_time, window=10):
"""
Get the last event for this user.
Expand All @@ -495,6 +496,7 @@ def check_last_history(sqlfile, user, grace_time, window=10):
:param sqlfile: An SQLite file. If it does not exist, it will be generated.
:type sqlfile: basestring
:param user: The PAM user
:param rhost: The PAM user rhost value
:param serial: The serial number of the token
:param success: Boolean
Expand All @@ -508,15 +510,15 @@ def check_last_history(sqlfile, user, grace_time, window=10):
res = False
events = []

for row in c.execute("SELECT user, serial, last_success, last_error FROM history "
"WHERE user=? ORDER by last_success "
for row in c.execute("SELECT user, rhost, serial, last_success, last_error FROM history "
"WHERE user=? AND rhost=? ORDER by last_success "
"LIMIT ?",
(user, window)):
(user, rhost, window)):
events.append(row)

if len(events)>0:
for event in events:
last_success = event[2]
last_success = event[3]
if last_success is not None:
# Get the elapsed time in minutes since last success
last_success_delta = datetime.datetime.now() - last_success
Expand All @@ -540,7 +542,7 @@ def check_last_history(sqlfile, user, grace_time, window=10):
return res


def save_history_item(sqlfile, user, serial, success):
def save_history_item(sqlfile, user, rhost, serial, success):
"""
Save the given success/error event.
Expand All @@ -550,6 +552,7 @@ def save_history_item(sqlfile, user, serial, success):
:param sqlfile: An SQLite file. If it does not exist, it will be generated.
:type sqlfile: basestring
:param user: The PAM user
:param rhost: The PAM user rhost value
:param serial: The serial number of the token
:param success: Boolean
Expand All @@ -564,21 +567,21 @@ def save_history_item(sqlfile, user, serial, success):
__name__, ("success" if success else "error")))
if success:
# Insert the Event
c.execute("INSERT OR REPLACE INTO history (user, serial,"
"error_counter, last_success) VALUES (?,?,?,?)",
(user, serial, 0, datetime.datetime.now()))
c.execute("INSERT OR REPLACE INTO history (user, rhost, serial,"
"error_counter, last_success) VALUES (?,?,?,?,?)",
(user, rhost, serial, 0, datetime.datetime.now()))
else:
# Insert the Event
c.execute("UPDATE history SET error_counter = error_counter + 1, "
" serial = ? , last_error = ? "
" WHERE user = ? ",
(serial, datetime.datetime.now(), user))
" WHERE user = ? AND rhost = ? ",
(serial, datetime.datetime.now(), user, rhost))

syslog.syslog(syslog.LOG_DEBUG,"Rows affected : %d " % c.rowcount)
if c.rowcount == 0:
c.execute("INSERT INTO history (user, serial,"
"error_counter, last_error) VALUES (?,?,?,?)",
(user, serial, 1, datetime.datetime.now()))
c.execute("INSERT INTO history (user, rhost, serial,"
"error_counter, last_error) VALUES (?,?,?,?,?)",
(user, rhost, serial, 1, datetime.datetime.now()))


# Save (commit) the changes
Expand Down Expand Up @@ -610,7 +613,7 @@ def _create_table(c):
try:
# create history table
c.execute("CREATE TABLE IF NOT EXISTS history "
"(user text, serial text, error_counter int, "
"(user text, rhost text, serial text, error_counter int, "
"last_success timestamp, last_error timestamp)")
c.execute("CREATE UNIQUE INDEX idx_user "
"ON history (user);")
Expand Down
29 changes: 14 additions & 15 deletions tests/test_pam_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class PAMH(object):

exception = Exception

def __init__(self, user, password):
def __init__(self, user, password, rhost):
self.authtok = password
self.user = user
self.rhost = user

def get_user(self, dummy):
return self.user
Expand Down Expand Up @@ -149,7 +150,7 @@ def test_02_authenticate_offline(self):
body=json.dumps(SUCCESS_BODY),
content_type="application/json")

pamh = PAMH("cornelius", "test100001")
pamh = PAMH("cornelius", "test100001", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -158,7 +159,7 @@ def test_02_authenticate_offline(self):
self.assertEqual(r, PAMH.PAM_SUCCESS)

# Authenticate the second time offline
pamh = PAMH("cornelius", "test100002")
pamh = PAMH("cornelius", "test100002", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -175,7 +176,7 @@ def test_03_authenticate_online(self):
"http://my.privacyidea.server/validate/check",
body=json.dumps(SUCCESS_BODY),
content_type="application/json")
pamh = PAMH("cornelius", "test999999")
pamh = PAMH("cornelius", "test999999", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -186,7 +187,7 @@ def test_03_authenticate_online(self):

def test_04_authenticate_offline(self):
# and authenticate offline again.
pamh = PAMH("cornelius", "test100000")
pamh = PAMH("cornelius", "test100000", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -213,7 +214,7 @@ def test_05_two_tokens(self):
]
})

pamh = PAMH("cornelius", "test100001")
pamh = PAMH("cornelius", "test100001", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -222,7 +223,7 @@ def test_05_two_tokens(self):
self.assertEqual(r, PAMH.PAM_SUCCESS)

# An older OTP value of the first token is deleted
pamh = PAMH("cornelius", "test100000")
pamh = PAMH("cornelius", "test100000", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -231,7 +232,7 @@ def test_05_two_tokens(self):
self.assertNotEqual(r, PAMH.PAM_SUCCESS)

# An older value with another token can authenticate!
pamh = PAMH("cornelius", "TEST100000")
pamh = PAMH("cornelius", "TEST100000", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -247,7 +248,7 @@ def test_06_refill(self):
body=json.dumps(SUCCESS_BODY),
content_type="application/json")

pamh = PAMH("cornelius", "test100000")
pamh = PAMH("cornelius", "test100000", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -256,7 +257,7 @@ def test_06_refill(self):
self.assertEqual(r, PAMH.PAM_SUCCESS)

# OTP value not known yet, online auth does not work
pamh = PAMH("cornelius", "test100004")
pamh = PAMH("cornelius", "test100004", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -271,7 +272,7 @@ def test_06_refill(self):
body=json.dumps(REFILL_BODY),
content_type="application/json")

pamh = PAMH("cornelius", "test100001")
pamh = PAMH("cornelius", "test100001", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -284,7 +285,7 @@ def test_06_refill(self):

# authenticate with refilled
with responses.RequestsMock() as rsps:
pamh = PAMH("cornelius", "test100004")
pamh = PAMH("cornelius", "test100004", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
Expand All @@ -297,12 +298,10 @@ def test_06_refill(self):
rsps.calls[0].request.body)

# ... but not twice
pamh = PAMH("cornelius", "test100004")
pamh = PAMH("cornelius", "test100004", "192.168.0.1")
flags = None
argv = ["url=http://my.privacyidea.server",
"sqlfile=%s" % SQLFILE,
"try_first_pass"]
r = pam_sm_authenticate(pamh, flags, argv)
self.assertNotEqual(r, PAMH.PAM_SUCCESS)


0 comments on commit 3c431bb

Please sign in to comment.