Skip to content

Commit

Permalink
feat: add seal_gpu and the related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dtera committed May 22, 2024
1 parent 462875a commit fffd431
Show file tree
Hide file tree
Showing 20 changed files with 1,977 additions and 34 deletions.
63 changes: 31 additions & 32 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,36 @@ if (APPLE)
project(heu VERSION 1.0 LANGUAGES C CXX)
option(ENABLE_GPU "whether enable gpu" OFF)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 72 75 80 86)
project(heu VERSION 1.0 LANGUAGES C CXX CUDA)
option(ENABLE_GPU "whether enable gpu" ON)
add_definitions(-DENABLE_GPAILLIER=true)

find_package(CUDAToolkit REQUIRED)
set(CMAKE_CUDA_ARCHITECTURES 70 72 75 80 86)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -O3")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++17")
set(gcc_like_cxx "$<COMPILE_LANG_AND_ID:CXX,ARMClang,AppleClang,Clang,GNU>")
set(nvcc_cxx "$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>")
endif ()

option(ENABLE_IPCL "whether enable ipcl" OFF)
option(ENABLE_IC "whether enable ic" OFF)
add_definitions(-DMSGPACK_NO_BOOST -DUSE_CMAKE -DNO_USE_MSGPACK)

file(GLOB_RECURSE SRCS ${PROJECT_NAME}/library/*.cpp ${PROJECT_NAME}/library/*.cc ${PROJECT_NAME}/api/*.cc
${PROJECT_NAME}/algorithms/*.cpp ${PROJECT_NAME}/algorithms/*.cc
${PROJECT_NAME}/spi/*.cpp ${PROJECT_NAME}/spi/*.cc)
file(GLOB_RECURSE IPCL_SRCS ${PROJECT_NAME}/library/algorithms/paillier_ipcl/*.cc)
file(GLOB_RECURSE GPU_SRCS ${PROJECT_NAME}/library/algorithms/paillier_gpu/*.cc
${PROJECT_NAME}/library/algorithms/paillier_gpu/gpulib/*.cu)
file(GLOB_RECURSE SEAL_GPU_SRCS ${PROJECT_NAME}/library/algorithms/seal_gpu/*)
file(GLOB_RECURSE PAILLIER_IC_SRCS ${PROJECT_NAME}/library/algorithms/paillier_ic/*.cc)
file(GLOB_RECURSE NP_SRCS ${PROJECT_NAME}/library/numpy/*.cc)
file(GLOB_RECURSE TEST_SRCS ${PROJECT_NAME}/library/*_test*.cc
${PROJECT_NAME}/spi/*_test*.cc ${PROJECT_NAME}/spi/test_*.cc)
file(GLOB_RECURSE BENCH_SRCS ${PROJECT_NAME}/library/*_bench.cc)
list(REMOVE_ITEM SRCS ${TEST_SRCS} ${BENCH_SRCS} ${NP_SRCS} ${SEAL_GPU_SRCS})
file(GLOB_RECURSE SRCS ${PROJECT_NAME}/library/*.c* ${PROJECT_NAME}/algorithms/*.c* ${PROJECT_NAME}/spi/*.c*
${PROJECT_NAME}/api/*.c*)
file(GLOB_RECURSE IPCL_SRCS ${PROJECT_NAME}/library/algorithms/paillier_ipcl/*.c*)
file(GLOB_RECURSE GPU_SRCS ${PROJECT_NAME}/library/algorithms/paillier_gpu/*.c*)
file(GLOB_RECURSE PAILLIER_IC_SRCS ${PROJECT_NAME}/library/algorithms/paillier_ic/*.c*)
file(GLOB_RECURSE NP_SRCS ${PROJECT_NAME}/library/numpy/*.c*)
file(GLOB_RECURSE TEST_SRCS ${PROJECT_NAME}/library/*_test*.c* ${PROJECT_NAME}/spi/*test*.c*)
file(GLOB_RECURSE BENCH_SRCS ${PROJECT_NAME}/library/*_bench.c*)
file(GLOB_RECURSE SEAL_GPU_SRCS ${PROJECT_NAME}/library/algorithms/seal_gpu/*.c*)
file(GLOB_RECURSE SEAL_FHE_GPU_SRCS ${PROJECT_NAME}/algorithms/seal_fhe/gpu/*.c*)
file(GLOB_RECURSE RM_SRCS ${PROJECT_NAME}/algorithms/seal_fhe/*.c*)

list(REMOVE_ITEM SRCS ${TEST_SRCS} ${BENCH_SRCS} ${NP_SRCS} ${SEAL_GPU_SRCS} ${SEAL_FHE_SRCS} ${RM_SRCS})
list(REMOVE_ITEM GPU_SRCS ${TEST_SRCS} ${BENCH_SRCS} ${NP_SRCS})
list(REMOVE_ITEM TEST_SRCS ${NP_SRCS})
list(REMOVE_ITEM BENCH_SRCS ${NP_SRCS})
Expand All @@ -47,7 +49,6 @@ if (NOT ENABLE_IPCL)
list(REMOVE_ITEM BENCH_SRCS ${IPCL_SRCS})
endif ()
if (NOT ENABLE_GPU)
list(REMOVE_ITEM SRCS ${GPU_SRCS})
list(REMOVE_ITEM TEST_SRCS ${GPU_SRCS})
list(REMOVE_ITEM BENCH_SRCS ${GPU_SRCS})
endif ()
Expand Down Expand Up @@ -79,6 +80,7 @@ else ()
message("OpenMP Found")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set(COMPILE_OPTIONS --expt-extended-lambda -lineinfo --Werror all-warnings)
endif ()
endif ()

Expand All @@ -87,6 +89,7 @@ include_directories(. include third_party/yacl third_party third_party/include)
link_directories(third_party/lib)

# Building tests
enable_testing()
include(FetchContent)
FetchContent_Declare(
googletest
Expand All @@ -105,18 +108,24 @@ if (ENABLE_GPU)
target_link_libraries(${PROJECT_NAME}_GPU)

add_subdirectory(heu/library/algorithms/seal_gpu)

add_library(${PROJECT_NAME}_seal_fhe_gpu SHARED ${SEAL_FHE_GPU_SRCS})
target_link_libraries(${PROJECT_NAME}_seal_fhe_gpu seal_gpu)

add_executable(${PROJECT_NAME}_seal_gpu_timetest heu/tests/seal_gpu_timetest.cu)
target_link_libraries(${PROJECT_NAME}_seal_gpu_timetest seal_gpu)

list(APPEND EXE_TARGETS ${PROJECT_NAME}_seal_gpu_timetest)
list(APPEND GPU_LIB ${PROJECT_NAME}_GPU)
#list(APPEND GPU_LIB ${PROJECT_NAME}_GPU)
endif ()
# link libraries for the target

# add project lib
add_library(${PROJECT_NAME} SHARED ${SRCS})
#target_sources(${PROJECT_NAME} PRIVATE)
# link libraries for the target
target_link_libraries(${PROJECT_NAME} yacl ${ABSL_LIB} ${GPU_LIB} # protobuf::libprotobuf protobuf::libprotobuf-lite
tommath fmt blake3 sodium curve25519 gmp GTest::gtest GTest::gmock)


# Building bench
foreach (bench_file ${BENCH_SRCS})
string(REPLACE ".cc" "" bench_name ${bench_file})
Expand All @@ -126,7 +135,6 @@ foreach (bench_file ${BENCH_SRCS})
list(APPEND EXE_TARGETS ${PROJECT_NAME}_${bench_name})
endforeach ()

enable_testing()
# Building test
foreach (test_file ${TEST_SRCS})
string(REPLACE ".cc" "" test_name ${test_file})
Expand All @@ -139,26 +147,17 @@ foreach (test_file ${TEST_SRCS})
endforeach ()

add_executable(${PROJECT_NAME}_tests ${PROJECT_NAME}/tests.cc)
target_link_libraries(${PROJECT_NAME}_tests ${PROJECT_NAME}
GTest::GTest
)
file(GLOB TEST_SUITS ${PROJECT_NAME}/tests/*_test.cc)
target_link_libraries(${PROJECT_NAME}_tests ${PROJECT_NAME} ${PROJECT_NAME}_seal_fhe_gpu GTest::GTest)
file(GLOB TEST_SUITS ${PROJECT_NAME}/tests/*_test.c*)
target_sources(${PROJECT_NAME}_tests PRIVATE ${TEST_SUITS})
list(APPEND EXE_TARGETS ${PROJECT_NAME}_tests)

foreach (target ${EXE_TARGETS})
message("target: ${target}")
endforeach ()
set_target_properties(${PROJECT_NAME} ${EXE_TARGETS} PROPERTIES
set_target_properties(${EXE_TARGETS} ${PROJECT_NAME} ${PROJECT_NAME}_GPU ${PROJECT_NAME}_seal_fhe_gpu PROPERTIES
CXX_STANDARD ${CMAKE_CXX_STANDARD}
CXX_EXTENSIONS OFF
CXX_STANDARD_REQUIRED ON
CUDA_SEPARABLE_COMPILATION ON
)
if (ENABLE_GPU)
set_target_properties(${PROJECT_NAME}_GPU PROPERTIES
CXX_STANDARD ${CMAKE_CXX_STANDARD}
CXX_EXTENSIONS OFF
CXX_STANDARD_REQUIRED ON
CUDA_SEPARABLE_COMPILATION ON
)
endif ()
79 changes: 79 additions & 0 deletions heu/algorithms/seal_fhe/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2024 Ant Group Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@yacl//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test")

package(default_visibility = ["//visibility:public"])

yacl_cc_library(
name = "seal_fhe",
srcs = ["he_kit.cc"],
hdrs = ["he_kit.h"],
deps = [
":decryptor",
":encoders",
":encryptor",
":evaluator",
],
alwayslink = 1,
)

yacl_cc_library(
name = "base",
srcs = ["base.cc"],
hdrs = ["base.h"],
deps = [
"//heu/spi/he/sketches/scalar",
"//heu/spi/utils:formater",
"@yacl//yacl/utils:serializer",
],
)

yacl_cc_library(
name = "encoders",
srcs = ["encoders.cc"],
hdrs = ["encoders.h"],
deps = [
":base",
],
)

yacl_cc_library(
name = "encryptor",
srcs = ["encryptor.cc"],
hdrs = ["encryptor.h"],
deps = [
":base",
],
)

yacl_cc_library(
name = "decryptor",
srcs = ["decryptor.cc"],
hdrs = ["decryptor.h"],
deps = [
":base",
],
)

yacl_cc_library(
name = "evaluator",
srcs = ["evaluator.cc"],
hdrs = ["evaluator.h"],
deps = [
":base",
"//heu/algorithms/common:type_alias",
"//heu/spi/utils:math_tool",
],
)
61 changes: 61 additions & 0 deletions heu/algorithms/seal_fhe/base.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2024 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "heu/algorithms/seal_fhe/base.cuh"

#include <vector>

#include "yacl/utils/serializer.h"

#include "heu/spi/utils/formater.h"

namespace heu::algos::seal_fhe {

Plaintext ItemTool::Clone(const Plaintext &pt) const { return pt; }

Ciphertext ItemTool::Clone(const Ciphertext &ct) const { return ct; }

size_t ItemTool::Serialize(const Plaintext &pt, uint8_t *buf,
size_t buf_len) const {
std::stringstream ss;
pt.save(ss);
ss >> buf;
return buf_len;
}

size_t ItemTool::Serialize(const Ciphertext &ct, uint8_t *buf,
size_t buf_len) const {
std::stringstream ss;
ct.save(ss);
ss >> buf;
return buf_len;
}

Plaintext ItemTool::DeserializePT(yacl::ByteContainerView buffer) const {
Plaintext res;
std::stringstream ss;
ss << buffer.data();
res.load(ss);
return res;
}

Ciphertext ItemTool::DeserializeCT(yacl::ByteContainerView buffer) const {
Ciphertext res;
std::stringstream ss;
ss << buffer.data();
res.load(ss);
return res;
}

} // namespace heu::algos::seal_fhe
103 changes: 103 additions & 0 deletions heu/algorithms/seal_fhe/base.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2024 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "heu/library/algorithms/seal_gpu/seal_cuda.cuh"
#include "heu/spi/he/sketches/common/keys.h"
#include "heu/spi/he/sketches/scalar/item_tool.h"

using seal_gpun::Ciphertext;
using seal_gpun::Plaintext;
using seal_gpun::SEALContext;

namespace heu::algos::seal_fhe {

class SecretKey : public spi::KeySketch<spi::HeKeyType::SecretKey>,
public seal_gpun::SecretKey {
public:
SecretKey() = default;

SecretKey(const seal_gpun::SecretKey &secretKey)
: seal_gpun::SecretKey(secretKey) {}

std::map<std::string, std::string> ListParams() const override {
auto sk = data();
return {
{"is_ntt_form_", fmt::to_string(sk.isNttForm())},
{"coeff_count_", fmt::to_string(sk.coeffCount())},
{"scale_", fmt::to_string(sk.scale())},
};
}
};

class PublicKey : public spi::KeySketch<spi::HeKeyType::PublicKey>,
public seal_gpun::PublicKey {
public:
std::map<std::string, std::string> ListParams() const override {
auto pk = data();
return {
{"is_ntt_form_", fmt::to_string(pk.isNttForm())},
{"size_", fmt::to_string(pk.size())},
{"poly_modulus_degree_", fmt::to_string(pk.polyModulusDegree())},
{"coeff_modulus_size_", fmt::to_string(pk.coeffModulusSize())},
{"scale_", fmt::to_string(pk.scale())},
{"correction_factor_", fmt::to_string(pk.correctionFactor())},
{"seed_", fmt::to_string(pk.seed())},
};
}
};

class RelinKeys : public spi::KeySketch<spi::HeKeyType::RelinKeys>,
public seal_gpun::RelinKeys {
public:
std::map<std::string, std::string> ListParams() const override {
return {
{"num_of_keyswitch_", fmt::to_string(size())},
};
}
};

class GaloisKeys : public spi::KeySketch<spi::HeKeyType::GaloisKeys>,
public seal_gpun::GaloisKeys {
public:
std::map<std::string, std::string> ListParams() const override {
return {
{"num_of_keyswitch_", fmt::to_string(size())},
};
}
};

class BootstrapKey : public spi::EmptyKeySketch<spi::HeKeyType::BootstrapKey> {
};

class ItemTool : public spi::ItemToolScalarSketch<Plaintext, Ciphertext,
SecretKey, PublicKey> {
public:
Plaintext Clone(const Plaintext &pt) const override;
Ciphertext Clone(const Ciphertext &ct) const override;

size_t Serialize(const Plaintext &pt, uint8_t *buf,
size_t buf_len) const override;
size_t Serialize(const Ciphertext &ct, uint8_t *buf,
size_t buf_len) const override;

Plaintext DeserializePT(yacl::ByteContainerView buffer) const override;
Ciphertext DeserializeCT(yacl::ByteContainerView buffer) const override;
};

} // namespace heu::algos::seal_fhe
Loading

0 comments on commit fffd431

Please sign in to comment.