Skip to content

Commit

Permalink
Conditionally use timeId
Browse files Browse the repository at this point in the history
  • Loading branch information
gowthamrao committed Feb 1, 2024
1 parent 6a7e90a commit d0c310e
Showing 1 changed file with 80 additions and 28 deletions.
108 changes: 80 additions & 28 deletions R/CompareCohorts.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,31 +59,52 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort
if (!isAggregatedCovariateData(covariateData2)) {
stop("Covariate2 data is not aggregated")
}
if (colnames(covariateData1$covariates) |> sort() != colnames(covariateData1$covariates) |> sort()) {
stop("Covariate1 and Covariate2 do not have the same structure")
}
covariateDataHasTimeId <- FALSE
if ("timeId" %in% colnames(covariateData1$covariates)) {
covariateDataHasTimeId <- TRUE
}

result <- tibble()
if (!is.null(covariateData1$covariates) && !is.null(covariateData2$covariates)) {
covariates1 <- covariateData1$covariates
if (!is.null(cohortId1)) {
covariates1 <- covariates1 %>%
filter(cohortDefinitionId == cohortId1)
}
covariates1 <- covariates1 %>%
select(
covariateId = "covariateId",
count1 = "sumValue"
) %>%
collect()

if (covariateDataHasTimeId) {
covariates1 <- covariates1 %>%
select(timeId = "timeId",
covariateId = "covariateId",
count1 = "sumValue") %>%
collect()
} else {
covariates1 <- covariates1 %>%
select(covariateId = "covariateId",
count1 = "sumValue") %>%
collect()
}

covariates2 <- covariateData2$covariates
if (!is.null(cohortId2)) {
covariates2 <- covariates2 %>%
filter(cohortDefinitionId == cohortId2)
}
covariates2 <- covariates2 %>%
select(
covariateId = "covariateId",
count2 = "sumValue"
) %>%
collect()
if (covariateDataHasTimeId) {
covariates2 <- covariates2 %>%
select(timeId = "timeId",
covariateId = "covariateId",
count2 = "sumValue") %>%
collect()
} else {
covariates2 <- covariates2 %>%
select(covariateId = "covariateId",
count2 = "sumValue") %>%
collect()
}

n1 <- attr(covariateData1, "metaData")$populationSize
if (!is.null(cohortId1)) {
Expand All @@ -102,34 +123,59 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort
m$sd2 <- sqrt(m$mean2 * (1 - m$mean2))
m$sd <- sqrt((m$sd1^2 + m$sd2^2) / 2)
m$stdDiff <- (m$mean2 - m$mean1) / m$sd
result <- bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
if (covariateDataHasTimeId) {
result <-
bind_rows(result, m[, c("covariateId", "timeId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
} else {
result <-
bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
}
}
if (!is.null(covariateData1$covariatesContinuous) && !is.null(covariateData2$covariatesContinuous)) {
covariates1 <- covariateData1$covariatesContinuous
if (!is.null(cohortId1)) {
covariates1 <- covariates1 %>%
filter(cohortDefinitionId == cohortId1)
}
covariates1 <- covariates1 %>%
select(
covariateId = "covariateId",
mean1 = "averageValue",
sd1 = "standardDeviation"
) %>%
collect()

if (covariateDataHasTimeId) {
covariates1 <- covariates1 %>%
select(
timeId = "timeId",
covariateId = "covariateId",
mean1 = "averageValue",
sd1 = "standardDeviation"
) %>%
collect()
} else {
covariates1 <- covariates1 %>%
select(covariateId = "covariateId",
mean1 = "averageValue",
sd1 = "standardDeviation") %>%
collect()
}

covariates2 <- covariateData2$covariatesContinuous
if (!is.null(cohortId2)) {
covariates2 <- covariates2 %>%
filter(cohortDefinitionId == cohortId2)
}
covariates2 <- covariates2 %>%
select(
covariateId = "covariateId",
mean2 = "averageValue",
sd2 = "standardDeviation"
) %>%
collect()
if (covariateDataHasTimeId) {
covariates2 <- covariates2 %>%
select(
timeId = "timeId",
covariateId = "covariateId",
mean2 = "averageValue",
sd2 = "standardDeviation"
) %>%
collect()
} else {
covariates2 <- covariates2 %>%
select(covariateId = "covariateId",
mean2 = "averageValue",
sd2 = "standardDeviation") %>%
collect()
}

m <- merge(covariates1, covariates2, all = T)
m$mean1[is.na(m$mean1)] <- 0
Expand All @@ -138,7 +184,13 @@ computeStandardizedDifference <- function(covariateData1, covariateData2, cohort
m$sd2[is.na(m$sd2)] <- 0
m$sd <- sqrt(m$sd1^2 + m$sd2^2)
m$stdDiff <- (m$mean2 - m$mean1) / m$sd
result <- bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
if (covariateDataHasTimeId) {
result <-
bind_rows(result, m[, c("covariateId", "timeId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
} else {
result <-
bind_rows(result, m[, c("covariateId", "mean1", "sd1", "mean2", "sd2", "sd", "stdDiff")])
}
}
covariateRef1 <- covariateData1$covariateRef %>%
collect()
Expand Down

0 comments on commit d0c310e

Please sign in to comment.