Skip to content

Commit

Permalink
Merge pull request #6 from andrewwhitehead/logging-perf
Browse files Browse the repository at this point in the history
Logging performance improvements
  • Loading branch information
andrewwhitehead authored Mar 4, 2021
2 parents cc0de8b + 3000d50 commit 978f30f
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "aries-askar"
version = "0.1.2"
version = "0.1.3"
authors = ["Hyperledger Aries Contributors <[email protected]>"]
edition = "2018"
description = "Hyperledger Aries Askar secure storage"
Expand Down
46 changes: 35 additions & 11 deletions src/ffi/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ use std::ptr;
use log::{LevelFilter, Metadata, Record};

use super::error::ErrorCode;
use crate::error::Error;

pub type EnabledCallback =
extern "C" fn(context: *const c_void, level: u32, target: *const c_char) -> bool;
pub type EnabledCallback = extern "C" fn(context: *const c_void, level: i32) -> i8;

pub type LogCallback = extern "C" fn(
context: *const c_void,
level: u32,
level: i32,
target: *const c_char,
message: *const c_char,
module_path: *const c_char,
file: *const c_char,
line: u32,
line: i32,
);

pub type FlushCallback = extern "C" fn(context: *const c_void);
Expand Down Expand Up @@ -47,10 +47,7 @@ impl CustomLogger {
impl log::Log for CustomLogger {
fn enabled(&self, metadata: &Metadata<'_>) -> bool {
if let Some(enabled_cb) = self.enabled {
let level = metadata.level() as u32;
let target = CString::new(metadata.target()).unwrap();

enabled_cb(self.context, level, target.as_ptr())
enabled_cb(self.context, metadata.level() as i32) != 0
} else {
true
}
Expand All @@ -59,13 +56,13 @@ impl log::Log for CustomLogger {
fn log(&self, record: &Record<'_>) {
let log_cb = self.log;

let level = record.level() as u32;
let level = record.level() as i32;
let target = CString::new(record.target()).unwrap();
let message = CString::new(record.args().to_string()).unwrap();

let module_path = record.module_path().map(|s| CString::new(s).unwrap());
let file = record.file().map(|s| CString::new(s).unwrap());
let line = record.line().unwrap_or(0);
let line = record.line().unwrap_or(0) as i32;

log_cb(
self.context,
Expand Down Expand Up @@ -97,11 +94,13 @@ pub extern "C" fn askar_set_custom_logger(
log: LogCallback,
enabled: Option<EnabledCallback>,
flush: Option<FlushCallback>,
max_level: i32,
) -> ErrorCode {
catch_err! {
let max_level = get_level_filter(max_level)?;
let logger = CustomLogger::new(context, enabled, log, flush);
log::set_boxed_logger(Box::new(logger)).map_err(err_map!(Unexpected))?;
log::set_max_level(LevelFilter::Trace);
log::set_max_level(max_level);
debug!("Initialized custom logger");
Ok(ErrorCode::Success)
}
Expand All @@ -115,3 +114,28 @@ pub extern "C" fn askar_set_default_logger() -> ErrorCode {
Ok(ErrorCode::Success)
}
}

#[no_mangle]
pub extern "C" fn askar_set_max_log_level(max_level: i32) -> ErrorCode {
catch_err! {
log::set_max_level(get_level_filter(max_level)?);
Ok(ErrorCode::Success)
}
}

fn get_level_filter(max_level: i32) -> Result<LevelFilter, Error> {
Ok(match max_level {
-1 => {
// load from RUST_LOG environment variable
// defaults to ERROR if unspecified
env_logger::Logger::from_default_env().filter()
}
0 => LevelFilter::Off,
1 => LevelFilter::Error,
2 => LevelFilter::Warn,
3 => LevelFilter::Info,
4 => LevelFilter::Debug,
5 => LevelFilter::Trace,
_ => return Err(err_msg!(Input, "Invalid log level")),
})
}
14 changes: 11 additions & 3 deletions src/postgres/provision.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::borrow::Cow;
use std::str::FromStr;
use std::time::Duration;

use sqlx::{
postgres::{PgConnection, PgPool, PgPoolOptions, Postgres},
Connection, Error as SqlxError, Executor, Row, Transaction,
postgres::{PgConnectOptions, PgConnection, PgPool, PgPoolOptions, Postgres},
ConnectOptions, Connection, Error as SqlxError, Executor, Row, Transaction,
};

use crate::db_utils::{init_keys, random_profile_name};
Expand Down Expand Up @@ -110,13 +111,20 @@ impl PostgresStoreOptions {
}

async fn pool(&self) -> std::result::Result<PgPool, SqlxError> {
#[allow(unused_mut)]
let mut conn_opts = PgConnectOptions::from_str(self.uri.as_str())?;
#[cfg(feature = "log")]
{
conn_opts.log_statements(log::LevelFilter::Debug);
conn_opts.log_slow_statements(log::LevelFilter::Debug, Default::default());
}
PgPoolOptions::default()
.connect_timeout(self.connect_timeout)
.idle_timeout(self.idle_timeout)
.max_connections(self.max_connections)
.min_connections(self.min_connections)
.test_before_acquire(false)
.connect(self.uri.as_str())
.connect_with(conn_opts)
.await
}

Expand Down
10 changes: 8 additions & 2 deletions src/sqlite/provision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::str::FromStr;

use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions},
Error as SqlxError, Row,
ConnectOptions, Error as SqlxError, Row,
};

use super::SqliteStore;
Expand Down Expand Up @@ -48,8 +48,14 @@ impl SqliteStoreOptions {
}

async fn pool(&self, auto_create: bool) -> std::result::Result<SqlitePool, SqlxError> {
let conn_opts =
#[allow(unused_mut)]
let mut conn_opts =
SqliteConnectOptions::from_str(self.path.as_ref())?.create_if_missing(auto_create);
#[cfg(feature = "log")]
{
conn_opts.log_statements(log::LevelFilter::Debug);
conn_opts.log_slow_statements(log::LevelFilter::Debug, Default::default());
}
SqlitePoolOptions::default()
// maintains at least 1 connection.
// for an in-memory database this is required to avoid dropping the database,
Expand Down
72 changes: 53 additions & 19 deletions wrappers/python/aries_askar/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
c_ubyte,
)
from ctypes.util import find_library
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Tuple, Union

from .error import StoreError, StoreErrorCode
from .types import Entry, EntryOperation, KeyAlg
Expand All @@ -29,6 +29,13 @@
CALLBACKS = {}
LIB: CDLL = None
LOGGER = logging.getLogger(__name__)
LOG_LEVELS = {
1: logging.ERROR,
2: logging.WARNING,
3: logging.INFO,
4: logging.DEBUG,
}
MODULE_NAME = __name__.split(".")[0]


class StoreHandle(c_int64):
Expand Down Expand Up @@ -200,7 +207,7 @@ def get_library() -> CDLL:
global LIB
if LIB is None:
LIB = _load_library("aries_askar")
_set_logger()
_init_logger()
return LIB


Expand Down Expand Up @@ -237,16 +244,14 @@ def _load_library(lib_name: str) -> CDLL:
) from e


def _set_logger():
logger = logging.getLogger("aries_askar")
logging.addLevelName(5, "TRACE")
level_mapping = {
1: logging.ERROR,
2: logging.WARNING,
3: logging.INFO,
4: logging.DEBUG,
5: 5,
}
def _init_logger():
logger = logging.getLogger(MODULE_NAME)
if logging.getLevelName("TRACE") == "Level TRACE":
# avoid redefining TRACE if another library has added it
logging.addLevelName(5, "TRACE")

def _enabled(_context, level: int) -> bool:
return logger.isEnabledFor(LOG_LEVELS.get(level, level))

def _log(
_context,
Expand All @@ -258,26 +263,55 @@ def _log(
line: int,
):
logger.getChild("native." + target.decode().replace("::", ".")).log(
level_mapping[level],
LOG_LEVELS.get(level, level),
"\t%s:%d | %s",
file_name.decode() if file_name else None,
line,
message.decode(),
)

_set_logger.callback = CFUNCTYPE(
_init_logger.enabled_cb = CFUNCTYPE(c_int8, c_void_p, c_int32)(_enabled)

_init_logger.log_cb = CFUNCTYPE(
None, c_void_p, c_int32, c_char_p, c_char_p, c_char_p, c_char_p, c_int32
)(_log)

if os.getenv("RUST_LOG"):
# level from environment
level = -1
else:
# inherit current level from logger
level = _convert_log_level(logger.level or logger.parent.level)

do_call(
"askar_set_custom_logger",
c_void_p(), # context
_set_logger.callback,
c_void_p(), # enabled
_init_logger.log_cb,
_init_logger.enabled_cb,
c_void_p(), # flush
c_int32(level),
)


def set_max_log_level(level: Union[str, int, None]):
get_library() # ensure logger is initialized
set_level = _convert_log_level(level)
do_call("askar_set_max_log_level", c_int32(set_level))


def _convert_log_level(level: Union[str, int, None]):
if level is None or level == "-1":
return -1
else:
if isinstance(level, str):
level = level.upper()
name = logging.getLevelName(level)
for k, v in LOG_LEVELS.items():
if logging.getLevelName(v) == name:
return k
return 0


def _fulfill_future(fut: asyncio.Future, result, err: Exception = None):
"""Resolve a callback future given the result and exception, if any."""
if fut.cancelled():
Expand Down Expand Up @@ -396,7 +430,7 @@ def get_current_error(expect: bool = False) -> Optional[StoreError]:
return StoreError(StoreErrorCode.WRAPPER, "Unknown error")


async def derive_verkey(key_alg: KeyAlg, seed: [str, bytes]) -> str:
async def derive_verkey(key_alg: KeyAlg, seed: Union[str, bytes]) -> str:
"""Derive a verification key from a seed."""
return str(
await do_call_async(
Expand All @@ -408,7 +442,7 @@ async def derive_verkey(key_alg: KeyAlg, seed: [str, bytes]) -> str:
)


async def generate_raw_key(seed: [str, bytes] = None) -> str:
async def generate_raw_key(seed: Union[str, bytes] = None) -> str:
"""Generate a new raw store wrapping key."""
return str(
await do_call_async(
Expand Down Expand Up @@ -730,7 +764,7 @@ async def session_pack_message(
async def session_unpack_message(
handle: SessionHandle,
message: Union[str, bytes],
) -> (ByteBuffer, str, Optional[str]):
) -> Tuple[ByteBuffer, str, Optional[str]]:
message = encode_bytes(message)
result = await do_call_async(
"askar_session_unpack_message", handle, message, return_type=lib_unpack_result
Expand Down
2 changes: 1 addition & 1 deletion wrappers/python/aries_askar/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""aries_askar library wrapper version."""

__version__ = "0.1.2"
__version__ = "0.1.3"
9 changes: 8 additions & 1 deletion wrappers/python/demo/perf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import asyncio
import logging
import os
import sys
import time

from aries_askar.bindings import generate_raw_key, version
from aries_askar.bindings import (
generate_raw_key,
version,
)
from aries_askar import Store

logging.basicConfig(level=os.getenv("LOG_LEVEL", "").upper() or None)

if len(sys.argv) > 1:
REPO_URI = sys.argv[1]
if REPO_URI == "postgres":
Expand Down
4 changes: 2 additions & 2 deletions wrappers/python/demo/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
import sys

from aries_askar.bindings import (
Expand All @@ -10,8 +11,7 @@
)
from aries_askar import KeyAlg, Store

logging.basicConfig(level=logging.INFO)
# logging.getLogger("aries_askar").setLevel(5)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "").upper() or None)

if len(sys.argv) > 1:
REPO_URI = sys.argv[1]
Expand Down

0 comments on commit 978f30f

Please sign in to comment.