Skip to content

Commit

Permalink
test: add Integrated Test for Coprocessor& fix minor bugs (#1122)
Browse files Browse the repository at this point in the history
* feat: cache `Runtime`

* fix: coprstream schema not set

* test: integrated tests for Coprocessor

* fix: UDF fixed

* style: remove unused import

* chore: remove more unused import

* feat: `filter`, (r)floordiv for Vector

* chore: CR advices

* feat: auto convert to `lit`

* chore: fix typo

* feat: from&to `pyarrow.array`

* feat: allow `pyarrow.array` as args to builtins

* chore: cargo fmt

* test: CI add `pyarrow`

* test: install Python&PyArrow in CI

* test: not cache depend for now

* chore: CR advices

* test: fix name

* style: rename
  • Loading branch information
discord9 authored Mar 6, 2023
1 parent ff6cfe8 commit 379f581
Show file tree
Hide file tree
Showing 24 changed files with 576 additions and 186 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/develop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ jobs:
uses: Swatinem/rust-cache@v2
- name: Install latest nextest release
uses: taiki-e/install-action@nextest
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install PyArrow Package
run: pip install pyarrow
- name: Install cargo-llvm-cov
uses: taiki-e/install-action@cargo-llvm-cov
- name: Collect coverage data
Expand Down
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ edition = "2021"
license = "Apache-2.0"

[workspace.dependencies]
arrow = "33.0"
arrow = { version = "33.0", features = ["pyarrow"] }
arrow-array = "33.0"
arrow-flight = "33.0"
arrow-schema = { version = "33.0", features = ["serde"] }
Expand Down
4 changes: 3 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ RUN apt-get update && apt-get install -y \
curl \
build-essential \
pkg-config \
python3-dev
python3 \
python3-dev \
&& pip install pyarrow

# Install Rust.
SHELL ["/bin/bash", "-c"]
Expand Down
16 changes: 4 additions & 12 deletions src/common/function/src/scalars/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

use std::sync::Arc;

use common_query::error::{ExecuteFunctionSnafu, FromScalarValueSnafu};
use common_query::error::FromScalarValueSnafu;
use common_query::prelude::{
ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf, ScalarValue,
ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf,
};
use datatypes::error::Error as DataTypeError;
use datatypes::prelude::*;
Expand Down Expand Up @@ -54,16 +54,8 @@ pub fn create_udf(func: FunctionRef) -> ScalarUdf {
.collect();

let result = func_cloned.eval(func_ctx, &args.context(FromScalarValueSnafu)?);

let udf = if len.is_some() {
result.map(ColumnarValue::Vector)?
} else {
ScalarValue::try_from_array(&result?.to_arrow_array(), 0)
.map(ColumnarValue::Scalar)
.context(ExecuteFunctionSnafu)?
};

Ok(udf)
let udf_result = result.map(ColumnarValue::Vector)?;
Ok(udf_result)
});

ScalarUdf::new(func.name(), &func.signature(), &return_type, &fun)
Expand Down
1 change: 1 addition & 0 deletions src/script/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python = [
]

[dependencies]
arrow.workspace = true
async-trait.workspace = true
catalog = { path = "../catalog" }
common-catalog = { path = "../common/catalog" }
Expand Down
92 changes: 64 additions & 28 deletions src/script/src/python/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use common_recordbatch::{
RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream,
};
use datafusion_expr::Volatility;
use datatypes::schema::{ColumnSchema, SchemaRef};
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::vectors::VectorRef;
use futures::Stream;
use query::parser::{QueryLanguageParser, QueryStatement};
Expand All @@ -40,9 +40,8 @@ use snafu::{ensure, ResultExt};
use sql::statements::statement::Statement;

use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use crate::python::error::{self, Result};
use crate::python::error::{self, PyRuntimeSnafu, Result};
use crate::python::ffi_types::copr::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};

const PY_ENGINE: &str = "python";

#[derive(Debug)]
Expand Down Expand Up @@ -81,17 +80,21 @@ impl PyUDF {

/// Fake a schema, should only be used with dynamically eval a Python Udf
fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef {
let empty_args = vec![];
let arg_names = self
.copr
.deco_args
.arg_names
.as_ref()
.unwrap_or(&empty_args);
// try to give schema right names in args so script can run as UDF without modify
// because when running as PyUDF, the incoming columns should have matching names to make sense
// for Coprocessor
let args = self.copr.deco_args.arg_names.clone();
let try_get_name = |i: usize| {
if let Some(arg_name) = args.as_ref().and_then(|args| args.get(i)) {
arg_name.clone()
} else {
format!("name_{i}")
}
};
let col_sch: Vec<_> = columns
.iter()
.enumerate()
.map(|(i, col)| ColumnSchema::new(arg_names[i].clone(), col.data_type(), true))
.map(|(i, col)| ColumnSchema::new(try_get_name(i), col.data_type(), true))
.collect();
let schema = datatypes::schema::Schema::new(col_sch);
Arc::new(schema)
Expand Down Expand Up @@ -172,7 +175,7 @@ impl Function for PyUDF {

pub struct PyScript {
query_engine: QueryEngineRef,
copr: CoprocessorRef,
pub(crate) copr: CoprocessorRef,
}

impl PyScript {
Expand All @@ -188,12 +191,48 @@ impl PyScript {
pub struct CoprStream {
stream: SendableRecordBatchStream,
copr: CoprocessorRef,
ret_schema: SchemaRef,
params: HashMap<String, String>,
}

impl CoprStream {
fn try_new(
stream: SendableRecordBatchStream,
copr: CoprocessorRef,
params: HashMap<String, String>,
) -> Result<Self> {
let mut schema = vec![];
for (ty, name) in copr.return_types.iter().zip(&copr.deco_args.ret_names) {
let ty = ty.clone().ok_or(
PyRuntimeSnafu {
msg: "return type not annotated, can't generate schema",
}
.build(),
)?;
let is_nullable = ty.is_nullable;
let ty = ty.datatype.ok_or(
PyRuntimeSnafu {
msg: "return type not annotated, can't generate schema",
}
.build(),
)?;
let col_schema = ColumnSchema::new(name, ty, is_nullable);
schema.push(col_schema);
}
let ret_schema = Arc::new(Schema::new(schema));
Ok(Self {
stream,
copr,
ret_schema,
params,
})
}
}

impl RecordBatchStream for CoprStream {
fn schema(&self) -> SchemaRef {
self.stream.schema()
// FIXME(discord9): use copr returns for schema
self.ret_schema.clone()
}
}

Expand All @@ -207,7 +246,6 @@ impl Stream for CoprStream {
let batch = exec_parsed(&self.copr, &Some(recordbatch), &self.params)
.map_err(BoxedError::new)
.context(ExternalSnafu)?;

Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(other) => Poll::Ready(other),
Expand Down Expand Up @@ -246,11 +284,9 @@ impl Script for PyScript {
let res = self.query_engine.execute(&plan).await?;
let copr = self.copr.clone();
match res {
Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream {
params,
copr,
stream,
}))),
Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream::try_new(
stream, copr, params,
)?))),
_ => unreachable!(),
}
} else {
Expand Down Expand Up @@ -296,7 +332,8 @@ impl ScriptEngine for PyEngine {
})
}
}

#[cfg(test)]
pub(crate) use tests::sample_script_engine;
#[cfg(test)]
mod tests {
use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider};
Expand All @@ -311,7 +348,7 @@ mod tests {

use super::*;

fn sample_script_engine() -> PyEngine {
pub(crate) fn sample_script_engine() -> PyEngine {
let catalog_list = catalog::local::new_memory_catalog_list().unwrap();

let default_schema = Arc::new(MemorySchemaProvider::new());
Expand Down Expand Up @@ -340,7 +377,7 @@ mod tests {
import greptime as gt
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
def test(number)->vector[u32]:
def test(number) -> vector[u32]:
return query.sql("select * from numbers")[0][0]
"#;
let script = script_engine
Expand All @@ -367,7 +404,7 @@ def test(number)->vector[u32]:

let script = r#"
@copr(returns = ["number"])
def test(**params)->vector[i64]:
def test(**params) -> vector[i64]:
return int(params['a']) + int(params['b'])
"#;
let script = script_engine
Expand Down Expand Up @@ -396,11 +433,10 @@ def test(**params)->vector[i64]:
let script_engine = sample_script_engine();

let script = r#"
import greptime as gt
from data_frame import col
from greptime import col
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
def test(number)->vector[u32]:
def test(number) -> vector[u32]:
return dataframe.filter(col("number")==col("number")).collect()[0][0]
"#;
let script = script_engine
Expand Down Expand Up @@ -432,7 +468,7 @@ def add(a, b):
return a + b;
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
def test(a, b, c):
def test(a, b, c) -> vector[f64]:
return add(a, b) / g.sqrt(c + 1)
"#;
let script = script_engine
Expand Down Expand Up @@ -470,7 +506,7 @@ def test(a, b, c):
import greptime as gt
@copr(args=["number"], returns = ["r"], sql="select number from numbers limit 100")
def test(a):
def test(a) -> vector[i64]:
return gt.vector([x for x in a if x % 2 == 0])
"#;
let script = script_engine
Expand Down
28 changes: 22 additions & 6 deletions src/script/src/python/ffi_types/copr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub(crate) fn select_from_rb(rb: &RecordBatch, fetch_names: &[String]) -> Result
.iter()
.map(|name| {
let vector = rb.column_by_name(name).with_context(|| OtherSnafu {
reason: format!("Can't find field name {name}"),
reason: format!("Can't find field name {name} in all columns in {rb:?}"),
})?;
Ok(PyVector::from(vector.clone()))
})
Expand All @@ -229,15 +229,29 @@ pub(crate) fn select_from_rb(rb: &RecordBatch, fetch_names: &[String]) -> Result
/// match between arguments' real type and annotation types
/// if type anno is `vector[_]` then use real type(from RecordBatch's schema)
pub(crate) fn check_args_anno_real_type(
arg_names: &[String],
args: &[PyVector],
copr: &Coprocessor,
rb: &RecordBatch,
) -> Result<()> {
ensure!(
arg_names.len() == args.len(),
OtherSnafu {
reason: format!("arg_names:{arg_names:?} and args{args:?}'s length is different")
}
);
for (idx, arg) in args.iter().enumerate() {
let anno_ty = copr.arg_types[idx].clone();
let real_ty = arg.to_arrow_array().data_type().clone();
let real_ty = ConcreteDataType::from_arrow_type(&real_ty);
let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable();
let arg_name = arg_names[idx].clone();
let col_idx = rb.schema.column_index_by_name(&arg_name).ok_or(
OtherSnafu {
reason: format!("Can't find column by name {arg_name}"),
}
.build(),
)?;
let is_nullable: bool = rb.schema.column_schemas()[col_idx].is_nullable();
ensure!(
anno_ty
.clone()
Expand Down Expand Up @@ -424,11 +438,13 @@ pub fn exec_parsed(
pyo3_exec_parsed(copr, rb, params)
}
#[cfg(not(feature = "pyo3_backend"))]
OtherSnafu {
reason: "`pyo3` feature is disabled, therefore can't run scripts in cpython"
.to_string(),
{
OtherSnafu {
reason: "`pyo3` feature is disabled, therefore can't run scripts in cpython"
.to_string(),
}
.fail()
}
.fail()
}
}
}
Expand Down
Loading

0 comments on commit 379f581

Please sign in to comment.