Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored computeStandardizedDifference #228

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
351 changes: 254 additions & 97 deletions R/CompareCohorts.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Observational Health Data Sciences and Informatics
# Copyright 2023 Observational Health Data Sciences and Informatics
#
# This file is part of FeatureExtraction
#
Expand Down Expand Up @@ -46,113 +46,270 @@
#' )
#' }
#' @export
computeStandardizedDifference <- function(covariateData1, covariateData2, cohortId1 = NULL, cohortId2 = NULL) {
if (!isCovariateData(covariateData1)) {
stop("covariateData1 is not of type 'covariateData'")
}
if (!isCovariateData(covariateData1)) {
stop("covariateData2 is not of type 'covariateData'")

computeStandardizedDifference <-
function(covariateData1,
covariateData2,
cohortId1 = NULL,
cohortId2 = NULL) {
isTypeAndAggregated(covariateData1, "covariateData1")
isTypeAndAggregated(covariateData2, "covariateData2")

if (!setequal(colnames(covariateData1$covariates),
colnames(covariateData2$covariates))) {
stop("Covariate1 and Covariate2 do not have the same structure")
}

covariateDataHasTimeId <- "timeId" %in% colnames(covariateData1$covariates)

result <- dplyr::tibble()

if (!is.null(covariateData1$covariates) &&
!is.null(covariateData2$covariates)) {
covariates1 <-
prepareCovariates(
covariates = covariateData1$covariates,
cohortId = cohortId1,
hasTimeId = covariateDataHasTimeId
)
covariates2 <-
prepareCovariates(
covariates = covariateData2$covariates,
cohortId = cohortId2,
hasTimeId = covariateDataHasTimeId
)

n1 <- getPopulationSize(covariateData1, cohortId1)
n2 <- getPopulationSize(covariateData2, cohortId2)

if (covariateDataHasTimeId) {
m <-
dplyr::bind_rows(
covariates1 %>% dplyr::select(timeId, covariateId),
covariates2 %>% dplyr::select(timeId, covariateId)
) %>%
dplyr::distinct() %>%
dplyr::left_join(covariates1 %>%
dplyr::rename(count1 = count),
by = c("timeId",
"covariateId")) %>%
dplyr::left_join(covariates2 %>%
dplyr::rename(count2 = count),
by = c("timeId",
"covariateId"))
} else {
m <-
dplyr::bind_rows(
covariates1 %>% dplyr::select(covariateId),
covariates2 %>% dplyr::select(covariateId)
) %>%
dplyr::distinct() %>%
dplyr::left_join(covariates1 %>%
dplyr::rename(count1 = count),
by = c("covariateId")) %>%
dplyr::left_join(covariates2 %>%
dplyr::rename(count2 = count),
by = c("covariateId"))
}
m <- m %>%
dplyr::distinct() %>%
tidyr::replace_na(replace = list(count1 = 0,
count2 = 0)) %>%
dplyr::mutate(
mean1 = .data$count1 / !!n1,
mean2 = .data$count2 / !!n2,
sd1 = sqrt(mean1 * (1 - mean1)),
sd2 = sqrt(mean2 * (1 - mean2))
) %>%
dplyr::mutate(sd = sqrt((sd1 ^ 2 + sd2 ^ 2) / 2),
stdDiff = (mean2 - mean1) / sd)
result <-
bindStandardizedDiff(result, m, covariateDataHasTimeId)
}

if (!is.null(covariateData1$covariatesContinuous) &&
!is.null(covariateData2$covariatesContinuous)) {
covariates1 <-
prepareContinuousCovariates(covariateData1$covariatesContinuous,
cohortId1,
covariateDataHasTimeId)
covariates2 <-
prepareContinuousCovariates(covariateData2$covariatesContinuous,
cohortId2,
covariateDataHasTimeId)

m <- dplyr::bind_rows(
covariates1 %>% dplyr::select(timeId, covariateId),
covariates2 %>% dplyr::select(timeId, covariateId)
) %>%
dplyr::distinct() %>%
dplyr::left_join(covariates1 %>%
dplyr::rename(mean1 = mean,
sd1 = sd),
by = c("timeId",
"covariateId")) %>%
dplyr::left_join(covariates2 %>%
dplyr::rename(mean2 = mean,
sd2 = sd),
by = c("timeId",
"covariateId")) %>%
dplyr::distinct() %>%
tidyr::replace_na(replace = list(
mean1 = 0,
mean2 = 0,
sd1 = 0,
sd2 = 0
)) %>%
dplyr::mutate(sd = sqrt((sd1 ^ 2 + sd2 ^ 2) / 2),
stdDiff = (mean2 - mean1) / sd)

result <-
bindStandardizedDiff(result, m, covariateDataHasTimeId)
}

result <-
joinAndArrange(result,
covariateData1,
covariateData2,
covariateDataHasTimeId)

return(result)
}
if (!isAggregatedCovariateData(covariateData1)) {
stop("Covariate1 data is not aggregated")

isTypeAndAggregated <- function(data, name) {
if (!isCovariateData(data)) {
stop(paste(name, "is not of type 'covariateData'"))
}
if (!isAggregatedCovariateData(covariateData2)) {
stop("Covariate2 data is not aggregated")
if (!isAggregatedCovariateData(data)) {
stop(paste("Covariate data in", name, "is not aggregated"))
}
result <- tibble()
if (!is.null(covariateData1$covariates) && !is.null(covariateData2$covariates)) {
covariates1 <- covariateData1$covariates
if (!is.null(cohortId1)) {
covariates1 <- covariates1 %>%
filter(.data$cohortDefinitionId == cohortId1)
}
covariates1 <- covariates1 %>%
select(
covariateId = "covariateId",
count1 = "sumValue"
) %>%
collect()
}

covariates2 <- covariateData2$covariates
if (!is.null(cohortId2)) {
covariates2 <- covariates2 %>%
filter(.data$cohortDefinitionId == cohortId2)
}
covariates2 <- covariates2 %>%
select(
covariateId = "covariateId",
count2 = "sumValue"
) %>%
collect()
prepareCovariates <- function(covariates, cohortId, hasTimeId) {
if (!is.null(cohortId)) {
covariates <-
covariates %>% dplyr::filter(.data$cohortDefinitionId == cohortId)
}
if (hasTimeId) {
covariates <-
covariates %>% dplyr::select(timeId, covariateId, sumValue) %>% dplyr::rename(count = sumValue) %>% dplyr::collect()
} else {
covariates <-
covariates %>% dplyr::select(covariateId, sumValue) %>% dplyr::rename(count = sumValue) %>% dplyr::collect()
}
return(covariates)
}

n1 <- attr(covariateData1, "metaData")$populationSize
if (!is.null(cohortId1)) {
n1 <- n1[as.character(cohortId1)]
prepareContinuousCovariates <-
function(covariates, cohortId, hasTimeId) {
if (!is.null(cohortId)) {
covariates <-
covariates %>% filter(.data$cohortDefinitionId == cohortId)
}
n2 <- attr(covariateData2, "metaData")$populationSize
if (!is.null(cohortId2)) {
n2 <- n2[as.character(cohortId2)]
if (hasTimeId) {
covariates <- covariates %>%
select(
timeId = "timeId",
covariateId = "covariateId",
mean = "averageValue",
sd = "standardDeviation"
) %>%
collect()
} else {
covariates <- covariates %>%
select(covariateId = "covariateId",
mean = "averageValue",
sd = "standardDeviation") %>%
collect()
}
m <- merge(covariates1, covariates2, all = T)
m$count1[is.na(m$count1)] <- 0
m$count2[is.na(m$count2)] <- 0
m$mean1 <- m$count1 / n1
m$mean2 <- m$count2 / n2
m$sd1 <- sqrt(m$mean1 * (1 - m$mean1))
m$sd2 <- sqrt(m$mean2 * (1 - m$mean2))
m$sd <- sqrt((m$sd1^2 + m$sd2^2) / 2)
m$stdDiff <- (m$mean2 - m$mean1) / m$sd
result <- bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
return(covariates)
}
if (!is.null(covariateData1$covariatesContinuous) && !is.null(covariateData2$covariatesContinuous)) {
covariates1 <- covariateData1$covariatesContinuous
if (!is.null(cohortId1)) {
covariates1 <- covariates1 %>%
filter(.data$cohortDefinitionId == cohortId1)
}
covariates1 <- covariates1 %>%
select(
covariateId = "covariateId",
mean1 = "averageValue",
sd1 = "standardDeviation"
) %>%
collect()

covariates2 <- covariateData2$covariatesContinuous
if (!is.null(cohortId2)) {
covariates2 <- covariates2 %>%
filter(.data$cohortDefinitionId == cohortId2)
}
covariates2 <- covariates2 %>%
select(
covariateId = "covariateId",
mean2 = "averageValue",
sd2 = "standardDeviation"
) %>%
collect()

m <- merge(covariates1, covariates2, all = T)
m$mean1[is.na(m$mean1)] <- 0
m$sd1[is.na(m$sd1)] <- 0
m$mean2[is.na(m$mean2)] <- 0
m$sd2[is.na(m$sd2)] <- 0
m$sd <- sqrt((m$sd1^2 + m$sd2^2) / 2)
m$stdDiff <- (m$mean2 - m$mean1) / m$sd
result <- bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
getPopulationSize <- function(covariateData, cohortId) {
populationSize <- attr(covariateData, "metaData")$populationSize
if (!is.null(cohortId)) {
populationSize <- populationSize[as.character(cohortId)]
}
covariateRef1 <- covariateData1$covariateRef %>%
collect()
covariateRef2 <- covariateData2$covariateRef %>%
collect()
return(populationSize)
}

result <- result %>%
left_join(select(covariateRef1, covariateId = "covariateId", covariateName1 = "covariateName"), by = "covariateId") %>%
left_join(select(covariateRef2, covariateId = "covariateId", covariateName2 = "covariateName"), by = "covariateId") %>%
mutate(covariateName = case_when(
is.na(covariateName1) ~ covariateName2,
TRUE ~ covariateName1
)) %>%
select(-rlang::sym("covariateName1"), -rlang::sym("covariateName2")) %>%
arrange(desc(abs(!!rlang::sym("stdDiff"))))
bindStandardizedDiff <- function(result, newData, hasTimeId) {
selectedCols <- if (hasTimeId) {
c("covariateId",
"timeId",
"mean1",
"sd1",
"mean2",
"sd2",
"sd",
"stdDiff")
} else {
c("covariateId",
"mean1",
"sd1",
"mean2",
"sd2",
"sd",
"stdDiff")
}
result <- dplyr::bind_rows(result, newData[, selectedCols])
return(result)
}


joinAndArrange <-
function(result,
covariateData1,
covariateData2,
hasTimeId) {
covariateRef1 <- covariateData1$covariateRef %>% dplyr::collect()
covariateRef2 <- covariateData2$covariateRef %>% dplyr::collect()

# unchecked assumption covariateRef's ae the same. They should be same if the output is from same feature extraction run.

covariateRef <- dplyr::bind_rows(covariateRef1,
covariateRef2) %>%
dplyr::distinct()

if (any(duplicated(covariateRef$covariateId))) {
stop("CovariateRef's are not compatible")
}

if (hasTimeId) {
timeRef1 <- covariateData1$timeRef %>% dplyr::collect()
timeRef2 <- covariateData2$timeRef %>% dplyr::collect()

# unchecked assumption timeRef's are the same. They should be same if the output is from same feature extraction run.

timeRef <- dplyr::bind_rows(timeRef1,
timeRef2) %>%
dplyr::distinct()

if (any(duplicated(covariateRef$covariateId))) {
stop("timeRef's are not compatible")
}

result <- result %>%
dplyr::left_join(covariateRef %>%
dplyr::select(covariateId, covariateName),
by = "covariateId") %>%
dplyr::left_join(timeRef %>%
dplyr::select(timeId, startDay, endDay),
by = "timeId") %>%
dplyr::arrange(dplyr::desc(abs(stdDiff)))
} else {
result <- result %>%
dplyr::left_join(covariateRef %>%
dplyr::select(covariateId,
covariateName),
by = "covariateId") %>%
dplyr::left_join(covariateRef2 %>%
dplyr::select(covariateId,
covariateName),
by = "covariateId") %>%
dplyr::arrange(dplyr::desc(abs(stdDiff)))
}

return(result)
}
Loading