Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup & test fix #14

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 31 additions & 55 deletions pyaes/aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
# See the README.md for API details and general information.


import copy
import struct

__all__ = ["AES", "AESModeOfOperationCTR", "AESModeOfOperationCBC", "AESModeOfOperationCFB",
Expand All @@ -61,36 +60,13 @@
def _compact_word(word):
return (word[0] << 24) | (word[1] << 16) | (word[2] << 8) | word[3]

def _string_to_bytes(text):
return list(ord(c) for c in text)

def _bytes_to_string(binary):
return "".join(chr(b) for b in binary)

def _concat_list(a, b):
return a + b


# Python 3 compatibility
try:
xrange
except Exception:
except NameError:
xrange = range

# Python 3 supports bytes, which is already an array of integers
def _string_to_bytes(text):
if isinstance(text, bytes):
return text
return [ord(c) for c in text]

# In Python 3, we return bytes
def _bytes_to_string(binary):
return bytes(binary)

# Python 3 cannot concatenate a list onto a bytes, so we bytes-ify it first
def _concat_list(a, b):
return a + bytes(b)


# Based *largely* on the Rijndael implementation
# See: http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf
Expand Down Expand Up @@ -161,7 +137,7 @@ def __init__(self, key):
tk[0] ^= ((self.S[(tt >> 16) & 0xFF] << 24) ^
(self.S[(tt >> 8) & 0xFF] << 16) ^
(self.S[ tt & 0xFF] << 8) ^
self.S[(tt >> 24) & 0xFF] ^
self.S[(tt >> 24) ] ^
(self.rcon[rconpointer] << 24))
rconpointer += 1

Expand All @@ -178,7 +154,7 @@ def __init__(self, key):
tk[KC // 2] ^= (self.S[ tt & 0xFF] ^
(self.S[(tt >> 8) & 0xFF] << 8) ^
(self.S[(tt >> 16) & 0xFF] << 16) ^
(self.S[(tt >> 24) & 0xFF] << 24))
(self.S[(tt >> 24) ] << 24))

for i in xrange(KC // 2 + 1, KC):
tk[i] ^= tk[i - 1]
Expand All @@ -195,7 +171,7 @@ def __init__(self, key):
for r in xrange(1, rounds):
for j in xrange(0, 4):
tt = self._Kd[r][j]
self._Kd[r][j] = (self.U1[(tt >> 24) & 0xFF] ^
self._Kd[r][j] = (self.U1[(tt >> 24) ] ^
self.U2[(tt >> 16) & 0xFF] ^
self.U3[(tt >> 8) & 0xFF] ^
self.U4[ tt & 0xFF])
Expand All @@ -216,18 +192,18 @@ def encrypt(self, plaintext):
# Apply round transforms
for r in xrange(1, rounds):
for i in xrange(0, 4):
a[i] = (self.T1[(t[ i ] >> 24) & 0xFF] ^
a[i] = (self.T1[(t[ i ] >> 24) ] ^
self.T2[(t[(i + s1) % 4] >> 16) & 0xFF] ^
self.T3[(t[(i + s2) % 4] >> 8) & 0xFF] ^
self.T4[ t[(i + s3) % 4] & 0xFF] ^
self._Ke[r][i])
t = copy.copy(a)
t = a[:]

# The last round is special
result = [ ]
for i in xrange(0, 4):
tt = self._Ke[rounds][i]
result.append((self.S[(t[ i ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
result.append((self.S[(t[ i ] >> 24) ] ^ (tt >> 24)) & 0xFF)
result.append((self.S[(t[(i + s1) % 4] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
result.append((self.S[(t[(i + s2) % 4] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF)
result.append((self.S[ t[(i + s3) % 4] & 0xFF] ^ tt ) & 0xFF)
Expand All @@ -250,18 +226,18 @@ def decrypt(self, ciphertext):
# Apply round transforms
for r in xrange(1, rounds):
for i in xrange(0, 4):
a[i] = (self.T5[(t[ i ] >> 24) & 0xFF] ^
a[i] = (self.T5[(t[ i ] >> 24) ] ^
self.T6[(t[(i + s1) % 4] >> 16) & 0xFF] ^
self.T7[(t[(i + s2) % 4] >> 8) & 0xFF] ^
self.T8[ t[(i + s3) % 4] & 0xFF] ^
self._Kd[r][i])
t = copy.copy(a)
t = a[:]

# The last round is special
result = [ ]
for i in xrange(0, 4):
tt = self._Kd[rounds][i]
result.append((self.Si[(t[ i ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
result.append((self.Si[(t[ i ] >> 24) ] ^ (tt >> 24)) & 0xFF)
result.append((self.Si[(t[(i + s1) % 4] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
result.append((self.Si[(t[(i + s2) % 4] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF)
result.append((self.Si[ t[(i + s3) % 4] & 0xFF] ^ tt ) & 0xFF)
Expand Down Expand Up @@ -341,15 +317,15 @@ def encrypt(self, plaintext):
if len(plaintext) != 16:
raise ValueError('plaintext block must be 16 bytes')

plaintext = _string_to_bytes(plaintext)
return _bytes_to_string(self._aes.encrypt(plaintext))
plaintext = bytearray(plaintext)
return bytes(bytearray(self._aes.encrypt(plaintext)))

def decrypt(self, ciphertext):
if len(ciphertext) != 16:
raise ValueError('ciphertext block must be 16 bytes')

ciphertext = _string_to_bytes(ciphertext)
return _bytes_to_string(self._aes.decrypt(ciphertext))
ciphertext = bytearray(ciphertext)
return bytes(bytearray(self._aes.decrypt(ciphertext)))



Expand Down Expand Up @@ -380,29 +356,29 @@ def __init__(self, key, iv = None):
elif len(iv) != 16:
raise ValueError('initialization vector must be 16 bytes')
else:
self._last_cipherblock = _string_to_bytes(iv)
self._last_cipherblock = bytearray(iv)

AESBlockModeOfOperation.__init__(self, key)

def encrypt(self, plaintext):
if len(plaintext) != 16:
raise ValueError('plaintext block must be 16 bytes')

plaintext = _string_to_bytes(plaintext)
plaintext = bytearray(plaintext)
precipherblock = [ (p ^ l) for (p, l) in zip(plaintext, self._last_cipherblock) ]
self._last_cipherblock = self._aes.encrypt(precipherblock)

return _bytes_to_string(self._last_cipherblock)
return bytes(bytearray(self._last_cipherblock))

def decrypt(self, ciphertext):
if len(ciphertext) != 16:
raise ValueError('ciphertext block must be 16 bytes')

cipherblock = _string_to_bytes(ciphertext)
cipherblock = bytearray(ciphertext)
plaintext = [ (p ^ l) for (p, l) in zip(self._aes.decrypt(cipherblock), self._last_cipherblock) ]
self._last_cipherblock = cipherblock

return _bytes_to_string(plaintext)
return bytes(bytearray(plaintext))



Expand All @@ -427,7 +403,7 @@ def __init__(self, key, iv, segment_size = 1):
elif len(iv) != 16:
raise ValueError('initialization vector must be 16 bytes')
else:
self._shift_register = _string_to_bytes(iv)
self._shift_register = bytearray(iv)

self._segment_bytes = segment_size

Expand All @@ -439,7 +415,7 @@ def encrypt(self, plaintext):
if len(plaintext) % self._segment_bytes != 0:
raise ValueError('plaintext block must be a multiple of segment_size')

plaintext = _string_to_bytes(plaintext)
plaintext = bytearray(plaintext)

# Break block into segments
encrypted = [ ]
Expand All @@ -449,17 +425,17 @@ def encrypt(self, plaintext):
cipher_segment = [ (p ^ x) for (p, x) in zip(plaintext_segment, xor_segment) ]

# Shift the top bits out and the ciphertext in
self._shift_register = _concat_list(self._shift_register[len(cipher_segment):], cipher_segment)
self._shift_register = self._shift_register[len(cipher_segment):] + bytearray(cipher_segment)

encrypted.extend(cipher_segment)

return _bytes_to_string(encrypted)
return bytes(bytearray(encrypted))

def decrypt(self, ciphertext):
if len(ciphertext) % self._segment_bytes != 0:
raise ValueError('ciphertext block must be a multiple of segment_size')

ciphertext = _string_to_bytes(ciphertext)
ciphertext = bytearray(ciphertext)

# Break block into segments
decrypted = [ ]
Expand All @@ -469,11 +445,11 @@ def decrypt(self, ciphertext):
plaintext_segment = [ (p ^ x) for (p, x) in zip(cipher_segment, xor_segment) ]

# Shift the top bits out and the ciphertext in
self._shift_register = _concat_list(self._shift_register[len(cipher_segment):], cipher_segment)
self._shift_register = self._shift_register[len(cipher_segment):] + bytearray(cipher_segment)

decrypted.extend(plaintext_segment)

return _bytes_to_string(decrypted)
return bytes(bytearray(decrypted))



Expand All @@ -499,15 +475,15 @@ def __init__(self, key, iv = None):
elif len(iv) != 16:
raise ValueError('initialization vector must be 16 bytes')
else:
self._last_precipherblock = _string_to_bytes(iv)
self._last_precipherblock = bytearray(iv)

self._remaining_block = [ ]

AESBlockModeOfOperation.__init__(self, key)

def encrypt(self, plaintext):
encrypted = [ ]
for p in _string_to_bytes(plaintext):
for p in bytearray(plaintext):
if len(self._remaining_block) == 0:
self._remaining_block = self._aes.encrypt(self._last_precipherblock)
self._last_precipherblock = [ ]
Expand All @@ -516,7 +492,7 @@ def encrypt(self, plaintext):
cipherbyte = p ^ precipherbyte
encrypted.append(cipherbyte)

return _bytes_to_string(encrypted)
return bytes(bytearray(encrypted))

def decrypt(self, ciphertext):
# AES-OFB is symetric
Expand Down Expand Up @@ -567,12 +543,12 @@ def encrypt(self, plaintext):
self._remaining_counter += self._aes.encrypt(self._counter.value)
self._counter.increment()

plaintext = _string_to_bytes(plaintext)
plaintext = bytearray(plaintext)

encrypted = [ (p ^ c) for (p, c) in zip(plaintext, self._remaining_counter) ]
self._remaining_counter = self._remaining_counter[len(encrypted):]

return _bytes_to_string(encrypted)
return bytes(bytearray(encrypted))

def decrypt(self, crypttext):
# AES-CTR is symetric
Expand Down
18 changes: 9 additions & 9 deletions pyaes/blockfeeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
# THE SOFTWARE.


from .aes import AESBlockModeOfOperation, AESSegmentModeOfOperation, AESStreamModeOfOperation
from .util import append_PKCS7_padding, strip_PKCS7_padding, to_bufferable
from pyaes.aes import AESBlockModeOfOperation, AESSegmentModeOfOperation, AESStreamModeOfOperation
from pyaes.util import append_PKCS7_padding, strip_PKCS7_padding


# First we inject three functions to each of the modes of operations
Expand Down Expand Up @@ -99,17 +99,17 @@ def _segment_final_encrypt(self, data, padding = PADDING_DEFAULT):
if padding != PADDING_DEFAULT:
raise Exception('invalid padding option')

faux_padding = (chr(0) * (self.segment_bytes - (len(data) % self.segment_bytes)))
padded = data + to_bufferable(faux_padding)
faux_padding = (b'\x00' * (self.segment_bytes - (len(data) % self.segment_bytes)))
padded = data + bytes(faux_padding)
return self.encrypt(padded)[:len(data)]

# CFB can handle a non-segment-sized block at the end using the remaining cipherblock
def _segment_final_decrypt(self, data, padding = PADDING_DEFAULT):
if padding != PADDING_DEFAULT:
raise Exception('invalid padding option')

faux_padding = (chr(0) * (self.segment_bytes - (len(data) % self.segment_bytes)))
padded = data + to_bufferable(faux_padding)
faux_padding = (b'\x00' * (self.segment_bytes - (len(data) % self.segment_bytes)))
padded = data + bytes(faux_padding)
return self.decrypt(padded)[:len(data)]

AESSegmentModeOfOperation._can_consume = _segment_can_consume
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, mode, feed, final, padding = PADDING_DEFAULT):
self._mode = mode
self._feed = feed
self._final = final
self._buffer = to_bufferable("")
self._buffer = b""
self._padding = padding

def feed(self, data = None):
Expand All @@ -170,10 +170,10 @@ def feed(self, data = None):
self._buffer = None
return result

self._buffer += to_bufferable(data)
self._buffer += bytes(data)

# We keep 16 bytes around so we can determine padding
result = to_bufferable('')
result = b''
while len(self._buffer) > 16:
can_consume = self._mode._can_consume(len(self._buffer) - 16)
if can_consume == 0: break
Expand Down
21 changes: 2 additions & 19 deletions pyaes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,16 @@
# represent arbitrary binary data, we must use the "bytes" object. This method
# ensures the object behaves as we need it to.

def to_bufferable(binary):
return binary

def _get_byte(c):
return ord(c)

try:
xrange
except:

def to_bufferable(binary):
if isinstance(binary, bytes):
return binary
return bytes(ord(b) for b in binary)

def _get_byte(c):
return c

def append_PKCS7_padding(data):
pad = 16 - (len(data) % 16)
return data + to_bufferable(chr(pad) * pad)
return data + bytes(bytearray([pad])) * pad

def strip_PKCS7_padding(data):
if len(data) % 16 != 0:
raise ValueError("invalid length")

pad = _get_byte(data[-1])
pad = bytearray(data)[-1]

if pad > 16:
raise ValueError("invalid padding byte")
Expand Down
6 changes: 2 additions & 4 deletions tests/test-aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


import sys
sys.path.append('../pyaes')
sys.path.append('..')

from pyaes import *

Expand All @@ -33,8 +33,6 @@
xrange
except NameError:
xrange = range
else:
pass

# compare against a known working implementation
from Crypto.Cipher import AES as KAES
Expand Down Expand Up @@ -141,7 +139,7 @@
tt_kdecrypt += time.time() - t0

t0 = time.time()
dt2 = [aes2.decrypt(k) for k in kenc]
dt2 = [aes2.decrypt(k) for k in enc]
tt_decrypt += time.time() - t0

if plaintext != dt2:
Expand Down
Loading