Skip to content

Commit

Permalink
feat(vector): add conversion between vector and string (#5029)
Browse files Browse the repository at this point in the history
* feat(vector): add conversion between vector and string

Signed-off-by: Zhenchi <[email protected]>

* fix sqlness

Signed-off-by: Zhenchi <[email protected]>

* address comments

Signed-off-by: Zhenchi <[email protected]>

---------

Signed-off-by: Zhenchi <[email protected]>
  • Loading branch information
zhongzc authored Nov 20, 2024
1 parent 027284e commit 0aab68c
Show file tree
Hide file tree
Showing 17 changed files with 656 additions and 170 deletions.
14 changes: 9 additions & 5 deletions src/common/function/src/scalars/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod convert;
mod distance;

use std::sync::Arc;

use distance::{CosDistanceFunction, DotProductFunction, L2SqDistanceFunction};

use crate::function_registry::FunctionRegistry;

pub(crate) struct VectorFunction;

impl VectorFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register(Arc::new(CosDistanceFunction));
registry.register(Arc::new(DotProductFunction));
registry.register(Arc::new(L2SqDistanceFunction));
// conversion
registry.register(Arc::new(convert::ParseVectorFunction));
registry.register(Arc::new(convert::VectorToStringFunction));

// distance
registry.register(Arc::new(distance::CosDistanceFunction));
registry.register(Arc::new(distance::DotProductFunction));
registry.register(Arc::new(distance::L2SqDistanceFunction));
}
}
19 changes: 19 additions & 0 deletions src/common/function/src/scalars/vector/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

mod parse_vector;
mod vector_to_string;

pub use parse_vector::ParseVectorFunction;
pub use vector_to_string::VectorToStringFunction;
160 changes: 160 additions & 0 deletions src/common/function/src/scalars/vector/convert/parse_vector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Display;

use common_query::error::{InvalidFuncArgsSnafu, InvalidVectorStringSnafu, Result};
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::types::parse_string_to_vector_type_value;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use snafu::{ensure, ResultExt};

use crate::function::{Function, FunctionContext};

const NAME: &str = "parse_vec";

#[derive(Debug, Clone, Default)]
pub struct ParseVectorFunction;

impl Function for ParseVectorFunction {
fn name(&self) -> &str {
NAME
}

fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::binary_datatype())
}

fn signature(&self) -> Signature {
Signature::exact(
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
}

fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
),
}
);

let column = &columns[0];
let size = column.len();

let mut result = BinaryVectorBuilder::with_capacity(size);
for i in 0..size {
let value = column.get(i).as_string();
if let Some(value) = value {
let res = parse_string_to_vector_type_value(&value, None)
.context(InvalidVectorStringSnafu { vec_str: &value })?;
result.push(Some(&res));
} else {
result.push_null();
}
}

Ok(result.to_vector())
}
}

impl Display for ParseVectorFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use common_base::bytes::Bytes;
use datatypes::value::Value;
use datatypes::vectors::StringVector;

use super::*;

#[test]
fn test_parse_vector() {
let func = ParseVectorFunction;

let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
]));

let result = func.eval(FunctionContext::default(), &[input]).unwrap();

let result = result.as_ref();
assert_eq!(result.len(), 3);
assert_eq!(
result.get(0),
Value::Binary(Bytes::from(
[1.0f32, 2.0, 3.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<u8>>()
))
);
assert_eq!(
result.get(1),
Value::Binary(Bytes::from(
[4.0f32, 5.0, 6.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<u8>>()
))
);
assert!(result.get(2).is_null());
}

#[test]
fn test_parse_vector_error() {
let func = ParseVectorFunction;

let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0".to_string()),
]));

let result = func.eval(FunctionContext::default(), &[input]);
assert!(result.is_err());

let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("7.0,8.0,9.0]".to_string()),
]));

let result = func.eval(FunctionContext::default(), &[input]);
assert!(result.is_err());

let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,hello,9.0]".to_string()),
]));

let result = func.eval(FunctionContext::default(), &[input]);
assert!(result.is_err());
}
}
139 changes: 139 additions & 0 deletions src/common/function/src/scalars/vector/convert/vector_to_string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Display;

use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::types::vector_type_value_to_string;
use datatypes::value::Value;
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
use snafu::ensure;

use crate::function::{Function, FunctionContext};

const NAME: &str = "vec_to_string";

#[derive(Debug, Clone, Default)]
pub struct VectorToStringFunction;

impl Function for VectorToStringFunction {
fn name(&self) -> &str {
NAME
}

fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}

fn signature(&self) -> Signature {
Signature::exact(
vec![ConcreteDataType::binary_datatype()],
Volatility::Immutable,
)
}

fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
),
}
);

let column = &columns[0];
let size = column.len();

let mut result = StringVectorBuilder::with_capacity(size);
for i in 0..size {
let value = column.get(i);
match value {
Value::Binary(bytes) => {
let len = bytes.len();
if len % std::mem::size_of::<f32>() != 0 {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid binary length of vector: {}", len),
}
.fail();
}

let dim = len / std::mem::size_of::<f32>();
// Safety: `dim` is calculated from the length of `bytes` and is guaranteed to be valid
let res = vector_type_value_to_string(&bytes, dim as _).unwrap();
result.push(Some(&res));
}
Value::Null => {
result.push_null();
}
_ => {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid value type: {:?}", value.data_type()),
}
.fail();
}
}
}

Ok(result.to_vector())
}
}

impl Display for VectorToStringFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}

#[cfg(test)]
mod tests {
use datatypes::value::Value;
use datatypes::vectors::BinaryVectorBuilder;

use super::*;

#[test]
fn test_vector_to_string() {
let func = VectorToStringFunction;

let mut builder = BinaryVectorBuilder::with_capacity(3);
builder.push(Some(
[1.0f32, 2.0, 3.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<_>>()
.as_slice(),
));
builder.push(Some(
[4.0f32, 5.0, 6.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<_>>()
.as_slice(),
));
builder.push_null();
let vector = builder.to_vector();

let result = func.eval(FunctionContext::default(), &[vector]).unwrap();

assert_eq!(result.len(), 3);
assert_eq!(result.get(0), Value::String("[1,2,3]".to_string().into()));
assert_eq!(result.get(1), Value::String("[4,5,6]".to_string().into()));
assert_eq!(result.get(2), Value::Null);
}
}
11 changes: 10 additions & 1 deletion src/common/query/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},

#[snafu(display("Invalid vector string: {}", vec_str))]
InvalidVectorString {
vec_str: String,
source: DataTypeError,
#[snafu(implicit)]
location: Location,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -273,7 +281,8 @@ impl ErrorExt for Error {
| Error::IntoVector { source, .. }
| Error::FromScalarValue { source, .. }
| Error::ConvertArrowSchema { source, .. }
| Error::FromArrowArray { source, .. } => source.status_code(),
| Error::FromArrowArray { source, .. }
| Error::InvalidVectorString { source, .. } => source.status_code(),

Error::MissingTableMutationHandler { .. }
| Error::MissingProcedureServiceHandler { .. }
Expand Down
Loading

0 comments on commit 0aab68c

Please sign in to comment.