diff --git a/privacyidea_pam.py b/privacyidea_pam.py index 622798d..7fad4f2 100644 --- a/privacyidea_pam.py +++ b/privacyidea_pam.py @@ -72,6 +72,7 @@ class Authenticator(object): def __init__(self, pamh, config): self.pamh = pamh self.user = pamh.get_user(None) + self.rhost = pamh.rhost self.URL = config.get("url", "https://localhost") self.sslverify = not config.get("nosslverify", False) cacerts = config.get("cacerts") @@ -208,7 +209,7 @@ def authenticate(self, password): self.pamh.conversation(pam_message) # Save history - save_history_item(self.sqlfile, self.user, serial, (True if rval == self.pamh.PAM_SUCCESS else False)) + save_history_item(self.sqlfile, self.user, self.rhost, serial, (True if rval == self.pamh.PAM_SUCCESS else False)) return rval def challenge_response(self, transaction_id, message, attributes): @@ -318,7 +319,7 @@ def pam_sm_authenticate(pamh, flags, argv): config = _get_config(argv) debug = config.get("debug") try_first_pass = config.get("try_first_pass") - prompt = config.get("prompt", "Your OTP") + prompt = config.get("prompt", "Your OTP").replace("_", " ") grace_time = config.get("grace") if prompt[-1] != ":": prompt += ":" @@ -485,7 +486,7 @@ def save_auth_item(sqlfile, user, serial, tokentype, authitem): # Just be sure any changes have been committed or they will be lost. conn.close() -def check_last_history(sqlfile, user, grace_time, window=10): +def check_last_history(sqlfile, user, rhost, grace_time, window=10): """ Get the last event for this user. @@ -495,6 +496,7 @@ def check_last_history(sqlfile, user, grace_time, window=10): :param sqlfile: An SQLite file. If it does not exist, it will be generated. :type sqlfile: basestring :param user: The PAM user + :param rhost: The PAM user rhost value :param serial: The serial number of the token :param success: Boolean @@ -508,15 +510,15 @@ def check_last_history(sqlfile, user, grace_time, window=10): res = False events = [] - for row in c.execute("SELECT user, serial, last_success, last_error FROM history " - "WHERE user=? ORDER by last_success " + for row in c.execute("SELECT user, rhost, serial, last_success, last_error FROM history " + "WHERE user=? AND rhost=? ORDER by last_success " "LIMIT ?", - (user, window)): + (user, rhost, window)): events.append(row) if len(events)>0: for event in events: - last_success = event[2] + last_success = event[3] if last_success is not None: # Get the elapsed time in minutes since last success last_success_delta = datetime.datetime.now() - last_success @@ -540,7 +542,7 @@ def check_last_history(sqlfile, user, grace_time, window=10): return res -def save_history_item(sqlfile, user, serial, success): +def save_history_item(sqlfile, user, rhost, serial, success): """ Save the given success/error event. @@ -550,6 +552,7 @@ def save_history_item(sqlfile, user, serial, success): :param sqlfile: An SQLite file. If it does not exist, it will be generated. :type sqlfile: basestring :param user: The PAM user + :param rhost: The PAM user rhost value :param serial: The serial number of the token :param success: Boolean @@ -564,21 +567,21 @@ def save_history_item(sqlfile, user, serial, success): __name__, ("success" if success else "error"))) if success: # Insert the Event - c.execute("INSERT OR REPLACE INTO history (user, serial," - "error_counter, last_success) VALUES (?,?,?,?)", - (user, serial, 0, datetime.datetime.now())) + c.execute("INSERT OR REPLACE INTO history (user, rhost, serial," + "error_counter, last_success) VALUES (?,?,?,?,?)", + (user, rhost, serial, 0, datetime.datetime.now())) else: # Insert the Event c.execute("UPDATE history SET error_counter = error_counter + 1, " " serial = ? , last_error = ? " - " WHERE user = ? ", - (serial, datetime.datetime.now(), user)) + " WHERE user = ? AND rhost = ? ", + (serial, datetime.datetime.now(), user, rhost)) syslog.syslog(syslog.LOG_DEBUG,"Rows affected : %d " % c.rowcount) if c.rowcount == 0: - c.execute("INSERT INTO history (user, serial," - "error_counter, last_error) VALUES (?,?,?,?)", - (user, serial, 1, datetime.datetime.now())) + c.execute("INSERT INTO history (user, rhost, serial," + "error_counter, last_error) VALUES (?,?,?,?,?)", + (user, rhost, serial, 1, datetime.datetime.now())) # Save (commit) the changes @@ -610,7 +613,7 @@ def _create_table(c): try: # create history table c.execute("CREATE TABLE IF NOT EXISTS history " - "(user text, serial text, error_counter int, " + "(user text, rhost text, serial text, error_counter int, " "last_success timestamp, last_error timestamp)") c.execute("CREATE UNIQUE INDEX idx_user " "ON history (user);") diff --git a/tests/test_pam_module.py b/tests/test_pam_module.py index c810a05..fead1a3 100644 --- a/tests/test_pam_module.py +++ b/tests/test_pam_module.py @@ -98,9 +98,10 @@ class PAMH(object): exception = Exception - def __init__(self, user, password): + def __init__(self, user, password, rhost): self.authtok = password self.user = user + self.rhost = user def get_user(self, dummy): return self.user @@ -149,7 +150,7 @@ def test_02_authenticate_offline(self): body=json.dumps(SUCCESS_BODY), content_type="application/json") - pamh = PAMH("cornelius", "test100001") + pamh = PAMH("cornelius", "test100001", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -158,7 +159,7 @@ def test_02_authenticate_offline(self): self.assertEqual(r, PAMH.PAM_SUCCESS) # Authenticate the second time offline - pamh = PAMH("cornelius", "test100002") + pamh = PAMH("cornelius", "test100002", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -175,7 +176,7 @@ def test_03_authenticate_online(self): "http://my.privacyidea.server/validate/check", body=json.dumps(SUCCESS_BODY), content_type="application/json") - pamh = PAMH("cornelius", "test999999") + pamh = PAMH("cornelius", "test999999", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -186,7 +187,7 @@ def test_03_authenticate_online(self): def test_04_authenticate_offline(self): # and authenticate offline again. - pamh = PAMH("cornelius", "test100000") + pamh = PAMH("cornelius", "test100000", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -213,7 +214,7 @@ def test_05_two_tokens(self): ] }) - pamh = PAMH("cornelius", "test100001") + pamh = PAMH("cornelius", "test100001", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -222,7 +223,7 @@ def test_05_two_tokens(self): self.assertEqual(r, PAMH.PAM_SUCCESS) # An older OTP value of the first token is deleted - pamh = PAMH("cornelius", "test100000") + pamh = PAMH("cornelius", "test100000", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -231,7 +232,7 @@ def test_05_two_tokens(self): self.assertNotEqual(r, PAMH.PAM_SUCCESS) # An older value with another token can authenticate! - pamh = PAMH("cornelius", "TEST100000") + pamh = PAMH("cornelius", "TEST100000", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -247,7 +248,7 @@ def test_06_refill(self): body=json.dumps(SUCCESS_BODY), content_type="application/json") - pamh = PAMH("cornelius", "test100000") + pamh = PAMH("cornelius", "test100000", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -256,7 +257,7 @@ def test_06_refill(self): self.assertEqual(r, PAMH.PAM_SUCCESS) # OTP value not known yet, online auth does not work - pamh = PAMH("cornelius", "test100004") + pamh = PAMH("cornelius", "test100004", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -271,7 +272,7 @@ def test_06_refill(self): body=json.dumps(REFILL_BODY), content_type="application/json") - pamh = PAMH("cornelius", "test100001") + pamh = PAMH("cornelius", "test100001", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -284,7 +285,7 @@ def test_06_refill(self): # authenticate with refilled with responses.RequestsMock() as rsps: - pamh = PAMH("cornelius", "test100004") + pamh = PAMH("cornelius", "test100004", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, @@ -297,12 +298,10 @@ def test_06_refill(self): rsps.calls[0].request.body) # ... but not twice - pamh = PAMH("cornelius", "test100004") + pamh = PAMH("cornelius", "test100004", "192.168.0.1") flags = None argv = ["url=http://my.privacyidea.server", "sqlfile=%s" % SQLFILE, "try_first_pass"] r = pam_sm_authenticate(pamh, flags, argv) self.assertNotEqual(r, PAMH.PAM_SUCCESS) - -