Skip to content

Commit

Permalink
feat(pyspark): enable reading csv and parquet globs and implement `re…
Browse files Browse the repository at this point in the history
…ad_json`
  • Loading branch information
cpcloud committed Aug 21, 2023
1 parent 4ea1834 commit d487e10
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
34 changes: 34 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from ibis.backends.pyspark.datatypes import PySparkType

if TYPE_CHECKING:
from collections.abc import Sequence

import pandas as pd
import pyarrow as pa

Expand Down Expand Up @@ -682,6 +684,38 @@ def read_csv(
spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def read_json(
self,
source_list: str | Sequence[str],
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a JSON file as a table in the current database.
Parameters
----------
source_list
The data source(s). May be a path to a file or directory of JSON files, or an
iterable of JSON files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.json.html
Returns
-------
ir.Table
The just-registered table
"""
source_list = normalize_filenames(source_list)
spark_df = self._session.read.json(source_list, **kwargs)
table_name = table_name or util.gen_name("read_json")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def register(
self,
source: str | Path | Any,
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ def ft_data(data_dir):
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
"trino",
]
Expand Down Expand Up @@ -485,7 +484,6 @@ def test_read_parquet_glob(con, tmp_path, ft_data):
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
"trino",
]
Expand Down Expand Up @@ -517,7 +515,6 @@ def test_read_csv_glob(con, tmp_path, ft_data):
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
"trino",
]
Expand Down

0 comments on commit d487e10

Please sign in to comment.