diff --git a/coredis/__init__.py b/coredis/__init__.py index c21888b6..ee5bec03 100644 --- a/coredis/__init__.py +++ b/coredis/__init__.py @@ -3,6 +3,9 @@ from coredis.connection import ClusterConnection, Connection, UnixDomainSocketConnection from coredis.exceptions import ( AskError, + AuthenticationFailureError, + AuthenticationRequiredError, + AuthorizationError, BusyLoadingError, CacheError, ClusterCrossSlotError, @@ -41,6 +44,9 @@ "ConnectionPool", "ClusterConnectionPool", "AskError", + "AuthenticationFailureError", + "AuthenticationRequiredError", + "AuthorizationError", "BusyLoadingError", "CacheError", "ClusterCrossSlotError", diff --git a/coredis/client.py b/coredis/client.py index e8b44de5..87ee33a1 100644 --- a/coredis/client.py +++ b/coredis/client.py @@ -140,6 +140,7 @@ def __init__( host="localhost", port=6379, db=0, + username=None, password=None, stream_timeout=None, connect_timeout=None, @@ -164,6 +165,7 @@ def __init__( if not connection_pool: kwargs = { "db": db, + "username": username, "password": password, "encoding": encoding, "stream_timeout": stream_timeout, diff --git a/coredis/connection.py b/coredis/connection.py index 93703779..656f9dc0 100755 --- a/coredis/connection.py +++ b/coredis/connection.py @@ -10,6 +10,9 @@ from coredis.exceptions import ( AskError, + AuthenticationFailureError, + AuthenticationRequiredError, + AuthorizationError, BusyLoadingError, ClusterCrossSlotError, ClusterDownError, @@ -73,6 +76,7 @@ async def _read_from_socket(self, length=None): while True: data = await self._stream.read(self.read_size) # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: raise ConnectionError("Socket closed on remote end") buf.write(data) @@ -82,6 +86,7 @@ async def _read_from_socket(self, length=None): if length is not None and length > marker: continue + break except socket.error: e = sys.exc_info()[1] @@ -90,6 +95,7 @@ async def _read_from_socket(self, length=None): async def read(self, length): length = length + 2 # make sure to read the \r\n terminator # make sure we've read enough data from the socket + if length > self.length: await self._read_from_socket(length - self.length) @@ -99,6 +105,7 @@ async def read(self, length): # purge the buffer when we've consumed it all so it doesn't # grow forever + if self.bytes_read == self.bytes_written: self.purge() @@ -108,6 +115,7 @@ async def readline(self): buf = self._buffer buf.seek(self.bytes_read) data = buf.readline() + while not data.endswith(SYM_CRLF): # there's more data in the socket that we need await self._read_from_socket() @@ -118,6 +126,7 @@ async def readline(self): # purge the buffer when we've consumed it all so it doesn't # grow forever + if self.bytes_read == self.bytes_written: self.purge() @@ -158,17 +167,24 @@ class BaseParser: "MOVED": MovedError, "CLUSTERDOWN": ClusterDownError, "CROSSSLOT": ClusterCrossSlotError, + "WRONGPASS": AuthenticationFailureError, + "NOAUTH": AuthenticationRequiredError, + "NOPERM": AuthorizationError, } def parse_error(self, response): """Parse an error response""" error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1 :] exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) @@ -189,13 +205,16 @@ def on_connect(self, connection): """Called when the stream connects""" self._stream = connection._reader self._buffer = SocketBuffer(self._stream, self._read_size) + if connection.decode_responses: self.encoding = connection.encoding def on_disconnect(self): """Called when the stream disconnects""" + if self._stream is not None: self._stream = None + if self._buffer is not None: self._buffer.close() self._buffer = None @@ -208,6 +227,7 @@ async def read_response(self): if not self._buffer: raise ConnectionError("Socket closed on remote end") response = await self._buffer.readline() + if not response: raise ConnectionError("Socket closed on remote end") @@ -217,17 +237,20 @@ async def read_response(self): raise InvalidResponse("Protocol Error: %s, %s" % (str(byte), str(response))) # server returned an error + if byte == "-": response = response.decode() error = self.parse_error(response) # if the error is a ConnectionError, raise immediately so the user # is notified + if isinstance(error, ConnectionError): raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() # and/or the pipeline's execute() will raise this error if # necessary, so just return the exception instance here. + return error # single value elif byte == "+": @@ -238,19 +261,24 @@ async def read_response(self): # bulk response elif byte == "$": length = int(response) + if length == -1: return None response = await self._buffer.read(length) # multi-bulk response elif byte == "*": length = int(response) + if length == -1: return None response = [] + for i in range(length): response.append(await self.read_response()) + if isinstance(response, bytes) and self.encoding: response = response.decode(self.encoding) + return response @@ -276,6 +304,7 @@ def can_read(self): if self._next_response is False: self._next_response = self._reader.gets() + return self._next_response is not False def on_connect(self, connection): @@ -284,6 +313,7 @@ def on_connect(self, connection): "protocolError": InvalidResponse, "replyError": ResponseError, } + if connection.decode_responses: kwargs["encoding"] = connection.encoding self._reader = hiredis.Reader(**kwargs) @@ -300,12 +330,15 @@ async def read_response(self): raise ConnectionError("Socket closed on remote end") # _next_response might be cached from a can_read() call + if self._next_response is not False: response = self._next_response self._next_response = False + return response response = self._reader.gets() + while response is False: try: buffer = await self._stream.read(self._read_size) @@ -318,12 +351,15 @@ async def read_response(self): raise ConnectionError( "Error {} while reading from stream: {}".format(type(e), e.args) ) + if not buffer: raise ConnectionError("Socket closed on remote end") self._reader.feed(buffer) response = self._reader.gets() + if isinstance(response, ResponseError): response = self.parse_error(response.args[0]) + return response @@ -337,6 +373,7 @@ class RedisSSLContext: def __init__(self, keyfile=None, certfile=None, cert_reqs=None, ca_certs=None): self.keyfile = keyfile self.certfile = certfile + if cert_reqs is None: self.cert_reqs = ssl.CERT_NONE elif isinstance(cert_reqs, str): @@ -345,6 +382,7 @@ def __init__(self, keyfile=None, certfile=None, cert_reqs=None, ca_certs=None): "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, } + if cert_reqs not in CERT_REQS: raise RedisError( "Invalid SSL Certificate Requirements Flag: %s" % cert_reqs @@ -361,6 +399,7 @@ def get(self): self.context.verify_mode = self.cert_reqs self.context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) self.context.load_verify_locations(self.ca_certs) + return self.context @@ -418,8 +457,10 @@ def clear_connect_callbacks(self): async def can_read(self): """Checks for data that can be read""" + if not self.is_connected: await self.connect() + return self._parser.can_read() async def connect(self): @@ -431,11 +472,13 @@ async def connect(self): raise ConnectionError() # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription + for callback in self._connect_callbacks: task = callback(self) # typing.Awaitable is not available in Python3.5 # so use inspect.isawaitable instead # according to issue https://github.com/alisaifee/coredis/issues/77 + if inspect.isawaitable(task): await task @@ -445,22 +488,31 @@ async def _connect(self): async def on_connect(self): self._parser.on_connect(self) - # if a password is specified, authenticate - if self.password: - await self.send_command("AUTH", self.password) + if self.username or self.password: + if self.username or self.password: + await self.send_command("AUTH", self.username, self.password) + elif self.password: + await self.send_command("AUTH", self.password) + if nativestr(await self.read_response()) != "OK": - raise ConnectionError("Invalid Password") + raise ConnectionError( + f"Failed to authenticate: username={self.username} & password={self.password}" + ) # if a database is specified, switch to it + if self.db: await self.send_command("SELECT", self.db) + if nativestr(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") if self.client_name is not None: await self.send_command("CLIENT SETNAME", self.client_name) + if nativestr(await self.read_response()) != "OK": raise ConnectionError(f"Failed to set client name: {self.client_name}") + self.last_active_at = time.time() async def read_response(self): @@ -472,13 +524,16 @@ async def read_response(self): except TimeoutError: self.disconnect() raise + if isinstance(response, RedisError): raise response self.awaiting_response = False + return response async def send_packed_command(self, command): """Sends an already packed command to the Redis server""" + if not self._writer: await self.connect() try: @@ -491,6 +546,7 @@ async def send_packed_command(self, command): except Exception: e = sys.exc_info()[1] self.disconnect() + if len(e.args) == 1: errno, errmsg = "UNKNOWN", e.args[0] else: @@ -512,6 +568,7 @@ async def send_command(self, *args): def encode(self, value): """Returns a bytestring representation of the value""" + if isinstance(value, bytes): return value elif isinstance(value, int): @@ -520,8 +577,10 @@ def encode(self, value): value = b(repr(value)) elif not isinstance(value, str): value = str(value) + if isinstance(value, str): value = value.encode(self.encoding) + return value def disconnect(self): @@ -543,15 +602,18 @@ def pack_command(self, *args): # manually. All of these arguements get wrapped in the Token class # to prevent them from being encoded. command = args[0] + if " " in command: args = tuple([b(s) for s in command.split()]) + args[1:] else: args = (b(command),) + args[1:] buff = SYM_EMPTY.join((SYM_STAR, b(str(len(args))), SYM_CRLF)) + for arg in map(self.encode, args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values + if len(buff) > 6000 or len(arg) > 6000: buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) output.append(buff) @@ -562,6 +624,7 @@ def pack_command(self, *args): (buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, b(arg), SYM_CRLF) ) output.append(buff) + return output def pack_commands(self, commands): @@ -582,6 +645,7 @@ def pack_commands(self, commands): if pieces: output.append(SYM_EMPTY.join(pieces)) + return output @@ -592,6 +656,7 @@ def __init__( self, host="127.0.0.1", port=6379, + username=None, password=None, db=0, retry_on_timeout=False, @@ -620,6 +685,7 @@ def __init__( ) self.host = host self.port = port + self.username = username self.password = password self.db = db self.ssl_context = ssl_context @@ -643,12 +709,15 @@ async def _connect(self): self._reader = reader self._writer = writer sock = writer.transport.get_extra_info("socket") + if sock is not None: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: # TCP_KEEPALIVE + if self.socket_keepalive: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): sock.setsockopt(socket.SOL_TCP, k, v) except (socket.error, TypeError): @@ -665,6 +734,7 @@ class UnixDomainSocketConnection(BaseConnection): def __init__( self, path="", + username=None, password=None, db=0, retry_on_timeout=False, @@ -691,6 +761,7 @@ def __init__( ) self.path = path self.db = db + self.username = username self.password = password self.ssl_context = ssl_context self._connect_timeout = connect_timeout @@ -726,11 +797,14 @@ async def on_connect(self): Initialize the connection, authenticate and select a database and send READONLY if it is set during object initialization. """ + if self.db: warnings.warn("SELECT DB is not allowed in cluster mode") self.db = "" await super(ClusterConnection, self).on_connect() + if self.readonly: await self.send_command("READONLY") + if nativestr(await self.read_response()) != "OK": raise ConnectionError("READONLY command failed") diff --git a/coredis/exceptions.py b/coredis/exceptions.py index 234ff9b7..56212083 100644 --- a/coredis/exceptions.py +++ b/coredis/exceptions.py @@ -144,3 +144,29 @@ class MovedError(AskError): A request sent to a node that doesn't serve this key will be replayed with a ``MOVED`` error that points to the correct node. """ + + +class AuthenticationError(ResponseError): + """ + Base class for authentication errors + """ + + +class AuthenticationFailureError(AuthenticationError): + """ + Raised when authentication parameters were provided + but were invalid + """ + + +class AuthenticationRequiredError(AuthenticationError): + """ + Raised when authentication parameters are required + but not provided + """ + + +class AuthorizationError(RedisError): + """ + Base class for authorization errors + """ diff --git a/coredis/pool.py b/coredis/pool.py index 6f48881f..a19130c7 100644 --- a/coredis/pool.py +++ b/coredis/pool.py @@ -25,8 +25,10 @@ def to_bool(value): if value is None or value == "": return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: return False + return bool(value) @@ -90,6 +92,7 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): for name, value in iter(parse_qs(qs).items()): if value and len(value) > 0: parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: try: url_options[name] = parser(value[0]) @@ -102,19 +105,23 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): else: url_options[name] = value[0] + username = url.username + password = url.password + path = url.path + hostname = url.hostname + if decode_components: - password = unquote(url.password) if url.password else None - path = unquote(url.path) if url.path else None - hostname = unquote(url.hostname) if url.hostname else None - else: - password = url.password - path = url.path - hostname = url.hostname + username = unquote(username) if username else None + password = unquote(password) if password else None + path = unquote(path) if path else None + hostname = unquote(hostname) if hostname else None # We only support redis:// and unix:// schemes. + if url.scheme == "unix": url_options.update( { + "username": username, "password": password, "path": path, "connection_class": UnixDomainSocketConnection, @@ -123,11 +130,17 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): else: url_options.update( - {"host": hostname, "port": int(url.port or 6379), "password": password} + { + "host": hostname, + "port": int(url.port or 6379), + "username": username, + "password": password, + } ) # If there's a path argument, use it as the db argument if a # querystring value wasn't specified + if "db" not in url_options and path: try: url_options["db"] = int(path.replace("/", "")) @@ -148,6 +161,7 @@ def from_url(cls, url, db=None, decode_components=False, **kwargs): # update the arguments from the URL values kwargs.update(url_options) + return cls(**kwargs) def __init__( @@ -169,6 +183,7 @@ def __init__( connection_class. """ max_connections = max_connections or 2 ** 31 + if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') @@ -199,6 +214,7 @@ async def disconnect_on_idle_time_exceeded(self, connection): except ValueError: pass self._created_connections -= 1 + break await asyncio.sleep(self.idle_check_interval) @@ -215,6 +231,7 @@ def _checkpid(self): if self.pid == os.getpid(): # another thread already did the work while we waited # on the lock. + return self.disconnect() self.reset() @@ -227,26 +244,32 @@ def get_connection(self, *args, **kwargs): except IndexError: connection = self.make_connection() self._in_use_connections.add(connection) + return connection def make_connection(self): """Creates a new connection""" + if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 connection = self.connection_class(**self.connection_kwargs) + if self.max_idle_time > self.idle_check_interval > 0: # do not await the future asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) + return connection def release(self, connection): """Releases the connection back to the pool""" self._checkpid() + if connection.pid != self.pid: return self._in_use_connections.remove(connection) # discard connection with unread response + if connection.awaiting_response: connection.disconnect() self._created_connections -= 1 @@ -256,6 +279,7 @@ def release(self, connection): def disconnect(self): """Closes all connections in the pool""" all_conns = chain(self._available_connections, self._in_use_connections) + for connection in all_conns: connection.disconnect() self._created_connections -= 1 @@ -296,6 +320,7 @@ def __init__( # Special case to make from_url method compliant with cluster setting. # from_url method will send in the ip and port through a different variable then the # regular startup_nodes variable. + if startup_nodes is None: if "port" in connection_kwargs and "host" in connection_kwargs: startup_nodes = [ @@ -335,6 +360,7 @@ def __repr__(self): Returns a string with all unique ip:port combinations that this pool is connected to """ + return "{0}<{1}>".format( type(self).__name__, ", ".join( @@ -360,6 +386,7 @@ async def disconnect_on_idle_time_exceeded(self, connection): node = connection.node self._available_connections[node["name"]].remove(connection) self._created_connections_per_node[node["name"]] -= 1 + break await asyncio.sleep(self.idle_check_interval) @@ -378,12 +405,14 @@ def _checkpid(self): if self.pid == os.getpid(): # another thread already did the work while we waited # on the lockself. + return self.disconnect() self.reset() def get_connection(self, command_name, *keys, **options): # Only pubsub command/connection should be allowed here + if command_name != "pubsub": raise RedisClusterException( "Only 'pubsub' commands can use get_connection()" @@ -413,6 +442,7 @@ def get_connection(self, command_name, *keys, **options): def make_connection(self, node): """Creates a new connection""" + if self.count_all_num_connections(node) >= self.max_connections: if self.max_connections_per_node: raise RedisClusterException( @@ -431,14 +461,17 @@ def make_connection(self, node): # Must store node in the connection to make it eaiser to track connection.node = node + if self.max_idle_time > self.idle_check_interval > 0: # do not await the future asyncio.ensure_future(self.disconnect_on_idle_time_exceeded(connection)) + return connection def release(self, connection): """Releases the connection back to the pool""" self._checkpid() + if connection.pid != self.pid: return @@ -446,14 +479,17 @@ def release(self, connection): # pool. There is cases where the connection is to be removed but it will not exist and # there must be a safe way to remove i_c = self._in_use_connections.get(connection.node["name"], set()) + if connection in i_c: i_c.remove(connection) else: pass # discard connection with unread response + if connection.awaiting_response: connection.disconnect() # reduce node connection count in case of too many connection error raised + if ( self.max_connections_per_node and self._created_connections_per_node.get(connection.node["name"]) @@ -483,12 +519,15 @@ def count_all_num_connections(self, node): def get_random_connection(self): """Opens new connection to random redis server""" + if self._available_connections: node_name = random.choice(list(self._available_connections.keys())) conn_list = self._available_connections[node_name] # check it in case of empty connection list + if conn_list: return conn_list.pop() + for node in self.nodes.random_startup_node_iter(): connection = self.get_connection_by_node(node) @@ -538,4 +577,5 @@ def get_master_node_by_slot(self, slot): def get_node_by_slot(self, slot): if self.readonly: return random.choice(self.nodes.slots[slot]) + return self.get_master_node_by_slot(slot) diff --git a/docs/source/api.rst b/docs/source/api.rst index 7eb4b56d..ac869117 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -40,11 +40,19 @@ Exceptions .. currentmodule:: coredis -.. autoexception:: AskError +Authentication & Authorization +------------------------------ + +.. autoexception:: AuthenticationFailureError :no-inherited-members: -.. autoexception:: BusyLoadingError +.. autoexception:: AuthenticationRequiredError :no-inherited-members: -.. autoexception:: CacheError +.. autoexception:: AuthorizationError + :no-inherited-members: + +Cluster +------- +.. autoexception:: AskError :no-inherited-members: .. autoexception:: ClusterCrossSlotError :no-inherited-members: @@ -54,6 +62,17 @@ Exceptions :no-inherited-members: .. autoexception:: ClusterTransactionError :no-inherited-members: +.. autoexception:: MovedError + :no-inherited-members: +.. autoexception:: RedisClusterException + :no-inherited-members: + +General Exceptions +------------------- +.. autoexception:: BusyLoadingError + :no-inherited-members: +.. autoexception:: CacheError + :no-inherited-members: .. autoexception:: CompressError :no-inherited-members: .. autoexception:: ConnectionError @@ -66,16 +85,12 @@ Exceptions :no-inherited-members: .. autoexception:: LockError :no-inherited-members: -.. autoexception:: MovedError - :no-inherited-members: .. autoexception:: NoScriptError :no-inherited-members: .. autoexception:: PubSubError :no-inherited-members: .. autoexception:: ReadOnlyError :no-inherited-members: -.. autoexception:: RedisClusterException - :no-inherited-members: .. autoexception:: RedisError :no-inherited-members: .. autoexception:: ResponseError diff --git a/tests/client/test_connection_pool.py b/tests/client/test_connection_pool.py index 62661da3..40686d15 100644 --- a/tests/client/test_connection_pool.py +++ b/tests/client/test_connection_pool.py @@ -113,6 +113,7 @@ def test_defaults(self): "host": "localhost", "port": 6379, "db": 0, + "username": None, "password": None, } @@ -123,6 +124,7 @@ def test_hostname(self): "host": "myhost", "port": 6379, "db": 0, + "username": None, "password": None, } @@ -135,6 +137,7 @@ def test_quoted_hostname(self): "host": "my / host +=+", "port": 6379, "db": 0, + "username": None, "password": None, } @@ -145,6 +148,7 @@ def test_port(self): "host": "localhost", "port": 6380, "db": 0, + "username": None, "password": None, } @@ -155,6 +159,7 @@ def test_password(self): "host": "localhost", "port": 6379, "db": 0, + "username": "", "password": "mypassword", } @@ -167,6 +172,7 @@ def test_quoted_password(self): "host": "localhost", "port": 6379, "db": 0, + "username": None, "password": "/mypass/+ word=$+", } @@ -177,6 +183,7 @@ def test_db_as_argument(self): "host": "localhost", "port": 6379, "db": 1, + "username": None, "password": None, } @@ -187,6 +194,7 @@ def test_db_in_path(self): "host": "localhost", "port": 6379, "db": 2, + "username": None, "password": None, } @@ -197,6 +205,7 @@ def test_db_in_querystring(self): "host": "localhost", "port": 6379, "db": 3, + "username": None, "password": None, } @@ -212,6 +221,7 @@ def test_extra_typed_querystring_options(self): "db": 2, "stream_timeout": 20.0, "connect_timeout": 10.0, + "username": None, "password": None, } @@ -257,6 +267,7 @@ def test_extra_querystring_options(self): "host": "localhost", "port": 6379, "db": 0, + "username": None, "password": None, "a": "1", "b": "2", @@ -269,6 +280,7 @@ def test_client_creates_connection_pool(self): "host": "myhost", "port": 6379, "db": 0, + "username": None, "password": None, } @@ -280,6 +292,7 @@ def test_defaults(self): assert pool.connection_kwargs == { "path": "/socket", "db": 0, + "username": None, "password": None, } @@ -289,6 +302,7 @@ def test_password(self): assert pool.connection_kwargs == { "path": "/socket", "db": 0, + "username": "", "password": "mypassword", } @@ -300,6 +314,7 @@ def test_quoted_password(self): assert pool.connection_kwargs == { "path": "/socket", "db": 0, + "username": None, "password": "/mypass/+ word=$+", } @@ -312,6 +327,7 @@ def test_quoted_path(self): assert pool.connection_kwargs == { "path": "/my/path/to/../+_+=$ocket", "db": 0, + "username": None, "password": "mypassword", } @@ -321,6 +337,7 @@ def test_db_as_argument(self): assert pool.connection_kwargs == { "path": "/socket", "db": 1, + "username": None, "password": None, } @@ -330,6 +347,7 @@ def test_db_in_querystring(self): assert pool.connection_kwargs == { "path": "/socket", "db": 2, + "username": None, "password": None, } @@ -339,6 +357,7 @@ def test_extra_querystring_options(self): assert pool.connection_kwargs == { "path": "/socket", "db": 0, + "username": None, "password": None, "a": "1", "b": "2", @@ -354,6 +373,7 @@ def test_defaults(self): "host": "localhost", "port": 6379, "db": 0, + "username": None, "password": None, }