Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created a vignette for making a new random forest and ref scores #24

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ Imports:
Config/testthat/edition: 3
URL: https://github.com/CSAFE-ISU/handwriterRF
BugReports: https://github.com/CSAFE-ISU/handwriterRF/issues
VignetteBuilder: knitr
8 changes: 4 additions & 4 deletions R/compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@
#' @examples
#' \donttest{
#' # Compare two documents from the same writer with a similarity score
#' s1 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r01.png"),
#' s1 <- system.file(file.path("extdata", "docs", "w0005_s01_pLND_r03.png"),
#' package = "handwriterRF"
#' )
#' s2 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r02.png"),
#' s2 <- system.file(file.path("extdata", "docs", "w0005_s02_pWOZ_r02.png"),
#' package = "handwriterRF"
#' )
#' compare_documents(s1, s2, score_only = TRUE)
#'
#' # Compare two documents from the same writer with a score-based
#' # likelihood ratio
#' s1 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r01.png"),
#' s1 <- system.file(file.path("extdata", "docs", "w0005_s01_pLND_r03.png"),
#' package = "handwriterRF"
#' )
#' s2 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r02.png"),
#' s2 <- system.file(file.path("extdata", "docs", "w0005_s02_pWOZ_r02.png"),
#' package = "handwriterRF"
#' )
#' compare_documents(s1, s2, score_only = FALSE)
Expand Down
7 changes: 6 additions & 1 deletion R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ plot_scores <- function(scores, obs_score = NULL, n_bins = 50) {
) + # add text
ggplot2::labs(title = "The observed similarity score compared to reference similarity scores", x = "Score", y = "Rate")
} else {
p <- p + ggplot2::labs(title = "Reference similarity scores", x = "Score", y = "Rate")
p <- p +
ggplot2::labs(title = "Reference similarity scores", x = "Score", y = "Rate")
}
p <- p +
ggplot2::theme(legend.position = "bottom",
legend.text = ggplot2::element_text(size = 6),
legend.title = ggplot2::element_text(size = 8))

return(p)
}
Expand Down
6 changes: 3 additions & 3 deletions R/slrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@
#' @examples
#' \donttest{
#' # Compare two samples from the same writer
#' s1 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r01.png"),
#' s1 <- system.file(file.path("extdata", "docs", "w0005_s01_pLND_r03.png"),
#' package = "handwriterRF"
#' )
#' s2 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r02.png"),
#' s2 <- system.file(file.path("extdata", "docs", "w0005_s02_pWOZ_r02.png"),
#' package = "handwriterRF"
#' )
#' calculate_slr(s1, s2)
#'
#' # Compare samples from two writers
#' s1 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r01.png"),
#' s1 <- system.file(file.path("extdata", "docs", "w0005_s02_pWOZ_r02.png"),
#' package = "handwriterRF"
#' )
#' s2 <- system.file(file.path("extdata", "docs", "w0238_s01_pWOZ_r02.png"),
Expand Down
29 changes: 27 additions & 2 deletions R/train.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,34 @@

#' Train a Random Forest
#'
#' Train a random forest with \pkg{ranger} from a data frame of cluster fill rates.
#' Train a random forest with \pkg{ranger} from a data frame of writer profiles
#' estimated with \code{\link{get_cluster_fill_rates}}. `train_rf` calculates
#' the distance between all pairs of writer profiles using one or more distance
#' measures. Currently, the available distance measures are absolute, Manhattan,
#' Euclidean, maximum, and cosine.
#'
#' @param df A data frame of cluster fill rates created with
#' The absolute distance between two n-length vectors of cluster fill rates, a
#' and b, is a vector of the same length as a and b. It can be calculated as
#' abs(a-b) where subtraction is performed element-wise, then the absolute
#' value of each element is returned. More specifically, element i of the vector is \eqn{|a_i
#' - b_i|} for \eqn{i=1,2,...,n}.
#'
#' The Manhattan distance between two n-length vectors of cluster fill rates, a and b, is
#' \eqn{\sum_{i=1}^n |a_i - b_i|}. In other words, it is the sum of the absolute
#' distance vector.
#'
#' The Euclidean distance between two n-length vectors of cluster fill rates, a and b, is
#' \eqn{\sqrt{\sum_{i=1}^n (a_i - b_i)^2}}. In other words, it is the sum of the elements of the
#' absolute distance vector.
#'
#' The maximum distance between two n-length vectors of cluster fill rates, a and b, is
#' \eqn{\max_{1 \leq i \leq n}{\{|a_i - b_i|\}}}. In other words, it is the sum of the elements of the
#' absolute distance vector.
#'
#' The cosine distance between two n-length vectors of cluster fill rates, a and b, is
#' \eqn{\sum_{i=1}^n (a_i - b_i)^2 / (\sqrt{\sum_{i=1}^n a_i^2}\sqrt{\sum_{i=1}^n b_i^2})}.
#'
#' @param df A data frame of writer profiles created with
#' \code{\link{get_cluster_fill_rates}}
#' @param ntrees An integer number of decision trees to use
#' @param distance_measures A vector of distance measures. Any combination of
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Compare 2 of these samples. In this case, both samples are from writer
30.

``` r
sample1 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r01.png"), package = "handwriterRF")
sample2 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r02.png"), package = "handwriterRF")
sample1 <- system.file(file.path("extdata", "docs", "w0005_s01_pLND_r03.png"), package = "handwriterRF")
sample2 <- system.file(file.path("extdata", "docs", "w0005_s02_pWOZ_r02.png"), package = "handwriterRF")
slr <- calculate_slr(sample1, sample2)
```

Expand Down Expand Up @@ -84,8 +84,8 @@ the data frame fits on this page.
slr
```

docname1 writer1 docname2 writer2 score slr
1 w0030_s01_pWOZ_r01 unknown1 w0030_s01_pWOZ_r02 unknown2 0.955 135.159
docname1 writer1 docname2 writer2 score slr
1 w0005_s01_pLND_r03 unknown1 w0005_s02_pWOZ_r02 unknown2 0.635 1.482318

### Interpret the Score-base Likelihood Ratio

Expand All @@ -95,4 +95,4 @@ View a verbal interpretation of the score-based likelihood ratio.
interpret_slr(slr)
```

[1] "A score-based likelihood ratio of 135.2 means the likelihood of observing a similarity score of 0.955 if the documents were written by the same person is 135.2 times greater than the likelihood of observing this score if the documents were written by different writers."
[1] "A score-based likelihood ratio of 1.5 means the likelihood of observing a similarity score of 0.635 if the documents were written by the same person is 1.5 times greater than the likelihood of observing this score if the documents were written by different writers."
4 changes: 2 additions & 2 deletions README.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ library(handwriterRF)

The package includes 4 example handwriting samples from the [CSAFE Handwriting Database](https://forensicstats.org/handwritingdatabase/). Compare 2 of these samples. In this case, both samples are from writer 30.
```{r calculate1, message=FALSE}
sample1 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r01.png"), package = "handwriterRF")
sample2 <- system.file(file.path("extdata", "docs", "w0030_s01_pWOZ_r02.png"), package = "handwriterRF")
sample1 <- system.file(file.path("extdata", "docs", "w0005_s01_pLND_r03.png"), package = "handwriterRF")
sample2 <- system.file(file.path("extdata", "docs", "w0005_s02_pWOZ_r02.png"), package = "handwriterRF")
slr <- calculate_slr(sample1, sample2)
```

Expand Down
31 changes: 0 additions & 31 deletions data-raw/make_clusters.R

This file was deleted.

98 changes: 52 additions & 46 deletions data-raw/train_valid_test.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
devtools::load_all()
install.packages("handwriter")
devtools::install_github("CSAFE-ISU/handwriterRF")

library(handwriter)
library(handwriterRF)

# Helper Functions --------------------------------------------------------

Expand Down Expand Up @@ -41,11 +44,15 @@ make_csafe_sets <- function(rates, prompts = c("pWOZ", "pLND"), num_per_prompt =
}

drop_columns <- function(df) {
df$doc <- paste(df$session, df$prompt, df$rep, sep = "_")
df <- df %>%
dplyr::ungroup() %>%
dplyr::select(-session, -prompt, -rep)
dplyr::select(-tidyselect::any_of(c("session", "prompt", "rep")))
}

# drop writer and doc column to prevent error with expand_docnames
rates <- rates %>% dplyr::select(-tidyselect::any_of(c("writer", "doc")))

# split writers train, validation, and test sets
all_writers <- find_writers_with_27_docs(df = rates)
writers <- split_writers(
Expand All @@ -71,7 +78,6 @@ make_cvl_sets <- function(rates, num_per_writer = 4, use_German_prompt = FALSE,

find_writers_with_5plus_docs <- function(df) {
# Filter cvl data frame for writers with 5 or more docs
df <- expand_cvl_docnames(df)
writers <- df %>%
dplyr::group_by(writer) %>%
dplyr::summarize(n = dplyr::n()) %>%
Expand All @@ -80,43 +86,44 @@ make_cvl_sets <- function(rates, num_per_writer = 4, use_German_prompt = FALSE,
return(writers)
}

sample_prompts <- function(df, writers, num_per_writer, use_German_prompt) {
df <- expand_cvl_docnames(df = df)

sample_cvl_prompts <- function(df, set_writers, num_per_writer, use_German_prompt) {
if (!use_German_prompt) {
df <- df %>% dplyr::filter(prompt != "6-cropped")
}

df <- df %>%
dplyr::filter(writer %in% writers) %>%
dplyr::filter(writer %in% set_writers) %>%
dplyr::group_by(writer) %>%
dplyr::slice_sample(n = num_per_writer)

return(df)
}

drop_prompt_column <- function(df) {
drop_cvl_prompt_column <- function(df) {
df <- df %>%
dplyr::ungroup() %>%
dplyr::select(-prompt)
return(df)
}

rates <- expand_cvl_docnames(df = rates)
rates$writer <- paste0("c", rates$writer)

all_writers <- find_writers_with_5plus_docs(df = rates)
writers <- split_writers(
all_writers = all_writers,
num_train_writers = num_train_writers,
num_validation_writers = num_validation_writers
)
docs <- lapply(writers, function(w) {
sample_prompts(
sample_cvl_prompts(
df = rates,
writers = w,
set_writers = w,
num_per_writer = num_per_writer,
use_German_prompt = use_German_prompt
)
})
docs <- lapply(docs, drop_prompt_column)
docs <- lapply(docs, drop_cvl_prompt_column)

return(docs)
}
Expand Down Expand Up @@ -145,54 +152,53 @@ split_writers <- function(all_writers, num_train_writers, num_validation_writers

set.seed(100)

# Create data frames of csafe and cvl cluster fill rates
# csafe <- load_cluster_fill_rates(clusters_dir = "/Users/stephanie/Documents/handwriting_datasets/CSAFE_Handwriting_Database/clusters")
# saveRDS(csafe, "data-raw/csafe_cfr.rds")
#
# cvl <- load_cluster_fill_rates("/Users/stephanie/Documents/handwriting_datasets/CVL/clusters")
# saveRDS(cvl, "data-raw/cvl_cfr.rds")
# If you need cluster assignments CVL data
handwriter::get_clusters_batch(input_dir = "path/to/cvl/graphs/dir",
output_dir = "path/to/cvl/clusters/dir",
template = "path/to/template.rds",
writer_indices = c(1,4),
doc_indices = c(6,6),
num_cores = 4)

# Load cluster fill rates
csafe <- readRDS("data-raw/csafe_cfr.rds")
cvl <- readRDS("data-raw/cvl_cfr.rds")
# Create data frames of csafe and cvl cluster fill rates
csafe <- load_cluster_fill_rates(clusters_dir = "/Users/stephanie/Documents/handwriting_datasets/CSAFE_Handwriting_Database/300dpi/clusters")
cvl <- load_cluster_fill_rates("/Users/stephanie/Documents/handwriting_datasets/CVL/300dpi/clusters")

# Make sets
# Make sets. Feel free to change num_train_writers and num_validation_writers.
# Writers not assigned to either of these sets will be placed in the test set.
csafe <- make_csafe_sets(
rates = csafe, prompts = c("pWOZ", "pLND"), num_per_prompt = 2,
num_train_writers = 100, num_validation_writers = 150
rates = csafe,
prompts = c("pWOZ", "pLND"),
num_per_prompt = 2,
num_train_writers = 100,
num_validation_writers = 150
)
saveRDS(csafe, "data-raw/csafe_sets.rds")

cvl <- make_cvl_sets(
rates = cvl, num_per_writer = 4, use_German_prompt = FALSE,
num_train_writers = 100, num_validation_writers = 150
rates = cvl,
num_per_writer = 4,
use_German_prompt = FALSE,
num_train_writers = 100,
num_validation_writers = 150
)
saveRDS(cvl, "data-raw/cvl_sets.rds")

# csafe <- readRDS("data-raw/csafe_sets.rds")
# cvl <- readRDS("data-raw/cvl_sets.rds")

# Train random forest
train <- rbind(csafe$train, cvl$train)
saveRDS(train, "data-raw/train.rds")
usethis::use_data(train, overwrite = TRUE)

random_forest <- train_rf(train, ntrees = 200, distance_measures = c("abs", "euc"))
saveRDS(random_forest, "data-raw/random_forest.rds")
usethis::use_data(random_forest, overwrite = TRUE)
# Choose distance measures
rf <- train_rf(train,
ntrees = 200,
distance_measures = c("abs", "euc"))

# Get similarity scores on validation set
validation <- rbind(csafe$validation, cvl$validation)
saveRDS(validation, "data-raw/validation.rds")
usethis::use_data(validation, overwrite = TRUE)

ref_scores <- get_ref_scores(rforest = random_forest, df = validation)
saveRDS(ref_scores, "data-raw/ref_scores.rds")
usethis::use_data(ref_scores, overwrite = TRUE)
# Get similarity scores on validation set. Note: there will be many times more
# 'different writer' scores compared to 'same writer' scores.
validation <- rbind(csafe$validation, cvl$validation)
rscores <- get_ref_scores(rforest = rf, df = validation)

# Test set
test <- rbind(csafe$test, cvl$test)
saveRDS(test, "data-raw/test.rds")
usethis::use_data(test, overwrite = TRUE)

plot_scores(scores = ref_scores)
results <- compare_writer_profiles(writer_profiles = test,
score_only = FALSE,
rforest = rf,
reference_scores = rscores)
Binary file removed inst/extdata/clusters/w0030_s01_pWOZ_r01.rds
Binary file not shown.
Binary file removed inst/extdata/clusters/w0030_s01_pWOZ_r02.rds
Binary file not shown.
Binary file added inst/extdata/docs/w0005_s01_pLND_r03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inst/extdata/docs/w0005_s02_pWOZ_r02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed inst/extdata/docs/w0030_s01_pWOZ_r01.png
Binary file not shown.
Binary file removed inst/extdata/docs/w0030_s01_pWOZ_r02.png
Binary file not shown.
6 changes: 3 additions & 3 deletions man/calculate_slr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/compare_documents.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading