Skip to content

Commit

Permalink
Update example to return reader to give user more flexibility. Also f…
Browse files Browse the repository at this point in the history
…ix client auth middleware returning null bearer token error
  • Loading branch information
ravjotbrar committed Jan 11, 2024
1 parent 8e54411 commit 5ae5e87
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 31 deletions.
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,4 @@ _description:_ The specific engine to run against. Only applicable to Dremio Clo
This lightweight Python client application connects to the Dremio Arrow Flight server endpoint. Developers can use token based or regular user credentials (username/password) for authentication. Please note username/password is not supported for Dremio Cloud. Dremio Cloud requires a token. Any datasets in Dremio that are accessible by the provided Dremio user can be queried. Developers can change settings by providing options in a config yaml file before running the client.
Moreover, the tls option can be provided to establish an encrypted connection.
The example includes a function called get_reader, which returns a FlightStreamReader. Users can choose to read the data based on the methods available in the [FlightStreamReader class](https://arrow.apache.org/docs/python/generated/pyarrow.flight.FlightStreamReader.html#pyarrow.flight.FlightStreamReader). In our example, we've decided to read the data into a Pandas dataframe.
1 change: 0 additions & 1 deletion python/dremio-flight/dremio/flight/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _connect_to_software(
middleware=[client_auth_middleware, client_cookie_middleware],
**tls_args,
)

# Authenticate with the server endpoint.
password_or_token = self.password if self.password else self.token
bearer_token = client.authenticate_basic_token(
Expand Down
6 changes: 3 additions & 3 deletions python/dremio-flight/dremio/flight/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def __init__(self, connection_args: dict) -> None:
def connect(self) -> flight.FlightClient:
return self.dremio_flight_conn.connect()

def execute_query(self, flight_client: flight.FlightClient) -> DataFrame:
def get_reader(self, client: flight.FlightClient) -> flight.FlightStreamReader:
dremio_flight_query = DremioFlightEndpointQuery(
self.connection_args.get("query"), flight_client, self.dremio_flight_conn
self.connection_args.get("query"), client, self.dremio_flight_conn
)
return dremio_flight_query.execute_query()
return dremio_flight_query.get_reader()
20 changes: 3 additions & 17 deletions python/dremio-flight/dremio/flight/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.client = client
self.headers = getattr(connection, "headers")

def execute_query(self) -> DataFrame:
def get_reader(self) -> flight.FlightStreamReader:
try:
options = flight.FlightCallOptions(headers=self.headers)
# Get the FlightInfo message to retrieve the Ticket corresponding
Expand All @@ -43,25 +43,11 @@ def execute_query(self) -> DataFrame:
logging.info("GetFlightInfo was successful")
logging.debug("Ticket: %s", flight_info.endpoints[0].ticket)

# Retrieve the result set as pandas DataFrame
reader = self.client.do_get(flight_info.endpoints[0].ticket, options)
return self._get_chunks(reader)
# Retrieve the reader
return self.client.do_get(flight_info.endpoints[0].ticket, options)

except Exception:
logging.exception(
"There was an error trying to get the data from the flight endpoint"
)
raise

def _get_chunks(self, reader: flight.FlightStreamReader) -> DataFrame:
dataframe = DataFrame()
while True:
try:
flight_batch = reader.read_chunk()
record_batch = flight_batch.data
data_to_pandas = record_batch.to_pandas()
dataframe = concat([dataframe, data_to_pandas])
except StopIteration:
break

return dataframe
4 changes: 4 additions & 0 deletions python/dremio-flight/dremio/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def __init__(self, factory):
self.factory = factory

def received_headers(self, headers):
if self.factory.call_credential:
return

auth_header_key = "authorization"

authorization_header = reduce(
lambda result, header: header[1]
if header[0] == auth_header_key
Expand Down
10 changes: 5 additions & 5 deletions python/dremio-flight/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
from argparse import Namespace
from numpy import array, array_equal
from pyarrow.flight import FlightUnauthenticatedError, FlightUnavailableError
from pyarrow.flight import FlightUnauthenticatedError, FlightInternalError
from dotenv import load_dotenv
import certifi
import os
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_simple_query():
dremio_flight_query = DremioFlightEndpointQuery(
args_dict["query"], flight_client, dremio_flight_conn
)
dataframe = dremio_flight_query.execute_query()
dataframe = dremio_flight_query.get_reader().read_pandas()
dataframe_arr = dataframe.to_numpy()
expected_arr = array([[1, 2, 3]])
assert array_equal(dataframe_arr, expected_arr)
Expand All @@ -96,7 +96,7 @@ def test_tls():
dremio_flight_query = DremioFlightEndpointQuery(
args_dict_ssl["query"], flight_client, dremio_flight_conn
)
dataframe = dremio_flight_query.execute_query()
dataframe = dremio_flight_query.get_reader().read_pandas()
dataframe_arr = dataframe.to_numpy()
expected_arr = array([[1, 2, 3]])
assert array_equal(dataframe_arr, expected_arr)
Expand All @@ -110,7 +110,7 @@ def test_bad_hostname():
args_dict_bad_hostname["hostname"] = "ha-ha!"

dremio_flight_conn = DremioFlightEndpointConnection(args_dict_bad_hostname)
with pytest.raises(FlightUnavailableError):
with pytest.raises(FlightInternalError):
dremio_flight_conn.connect()


Expand All @@ -122,7 +122,7 @@ def test_bad_port():
args_dict_bad_port["port"] = 12345

dremio_flight_conn = DremioFlightEndpointConnection(args_dict_bad_port)
with pytest.raises(FlightUnavailableError):
with pytest.raises(FlightInternalError):
dremio_flight_conn.connect()


Expand Down
8 changes: 4 additions & 4 deletions python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
# Connect to Dremio Arrow Flight server endpoint.
flight_client = dremio_flight_endpoint.connect()

# Execute query
dataframe = dremio_flight_endpoint.execute_query(flight_client)
# Get reader
reader = dremio_flight_endpoint.get_reader(flight_client)

# Print out the data
print(dataframe)
# Print out the data as a dataframe
print(reader.read_pandas())

0 comments on commit 5ae5e87

Please sign in to comment.