diff --git a/examples/bin/dsql b/examples/bin/dsql index b402e17534d4..dd02a9e893c9 100755 --- a/examples/bin/dsql +++ b/examples/bin/dsql @@ -21,9 +21,16 @@ PWD="$(pwd)" WHEREAMI="$(dirname "$0")" WHEREAMI="$(cd "$WHEREAMI" && pwd)" -if [ -x "$(command -v python2)" ] +if [ -x "$(command -v python3)" ] then - exec python2 "$WHEREAMI/dsql-main" "$@" + exec python3 "$WHEREAMI/dsql-main-py3" "$@" +elif [ -x "$(command -v python2)" ] +then + echo "Warning: Support for Python 2 will be removed in the future. Please consider upgrading to Python 3" + exec python2 "$WHEREAMI/dsql-main-py2" "$@" +elif [ -x "$(command -v python)" ] +then + exec python "$WHEREAMI/dsql-main-py3" "$@" else - exec "$WHEREAMI/dsql-main" "$@" + echo "python interepreter not found" fi diff --git a/examples/bin/dsql-main b/examples/bin/dsql-main-py2 old mode 100755 new mode 100644 similarity index 99% rename from examples/bin/dsql-main rename to examples/bin/dsql-main-py2 index c24602739df3..d7325447c65c --- a/examples/bin/dsql-main +++ b/examples/bin/dsql-main-py2 @@ -17,6 +17,11 @@ # specific language governing permissions and limitations # under the License. +# NOTE: +# Any feature updates to this script must also be reflected in +# `dsql-main-py3` so that intended changes work for users using +# Python 2 or 3. + from __future__ import print_function import argparse diff --git a/examples/bin/dsql-main-py3 b/examples/bin/dsql-main-py3 new file mode 100755 index 000000000000..bc573f439142 --- /dev/null +++ b/examples/bin/dsql-main-py3 @@ -0,0 +1,523 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: +# Any feature updates to this script must also be reflected in +# `dsql-main-py2` so that intended changes work for users using +# Python 2 or 3. + +import argparse +import base64 +import collections +import csv +import errno +import json +import numbers +import os +import re +import readline +import ssl +import sys +import time +import unicodedata +import urllib.request +import urllib.error + +class DruidSqlException(Exception): + def friendly_message(self): + return getattr(self, 'message', 'Query failed') + + def write_to(self, f): + f.write('\x1b[31m') + f.write(self.friendly_message()) + f.write('\x1b[0m') + f.write('\n') + f.flush() + +def do_query_with_args(url, sql, context, args): + return do_query(url, sql, context, args.timeout, args.user, args.ignore_ssl_verification, args.cafile, args.capath, args.certchain, args.keyfile, args.keypass) + +def do_query(url, sql, context, timeout, user, ignore_ssl_verification, ca_file, ca_path, cert_chain, key_file, key_pass): + json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict) + try: + if timeout <= 0: + timeout = None + query_context = context + elif int(context.get('timeout', 0)) / 1000. < timeout: + query_context = context.copy() + query_context['timeout'] = timeout * 1000 + + sql_json = json.dumps({'query' : sql, 'context' : query_context}) + + # SSL stuff + ssl_context = None + if ignore_ssl_verification or ca_file is not None or ca_path is not None or cert_chain is not None: + ssl_context = ssl.create_default_context() + if ignore_ssl_verification: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif ca_path is not None: + ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path) + else: + ssl_context.load_cert_chain(certfile=cert_chain, keyfile=key_file, password=key_pass) + + req = urllib.request.Request(url, sql_json.encode('utf-8'), {'Content-Type' : 'application/json'}) + + if user: + req.add_header("Authorization", "Basic %s" % base64.b64encode(user.encode('utf-8')).decode('utf-8')) + + response = urllib.request.urlopen(req, None, timeout, context=ssl_context) + + first_chunk = True + eof = False + buf = '' + + while not eof or len(buf) > 0: + while True: + try: + # Remove starting ',' + buf = buf.lstrip(',') + obj, sz = json_decoder.raw_decode(buf) + yield obj + buf = buf[sz:] + except ValueError as e: + # Maybe invalid JSON, maybe partial object; it's hard to tell with this library. + if eof and buf.rstrip() == ']': + # Stream done and all objects read. + buf = '' + break + elif eof or len(buf) > 256 * 1024: + # If we read more than 256KB or if it's eof then report the parse error. + raise + else: + # Stop reading objects, get more from the stream instead. + break + + # Read more from the http stream + if not eof: + chunk = response.read(8192).decode('utf-8') + if chunk: + buf = buf + chunk + if first_chunk: + # Remove starting '[' + buf = buf.lstrip('[') + else: + # Stream done. Keep reading objects out of buf though. + eof = True + + except urllib.error.URLError as e: + raise_friendly_error(e) + +def raise_friendly_error(e): + if isinstance(e, urllib.error.HTTPError): + text = e.read().strip() + error_obj = {} + try: + error_obj = dict(json.loads(text)) + except: + pass + if e.code == 500 and 'errorMessage' in error_obj: + error_text = '' + if error_obj['error'] != 'Unknown exception': + error_text = error_text + error_obj['error'] + ': ' + if error_obj['errorClass']: + error_text = error_text + str(error_obj['errorClass']) + ': ' + error_text = error_text + str(error_obj['errorMessage']) + if error_obj['host']: + error_text = error_text + ' (' + str(error_obj['host']) + ')' + raise DruidSqlException(error_text) + elif e.code == 405: + error_text = 'HTTP Error {0}: {1}\n{2}'.format(e.code, e.reason + " - Are you using the correct broker URL and " +\ + "is druid.sql.enabled set to true on your broker?", text) + raise DruidSqlException(error_text) + else: + raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text)) + else: + raise DruidSqlException(str(e)) + +def to_utf8(value): + if value is None: + return b"" + elif isinstance(value, str): + return value.encode("utf-8") + else: + return str(value).encode("utf-8") + + +def to_tsv(values, delimiter): + return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values) + +def print_csv(rows, header): + csv_writer = csv.writer(sys.stdout) + first = True + for row in rows: + if first and header: + csv_writer.writerow(list(to_utf8(k) for k in row.keys())) + first = False + + values = [] + for key, value in row.iteritems(): + values.append(to_utf8(value)) + + csv_writer.writerow(values) + +def print_tsv(rows, header, tsv_delimiter): + first = True + for row in rows: + if first and header: + print(to_tsv(row.keys(), tsv_delimiter)) + first = False + + values = [] + for key, value in row.iteritems(): + values.append(value) + + print(to_tsv(values, tsv_delimiter)) + +def print_json(rows): + for row in rows: + print(json.dumps(row)) + +def table_to_printable_value(value): + # Unicode string, trimmed with control characters removed + if value is None: + return u"NULL" + else: + return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32))) + +def table_compute_string_width(v): + normalized = unicodedata.normalize('NFC', v) + width = 0 + for c in normalized: + ccategory = unicodedata.category(c) + cwidth = unicodedata.east_asian_width(c) + if ccategory == 'Cf': + # Formatting control, zero width + pass + elif cwidth == 'F' or cwidth == 'W': + # Double-wide character, prints in two columns + width = width + 2 + else: + # All other characters + width = width + 1 + return width + +def table_compute_column_widths(row_buffer): + widths = None + for values in row_buffer: + values_widths = [table_compute_string_width(v) for v in values] + if not widths: + widths = values_widths + else: + i = 0 + for v in values: + widths[i] = max(widths[i], values_widths[i]) + i = i + 1 + return widths + +def table_print_row(values, column_widths, column_types): + vertical_line = '\u2502' + for i in range(len(values)): + padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i])) + if column_types and column_types[i] == 'n': + print(vertical_line + ' ' + padding + values[i] + ' ', end="") + else: + print(vertical_line + ' ' + values[i] + padding + ' ', end="") + print(vertical_line) + +def table_print_header(values, column_widths): + # Line 1 + left_corner = '\u250C' + horizontal_line = '\u2500' + top_tee = '\u252C' + right_corner = '\u2510' + print(left_corner, end="") + for i in range(0, len(column_widths)): + print(horizontal_line * max(0, column_widths[i] + 2), end="") + if i + 1 < len(column_widths): + print(top_tee, end="") + print(right_corner) + + # Line 2 + table_print_row(values, column_widths, None) + + # Line 3 + left_tee = '\u251C' + cross = '\u253C' + right_tee = '\u2524' + print(left_tee, end="") + for i in range(0, len(column_widths)): + print(horizontal_line * max(0, column_widths[i] + 2), end="") + if i + 1 < len(column_widths): + print(cross, end="") + print(right_tee) + + +def table_print_bottom(column_widths): + left_corner = '\u2514' + right_corner = '\u2518' + bottom_tee = '\u2534' + horizontal_line = '\u2500' + print(left_corner, end="") + for i in range(0, len(column_widths)): + print(horizontal_line * max(0, column_widths[i] + 2), end="") + if i + 1 < len(column_widths): + print(bottom_tee, end="") + print(right_corner) + + +def table_print_row_buffer(row_buffer, column_widths, column_types): + first = True + for values in row_buffer: + if first: + table_print_header(values, column_widths) + first = False + else: + table_print_row(values, column_widths, column_types) + +def print_table(rows): + start = time.time() + nrows = 0 + first = True + + # Buffer some rows before printing. + rows_to_buffer = 500 + row_buffer = [] + column_types = [] + column_widths = None + + for row in rows: + nrows = nrows + 1 + + if first: + row_buffer.append([table_to_printable_value(k) for k in row.keys()]) + for k in row.keys(): + if isinstance(row[k], numbers.Number): + column_types.append('n') + else: + column_types.append('s') + first = False + + values = [table_to_printable_value(v) for k, v in row.items()] + if rows_to_buffer > 0: + row_buffer.append(values) + rows_to_buffer = rows_to_buffer - 1 + else: + if row_buffer: + column_widths = table_compute_column_widths(row_buffer) + table_print_row_buffer(row_buffer, column_widths, column_types) + del row_buffer[:] + table_print_row(values, column_widths, column_types) + + if row_buffer: + column_widths = table_compute_column_widths(row_buffer) + table_print_row_buffer(row_buffer, column_widths, column_types) + + if column_widths: + table_print_bottom(column_widths) + + print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start)) + print("") + +def display_query(url, sql, context, args): + rows = do_query_with_args(url, sql, context, args) + + if args.format == 'csv': + print_csv(rows, args.header) + elif args.format == 'tsv': + print_tsv(rows, args.header, args.tsv_delimiter) + elif args.format == 'json': + print_json(rows) + elif args.format == 'table': + print_table(rows) + +def sql_literal_escape(s): + if s is None: + return "''" + elif isinstance(s, str): + ustr = s + else: + ustr = str(s) + + escaped = ["U&'"] + + for c in ustr: + ccategory = unicodedata.category(c) + if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ': + escaped.append(c) + else: + escaped.append(u'\\') + escaped.append('%04x' % ord(c)) + + escaped.append("'") + return ''.join(escaped) + +def make_readline_completer(url, context, args): + starters = [ + 'EXPLAIN PLAN FOR', + 'SELECT' + ] + + middlers = [ + 'FROM', + 'WHERE', + 'GROUP BY', + 'ORDER BY', + 'LIMIT' + ] + + def readline_completer(text, state): + if readline.get_begidx() == 0: + results = [x for x in starters if x.startswith(text.upper())] + [None] + else: + results = ([x for x in middlers if x.startswith(text.upper())] + [None]) + + return results[state] + " " + + print("Connected to [" + args.host + "].") + print("") + + return readline_completer + +def main(): + parser = argparse.ArgumentParser(description='Druid SQL command-line client.') + parser_cnn = parser.add_argument_group('Connection options') + parser_fmt = parser.add_argument_group('Formatting options') + parser_oth = parser.add_argument_group('Other options') + parser_cnn.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Druid query host or url, like https://localhost:8282/') + parser_cnn.add_argument('--user', '-u', type=str, help='HTTP basic authentication credentials, like user:password') + parser_cnn.add_argument('--timeout', type=int, default=0, help='Timeout in seconds') + parser_cnn.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/3/library/ssl.html#ssl.SSLContext.') + parser_cnn.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/3/library/ssl.html#ssl.SSLContext.') + parser_cnn.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.') + parser_fmt.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format') + parser_fmt.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"') + parser_fmt.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"') + parser_oth.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection, see https://druid.apache.org/docs/latest/querying/sql.html#connection-context for options') + parser_oth.add_argument('--execute', '-e', type=str, help='Execute single SQL query') + parser_cnn.add_argument('--certchain', type=str, help='Path to SSL certificate used to connect to server. See load_cert_chain() in https://docs.python.org/3/library/ssl.html#ssl.SSLContext.') + parser_cnn.add_argument('--keyfile', type=str, help='Path to private SSL key used to connect to server. See load_cert_chain() in https://docs.python.org/3/library/ssl.html#ssl.SSLContext.') + parser_cnn.add_argument('--keypass', type=str, help='Password to private SSL key file used to connect to server. See load_cert_chain() in https://docs.python.org/3/library/ssl.html#ssl.SSLContext.') + args = parser.parse_args() + + # Build broker URL + url = args.host.rstrip('/') + '/druid/v2/sql/' + if not url.startswith('http:') and not url.startswith('https:'): + url = 'http://' + url + + # Build context + context = {} + if args.context_option: + for opt in args.context_option: + kv = opt.split("=", 1) + if len(kv) != 2: + raise ValueError('Invalid context option, should be key=value: ' + opt) + if re.match(r"^\d+$", kv[1]): + context[kv[0]] = int(kv[1]) + else: + context[kv[0]] = kv[1] + + if args.execute: + display_query(url, args.execute, context, args) + else: + # interactive mode + print("Welcome to dsql, the command-line client for Druid SQL.") + + readline_history_file = os.path.expanduser("~/.dsql_history") + readline.parse_and_bind('tab: complete') + readline.set_history_length(500) + readline.set_completer(make_readline_completer(url, context, args)) + + try: + readline.read_history_file(readline_history_file) + except IOError: + # IOError can happen if the file doesn't exist. + pass + + print("Type \"\\h\" for help.") + + while True: + sql = '' + while not sql.endswith(';'): + prompt = "dsql> " if sql == '' else 'more> ' + try: + more_sql = input(prompt) + except EOFError: + sys.stdout.write('\n') + sys.exit(1) + if sql == '' and more_sql.startswith('\\'): + # backslash command + dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql) + if dmatch: + include_system = dmatch.group(1) + extra_info = dmatch.group(2) + arg = dmatch.group(3).strip() + if arg: + sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_literal_escape(arg) + if not include_system: + sql = sql + " AND TABLE_SCHEMA = 'druid'" + # break to execute sql + break + else: + sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES" + if not include_system: + sql = sql + " WHERE TABLE_SCHEMA = 'druid'" + # break to execute sql + break + + hmatch = re.match(r'^\\h\s*$', more_sql) + if hmatch: + print("Commands:") + print(" \\d show tables") + print(" \\dS show tables, including system tables") + print(" \\d table_name describe table") + print(" \\h show this help") + print(" \\q exit this program") + print("Or enter a SQL query ending with a semicolon (;).") + continue + + qmatch = re.match(r'^\\q\s*$', more_sql) + if qmatch: + sys.exit(0) + + print("No such command: " + more_sql) + else: + sql = (sql + ' ' + more_sql).strip() + + try: + readline.write_history_file(readline_history_file) + display_query(url, sql.rstrip(';'), context, args) + except DruidSqlException as e: + e.write_to(sys.stdout) + except KeyboardInterrupt: + sys.stdout.write("Query interrupted\n") + sys.stdout.flush() + +try: + main() +except DruidSqlException as e: + e.write_to(sys.stderr) + sys.exit(1) +except KeyboardInterrupt: + sys.exit(1) +except IOError as e: + if e.errno == errno.EPIPE: + sys.exit(1) + else: + raise