Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Link back information about position in the original string #228

Open
ilyakochik opened this issue Feb 6, 2024 · 4 comments
Open

Link back information about position in the original string #228

ilyakochik opened this issue Feb 6, 2024 · 4 comments

Comments

@ilyakochik
Copy link

To analyse SQL in the editor and provide some syntax checking it would be great for each branch or leaf in the tree to know where it starts and ends in the original document.

Is there a quick way to return e.g. TreeDict(dict) and TreeList(list) classes that would have start and end location?

I can try and do it myself if you can point me in the right direction. Seems like scrub() function should do it, but it's not very obvious.

@klahnakoski
Copy link
Owner

That sounds like an interesting feature! Looking at the scrub method, it does appear to have all the information you need to get this done.

I am not sure how you would like to include the start/end properties. Maybe you can make scrub emit objects instead of dict/lists. Then properties are like usual, but you have attributes to store start/end

>>> class A(dict):
...     pass
...
>>> a = A()
>>> setattr(a, "start", 42)
>>> a['start']=43
>>> a['start']
43
>>> a.start
42

No matter what your choice, you can patch the scrub method without making a PR:

>>> def my_new_scrub_method():
...      pass
...
>>> from mo_sql_parsing import utils
>>> utils.scrub = my_new_scrub_method

If you do make a PR, be sure to make it an option

  • a parameter in the parser call, or
    parse(sql, add_positions=True)
    
  • a method you run to do the override above, or
    add_positions()  
    parse(sql)
    
  • a context manager
    with add_positions():
        parse(sql)
    

I will leave it with you for now.

@ilyakochik
Copy link
Author

ilyakochik commented Feb 7, 2024

Thanks! I don't have time for the proper PR yet, but here is the patch I made in case you or anyone else might need it.

There are scenarios when there is no correct start and get (some returns from Call, identifiers, etc.) I try fixing some of the cases with _fix().

from __future__ import annotations
import mo_sql_parsing
import mo_parsing.utils
from mo_future import text, number_types, binary_type
from mo_parsing import *
from mo_sql_parsing.utils import scrub_op, SQL_NULL, Call
from collections.abc import Generator


class SqlTree:
    start: int = None
    end: int = None

    @staticmethod
    def get_meta_keys() -> tuple[str, str]:
        return ("start", "end")

    def set_meta(self, **kwargs) -> SqlTree:
        for k in kwargs:
            assert k in SqlTree.get_meta_keys()
            setattr(self, k, kwargs[k])
        return self

    def get_meta(self) -> dict[str, object]:
        return {k: getattr(self, k, None) for k in SqlTree.get_meta_keys()}

    def copy_meta(self, other: SqlTree) -> SqlTree:
        return self.set_meta(**other.get_meta())


class SqlList(list, SqlTree):
    def items(self) -> Generator[tuple[int, object]]:
        return enumerate(self)


class SqlDict(dict, SqlTree):
    pass


class SqlValue(SqlTree):
    value: str | int = None

    def __init__(self, value: str | int):
        self.value = value

    def __str__(self) -> str:
        return self.value.__str__()

    def __repr__(self) -> str:
        return self.value.__repr__()

    def items(self) -> Generator[tuple[int, str | int]]:
        return enumerate([self.value])


flat_keys = "name", "value", "all_columns"
transpile_map = {list: SqlList, dict: SqlDict, int: SqlValue, str: SqlValue}


# Results of original `scrub` function is appended to a list, so apply it after `_parse`
mo_sql_parsing.utils.scrub = lambda x: x


def parse(sql: str) -> SqlTree:
    parsed = mo_sql_parsing.parse(sql)
    tree = _transpile(parsed)
    tree = _squash(tree)
    _fix(tree)
    _check(tree, sql)
    return tree


def _check(tree: SqlTree, sql: str) -> None:
    meta = tree.get_meta()
    if all(v is not None for v in meta.values()):
        print(str(tree)[0:20], " === ", sql[meta["start"] : meta["end"]])

    if isinstance(tree, (SqlDict, SqlList)):
        for k, v in tree.items():
            _check(v, sql)


def _fix(tree: SqlTree) -> None:
    # TODO: fix parser to always have start and end without these hacks
    if isinstance(tree, SqlList):
        start = (r.start for r in tree if r.start is not None)
        end = (r.end for r in tree if r.end is not None)
        tree.set_meta(start=min(start, default=None), end=max(end, default=None))
    elif isinstance(tree, SqlDict) and len(tree) == 1:
        child = list(tree.keys()).pop()
        if not all(v is not None for v in tree[child].get_meta().values()):
            tree[child].copy_meta(tree)

    if isinstance(tree, SqlTree):
        for k, v in tree.items():
            _fix(v)


def _squash(dirty: SqlTree, parent=None) -> SqlTree:
    global flat_keys

    # Recursively clean up the tree
    if isinstance(dirty, SqlList):
        clean = [_squash(r, dirty) for r in dirty]
        clean = [r for r in clean if r is not None]
        if len(clean) > 1:
            clean = SqlList(clean)
        elif len(clean) == 1 and clean[0] is not None:
            clean = clean[0]
        else:
            clean = None
    elif isinstance(dirty, SqlDict):
        clean = {k: _squash(v, dirty) for k, v in dirty.items()}
        clean = {
            k: v if isinstance(v, list) or k in flat_keys else SqlList([v]).copy_meta(v)
            for k, v in clean.items()
            if v is not None
        }
        clean = SqlDict(clean)
    elif dirty is None or isinstance(dirty, SqlValue):
        clean = dirty
    else:
        raise NotImplementedError(f"Not implemented for {dirty.__class__}")

    # Preserve meta attributes
    if clean is None:
        return clean
    elif all(v is not None for v in clean.get_meta().values()):
        return clean
    elif all(v is not None for v in dirty.get_meta().values()):
        return clean.copy_meta(dirty)
    elif all(v is not None for v in parent.get_meta().values()):
        return clean.copy_meta(parent)
    else:
        return clean


def _transpile(dirty: object) -> SqlTree:
    global transpile_map
    loc_attrs, loc = ("start", "end"), {}
    clean = None

    # Parse depending on type
    if dirty is SQL_NULL or dirty is None:
        clean = None
    elif isinstance(dirty, (text, number_types)):
        # TODO: Simple tokens do not have `start` and `end`
        clean = dirty
    elif isinstance(dirty, binary_type):
        clean = dirty.decode("utf8")
    elif isinstance(dirty, list):
        clean = [_transpile(r) for r in dirty]
    elif isinstance(dirty, dict):
        clean = {k: _transpile(v) for k, v in dirty.items()}
    elif isinstance(dirty, Call):
        kwargs = _transpile(dirty.kwargs)
        args = _transpile(dirty.args)
        clean = scrub_op(dirty.op, args, kwargs)
        # TODO: Call object has no `start` and `end`
    elif isinstance(dirty, mo_parsing.results.ForwardResults):
        loc = {a: getattr(dirty, a, None) for a in loc_attrs}
        clean = _transpile(dirty.tokens)
    elif isinstance(dirty, mo_parsing.results.ParseResults):
        loc = {a: getattr(dirty, a, None) for a in loc_attrs}
        tokens = dict(dirty.items()) or dirty.tokens
        clean = _transpile(tokens)
        # TODO: "*" is {all_columns: {}}, while "tbl.*" is {all_columns: "tbl"}
        #       for consistency better {all_columns: ''}
        # TODO: ParseResults often has `start=-1` and `end=0`
    else:
        raise NotImplementedError(f"Not implemented for {dirty.__class__}")

    # Transpile to Sql classes
    if clean.__class__ in transpile_map:
        clean = transpile_map[clean.__class__](clean)

    # Update meta attributes if captured
    if loc and all(loc[v] is not None and loc[v] >= 0 for v in loc_attrs):
        clean.set_meta(**loc)

    return clean

@klahnakoski
Copy link
Owner

Thank you. I made a branch: https://github.com/klahnakoski/mo-sql-parsing/tree/add-start-end

@klahnakoski
Copy link
Owner

it will need tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants