
emp_vec <- \(n, q, prob = TRUE) {
    N <- length(q)
    y <- sample(N, n, TRUE, q)
    x <- tabulate(y, N)
    if (prob) x / n else x
}

emp_mat <- \(n, q, prob = TRUE) {
    p <- matrix(0, nrow(q), ncol(q))
    for (i in seq_along(n)) {
        p[i, ] <- emp_vec(n[i], q[i, ], prob)
    }
    p
}

emp_llst <- \(n, mu, prob = TRUE) {
    rrapply(mu, f = \(x, .xpos) emp_vec(n[[.xpos]], x, prob), how = "replace")
}

emp_measure <- \(n, mu, prob = TRUE) {
    if (is.list(mu)) {
        emp_llst(n, mu, prob)
    } else {
        emp_mat(n, mu, prob)
    }
}

#' @title Generate tabulated samples from probability vectors
#' @description Generate count vectors instead of samples, i.e., vectors giving the number of times a sample was observed at the respective points.
#' @param n vector or nested list of sample sizes.
#' @param mu matrix (row-wise) or nested list containing probability vectors to sample from. The structure of `n` and `mu` must be the same.
#' @param prob logical value indicating whether probabilities (instead of counts) should be returned.
#' @returns The count vectors corresponding to the generated samples. Has the same structure as `mu`.
#' @examples
#' ## matrix example
#'
#' mu <- matrix(c(0.01, 0.99, 0.5, 0.5), 2, 2, TRUE)
#' n <- c(80, 20)
#'
#' set.seed(123)
#' cv <- tab_sample(n, mu)
#' print(cv)
#' # sample sizes are rowsums
#' print(rowSums(cv))
#' # empirical probability vectors
#' print(sweep(cv, 1, n, "/"))
#' set.seed(123)
#' # same result
#' print(tab_sample(n, mu, prob = TRUE))
#'
#' ## list example
#'
#' mu <- list(
#'    list(c(0.3, 0.7), c(0.25, 0.75)),
#'    list(c(0, 1), c(0.5, 0.5))
#' )
#' n <- list(list(100, 120), list(80, 90))
#'
#' set.seed(123)
#' cv <- tab_sample(n, mu)
#' print(cv)
#' # empirical probability vectors
#' print(rapply(cv, \(x) x / sum(x), how = "replace"))
#' set.seed(123)
#' print(tab_sample(n, mu, prob = TRUE))
#' @export
tab_sample <- \(n, mu, prob = FALSE) {
    stopifnot(is_log_scalar(prob))
    emp_measure(n, mu, prob)
}

##

multinom_cov_mat <- \(p) {
    x <- outer(-p, p, "*")
    diag(x) <- p * (1 - p)
    x
}

mvrnorm_mat <- \(Sigma) {
    e <- eigen(Sigma, symmetric = TRUE)
    stopifnot(all(e$values >= -1e-6 * abs(e$values[1])))
    e$vectors %*% diag(sqrt(pmax(e$values, 0)))
}

get_gen_G_mat_one <- \(mu, delta, N, K) {
    delta <- sqrt(delta)
    Sigma12t <- mvrnorm_mat(multinom_cov_mat(mu[1, ])) |> t()
    \() {
        x <- stats::rnorm(K * N) |> matrix(K, N, byrow = TRUE)
        sweep(x %*% Sigma12t, 1, delta, "*")
    }
}

get_gen_G_mat <- \(mu, delta, N, K) {
    Sigma12s <- lapply(seq_along(delta), \(k) sqrt(delta[k]) * mvrnorm_mat(multinom_cov_mat(mu[k, ])))
    \() {
        # byrow = TRUE is important for reproducibility with regards to get_gen_G_llst()
        x <- stats::rnorm(K * N) |> matrix(K, N, byrow = TRUE)
        for (k in seq_along(delta)) {
            x[k, ] <- drop(Sigma12s[[k]] %*% x[k, ])
        }
        x
    }
}

get_gen_G_llst <- \(mu, delta, N, K) {
    Sigma12s <- rrapply(mu, f = \(x, .xpos) sqrt(delta[[.xpos]]) * mvrnorm_mat(multinom_cov_mat(x)), how = "replace")
    \() {
        rapply(Sigma12s, \(Sigma) drop(Sigma %*% stats::rnorm(N)), how = "replace")
    }
}

get_gen_G <- \(mu, delta, N, K) {
    if (is.list(mu)) {
        get_gen_G_llst(mu, delta, N, K)
    } else {
        get_gen_G_mat(mu, delta, N, K)
    }
}

get_gen_G_boot <- \(mu, rho, n) {
    rho <- sqrt(rho)
    if (is.list(mu)) {
        \() {
            hmu <- emp_llst(n, mu)
            rrapply(mu, f = \(x, .xpos) rho * (hmu[[.xpos]] - x), how = "replace")
        }
    } else {
        \() {
            hmu <- emp_mat(n, mu)
            rho * (hmu - mu)
        }
    }
}

#

add_mean_to_gen <- \(get.gen_G, mean) {
    if (is.null(mean)) {
        get.gen_G
    } else {
        \(...) {
            gen_G <- get.gen_G(...)
            \() {
                G <- gen_G()
                if (is.list(G)) {
                    rrapply(G, f = \(x, .xpos) x + mean[[.xpos]], how = "replace")
                } else {
                    G + mean
                }
            }
        }
    }
}

##

get_permute_emp_mat <- \(samples, n) {
    N <- ncol(samples)
    combined <- colSums(samples)
    supp <- rep(seq_along(combined), times = combined)
    idx <- group_idx(n)
    \() {
        perm <- sample(supp)
        do.call(rbind, lapply(seq_along(idx), \(i) tabulate(perm[idx[[i]]], N) / n[i]))
    }
}
