Skip to content

Commit

Permalink
Add join_by_columns function in Legend API
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-ssh16 authored and gs-ssh16 committed Oct 18, 2023
1 parent 867f12a commit c1a9313
Show file tree
Hide file tree
Showing 8 changed files with 808 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pylegend/core/databse/sql_to_string/db_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def join_processor(
config: SqlToStringConfig
) -> str:
left = extension.process_relation(join.left, config)
right = extension.process_relation(join.right, config)
right = extension.process_relation(join.right, config.push_indent())
join_type = join.type_
condition = "ON ({op})".format(op=extension.process_join_criteria(join.criteria, config)) if join.criteria else ""
if join_type == JoinType.CROSS:
Expand Down
7 changes: 7 additions & 0 deletions pylegend/core/request/legend_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ def execute_sql_string(
stream=True
).iter_content(chunk_size=chunk_size)
return ResponseReader(iter_content)

def __eq__(self, other: 'object') -> bool:
if self is other:
return True
if isinstance(other, LegendClient):
return self.get_host() == other.get_host() and self.get_port() == other.get_port()
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# Copyright 2023 Goldman Sachs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from pylegend._typing import (
PyLegendList,
PyLegendSequence,
)
from pylegend.core.tds.legend_api.frames.legend_api_applied_function_tds_frame import LegendApiAppliedFunction
from pylegend.core.tds.sql_query_helpers import copy_query, create_sub_query, extract_columns_for_subquery
from pylegend.core.sql.metamodel import (
QuerySpecification,
Select,
SelectItem,
SingleColumn,
AliasedRelation,
TableSubquery,
Query,
Join,
JoinType,
JoinOn,
QualifiedNameReference,
QualifiedName,
LogicalBinaryExpression,
LogicalBinaryType,
ComparisonExpression,
ComparisonOperator,
Expression,
)
from pylegend.core.tds.tds_column import TdsColumn
from pylegend.core.tds.tds_frame import FrameToSqlConfig
from pylegend.core.tds.legend_api.frames.legend_api_base_tds_frame import LegendApiBaseTdsFrame
from pylegend.core.tds.legend_api.frames.legend_api_tds_frame import LegendApiTdsFrame


__all__: PyLegendSequence[str] = [
"JoinByColumnsFunction"
]


class JoinByColumnsFunction(LegendApiAppliedFunction):
__base_frame: LegendApiBaseTdsFrame
__other_frame: LegendApiBaseTdsFrame
__column_names_self: PyLegendList[str]
__column_names_other: PyLegendList[str]
__join_type: str

@classmethod
def name(cls) -> str:
return "join_by_columns"

def __init__(
self,
base_frame: LegendApiBaseTdsFrame,
other_frame: LegendApiTdsFrame,
column_names_self: PyLegendList[str],
column_names_other: PyLegendList[str],
join_type: str
) -> None:
self.__base_frame = base_frame
if not isinstance(other_frame, LegendApiBaseTdsFrame):
raise ValueError("Expected LegendApiBaseTdsFrame") # pragma: no cover
self.__other_frame = other_frame
self.__column_names_self = column_names_self
self.__column_names_other = column_names_other
self.__join_type = join_type

def to_sql(self, config: FrameToSqlConfig) -> QuerySpecification:
db_extension = config.sql_to_string_generator().get_db_extension()
base_query = copy_query(self.__base_frame.to_sql_query_object(config))
other_query = copy_query(self.__other_frame.to_sql_query_object(config))
left_alias = db_extension.quote_identifier('left')
right_alias = db_extension.quote_identifier('right')

join_type = (
JoinType.INNER if self.__join_type.lower() == 'inner' else (
JoinType.LEFT if self.__join_type.lower() == 'left_outer' else
JoinType.RIGHT
)
)

def logical_and_expr(left: Expression, right: Expression) -> Expression:
return LogicalBinaryExpression(
type_=LogicalBinaryType.AND,
left=left,
right=right
)
join_expr = functools.reduce(
logical_and_expr, # type: ignore
[
ComparisonExpression(
left=QualifiedNameReference(
name=QualifiedName(parts=[left_alias, db_extension.quote_identifier(x)])
),
right=QualifiedNameReference(
name=QualifiedName(parts=[right_alias, db_extension.quote_identifier(y)])
),
operator=ComparisonOperator.EQUAL
)
for x, y in zip(self.__column_names_self, self.__column_names_other)
]
)

common_join_cols = [x for x, y in zip(self.__column_names_self, self.__column_names_other) if x == y]
common_join_cols.sort()
new_select_items: PyLegendList[SelectItem] = []
for c in (c for c in self.__base_frame.columns() if c.get_name() not in common_join_cols):
q = db_extension.quote_identifier(c.get_name())
new_select_items.append(SingleColumn(q, QualifiedNameReference(name=QualifiedName(parts=[left_alias, q]))))
for c in (c for c in self.__base_frame.columns() if c.get_name() in common_join_cols):
q = db_extension.quote_identifier(c.get_name())
common_col_alias = right_alias if self.__join_type.lower() == 'right_outer' else left_alias
new_select_items.append(
SingleColumn(q, QualifiedNameReference(name=QualifiedName(parts=[common_col_alias, q])))
)
for c in (c for c in self.__other_frame.columns() if c.get_name() not in common_join_cols):
q = db_extension.quote_identifier(c.get_name())
new_select_items.append(SingleColumn(q, QualifiedNameReference(name=QualifiedName(parts=[right_alias, q]))))

join_query = QuerySpecification(
select=Select(
selectItems=new_select_items,
distinct=False
),
from_=[
Join(
type_=join_type,
left=AliasedRelation(
relation=TableSubquery(query=Query(queryBody=base_query, limit=None, offset=None, orderBy=[])),
alias=left_alias,
columnNames=extract_columns_for_subquery(base_query)
),
right=AliasedRelation(
relation=TableSubquery(query=Query(queryBody=other_query, limit=None, offset=None, orderBy=[])),
alias=right_alias,
columnNames=extract_columns_for_subquery(other_query)
),
criteria=JoinOn(expression=join_expr)
)
],
where=None,
groupBy=[],
having=None,
orderBy=[],
limit=None,
offset=None
)

wrapped_join_query = create_sub_query(join_query, config, "root")
return wrapped_join_query

def base_frame(self) -> LegendApiBaseTdsFrame:
return self.__base_frame

def tds_frame_parameters(self) -> PyLegendList["LegendApiBaseTdsFrame"]:
return [self.__other_frame]

def calculate_columns(self) -> PyLegendSequence["TdsColumn"]:
common_join_cols = [x for x, y in zip(self.__column_names_self, self.__column_names_other) if x == y]
common_join_cols.sort()
return (
[c.copy() for c in self.__base_frame.columns() if c.get_name() not in common_join_cols] +
[c.copy() for c in self.__base_frame.columns() if c.get_name() in common_join_cols] +
[c.copy() for c in self.__other_frame.columns() if c.get_name() not in common_join_cols]
)

def validate(self) -> bool:
left_cols = [c.get_name() for c in self.__base_frame.columns()]
for c in self.__column_names_self:
if c not in left_cols:
raise ValueError(
"Column - '{col}' in join columns list doesn't exist in the left frame being joined. "
"Current left frame columns: {cols}".format(
col=c,
cols=left_cols
)
)

right_cols = [c.get_name() for c in self.__other_frame.columns()]
for c in self.__column_names_other:
if c not in right_cols:
raise ValueError(
"Column - '{col}' in join columns list doesn't exist in the right frame being joined. "
"Current right frame columns: {cols}".format(
col=c,
cols=right_cols
)
)

if len(self.__column_names_self) != len(self.__column_names_other):
raise ValueError(
"For join_by_columns function, column lists should be of same size. "
"Passed column list sizes - Left: {l}, Right: {r}".format(
l=len(self.__column_names_self),
r=len(self.__column_names_other)
)
)

if len(self.__column_names_self) == 0:
raise ValueError("For join_by_columns function, column lists should not be empty")

common_join_cols = []
for (x, y) in zip(self.__column_names_self, self.__column_names_other):
left_col = list(filter(lambda c: c.get_name() == x, self.__base_frame.columns()))[0]
right_col = list(filter(lambda c: c.get_name() == y, self.__other_frame.columns()))[0]

if left_col.get_type() != right_col.get_type():
raise ValueError(
"Trying to join on columns with different types - Left Col: {l}, Right Col: {r}".format(
l=left_col,
r=right_col
)
)

if x == y:
common_join_cols.append(x)

final_cols = (
[x for x in left_cols if x not in common_join_cols] +
[x for x in right_cols if x not in common_join_cols] +
common_join_cols
)

if len(final_cols) != len(set(final_cols)):
raise ValueError(
"Found duplicate columns in joined frames (which are not join keys). "
"Columns - Left Frame: {l}, Right Frame: {r}, Common Join Keys: {j}".format(
l=left_cols,
r=right_cols,
j=common_join_cols
)
)

if self.__join_type.lower() not in ('inner', 'left_outer', 'right_outer'):
raise ValueError(
"Unknown join type - {j}. Supported types are - INNER, LEFT_OUTER, RIGHT_OUTER".format(
j=self.__join_type
)
)

return True
17 changes: 17 additions & 0 deletions pylegend/core/tds/legend_api/frames/legend_api_base_tds_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ def extend(
)
return LegendApiAppliedFunctionTdsFrame(ExtendFunction(self, functions_list, column_names_list))

def join_by_columns(
self,
other: "LegendApiTdsFrame",
self_columns: PyLegendList[str],
other_columns: PyLegendList[str],
join_type: str = 'LEFT_OUTER'
) -> "LegendApiTdsFrame":
from pylegend.core.tds.legend_api.frames.legend_api_applied_function_tds_frame import (
LegendApiAppliedFunctionTdsFrame
)
from pylegend.core.tds.legend_api.frames.functions.join_by_columns_function import (
JoinByColumnsFunction
)
return LegendApiAppliedFunctionTdsFrame(
JoinByColumnsFunction(self, other, self_columns, other_columns, join_type)
)

@abstractmethod
def to_sql_query_object(self, config: FrameToSqlConfig) -> QuerySpecification:
pass
Expand Down
10 changes: 10 additions & 0 deletions pylegend/core/tds/legend_api/frames/legend_api_tds_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,13 @@ def extend(
column_names_list: PyLegendList[str]
) -> "LegendApiTdsFrame":
pass

@abstractmethod
def join_by_columns(
self,
other: "LegendApiTdsFrame",
self_columns: PyLegendList[str],
other_columns: PyLegendList[str],
join_type: str = 'LEFT_OUTER'
) -> "LegendApiTdsFrame":
pass
25 changes: 15 additions & 10 deletions pylegend/core/tds/sql_query_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
__all__: PyLegendSequence[str] = [
"create_sub_query",
"copy_query",
"extract_columns_for_subquery",
]


Expand All @@ -45,16 +46,7 @@ def create_sub_query(
query = copy_query(base_query)
table_alias = config.sql_to_string_generator().get_db_extension().quote_identifier(alias)

columns = []
for col in query.select.selectItems:
if not isinstance(col, SingleColumn):
raise ValueError("Subquery creation not supported for queries "
"with columns other than SingleColumn") # pragma: no cover
if col.alias is None:
raise ValueError("Subquery creation not supported for queries "
"with SingleColumns with missing alias") # pragma: no cover
columns.append(col.alias)

columns = extract_columns_for_subquery(query)
outer_query_columns = columns_to_retain if columns_to_retain else columns
unordered_select_items_with_index = [
(
Expand Down Expand Up @@ -104,6 +96,19 @@ def copy_query(query: QuerySpecification) -> QuerySpecification:
)


def extract_columns_for_subquery(query: QuerySpecification) -> PyLegendList[str]:
columns = []
for col in query.select.selectItems:
if not isinstance(col, SingleColumn):
raise ValueError("Subquery creation not supported for queries "
"with columns other than SingleColumn") # pragma: no cover
if col.alias is None:
raise ValueError("Subquery creation not supported for queries "
"with SingleColumns with missing alias") # pragma: no cover
columns.append(col.alias)
return columns


def copy_select(select: Select) -> Select:
return Select(
distinct=select.distinct,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_sql_gen_filter_function_chained_with_top(self) -> None:
(("root"."col2" LIKE \'A%\') AND ("root"."col1" > 10))'''
assert frame.to_sql_query(FrameToSqlConfig()) == dedent(expected)

@pytest.mark.skip
@pytest.mark.skip(reason="Literal not supported ")
def test_e2e_filter_function_literal(self, legend_test_server: PyLegendDict[str, PyLegendUnion[int, ]]) -> None:
frame: LegendApiTdsFrame = simple_person_service_frame(legend_test_server["engine_port"])
frame = frame.filter(lambda r: 1 == 2) # type: ignore
Expand Down
Loading

0 comments on commit c1a9313

Please sign in to comment.