#' Fit a Joint Longitudinal-Survival Model with Semi-Parametric Association Surfaces
#'
#' @description
#' Two-stage estimation framework for multi-state joint models with
#' tensor-product spline association surfaces. Stage 1 fits mixed-effects
#' longitudinal models and extracts BLUPs. Stage 2 fits transition-specific
#' penalized Cox models with tensor-product spline surfaces via REML.
#'
#' @param long_data Data frame of longitudinal biomarker measurements.
#'   Required columns: \code{patient_id}, \code{visit_time_years},
#'   \code{biomarker}, \code{value}.
#' @param surv_data Data frame of survival/transition events.
#'   Required columns: \code{patient_id}, \code{start_time}, \code{stop_time},
#'   \code{status}, \code{state_from}, \code{state_to}, \code{transition}.
#' @param transitions Character vector of transitions to model (e.g.,
#'   \code{"CKD -> CVD"}). If \code{NULL}, all observed transitions are used.
#' @param covariates Character vector of baseline covariate names present in
#'   \code{surv_data}. Default \code{c("age_baseline", "sex")}.
#' @param k_marginal Integer vector of length 1 or 2 giving marginal basis
#'   dimensions for the tensor-product spline. Default \code{c(5, 5)}.
#' @param k_additive Integer giving the basis dimension for the additive
#'   smooth of the third biomarker (if present). Default \code{6}.
#' @param bs Character string for the spline basis type. Default \code{"tp"}
#'   (thin-plate regression spline).
#' @param method Smoothing parameter estimation method. Default \code{"REML"}.
#' @param min_events Integer minimum number of events required to fit a
#'   transition model. Default \code{10}.
#' @param verbose Logical; print progress messages. Default \code{TRUE}.
#'
#' @return An object of class \code{"jmSurface"} containing:
#'   \item{lme_fits}{Named list of \code{nlme::lme} objects (one per biomarker)}
#'   \item{gam_fits}{Named list of \code{mgcv::gam} objects (one per transition)}
#'   \item{eta_data}{Named list of analysis data frames with latent summaries}
#'   \item{transitions}{Character vector of fitted transitions}
#'   \item{biomarkers}{Character vector of biomarker names}
#'   \item{covariates}{Character vector of covariate names used}
#'   \item{edf}{Named numeric vector of EDF values per transition}
#'   \item{deviance_explained}{Named numeric of deviance explained per transition}
#'   \item{call}{The matched call}
#'
#' @details
#' The model for each transition \eqn{(r,s)} is:
#' \deqn{\lambda_i^{rs}(t | \eta_i(t)) = \lambda_0^{rs}(t) \exp\{\gamma_{rs}' w_i + f_{rs}(\eta_i(t))\}}
#' where \eqn{f_{rs}} is a semi-parametric association surface represented via
#' tensor-product splines, and \eqn{\eta_i(t)} are BLUP-based latent
#' longitudinal summaries evaluated at the midpoint of each sojourn interval.
#'
#' @references
#' Bhattacharjee, A. (2025). Interpretable Multi-Biomarker Fusion in Joint
#' Longitudinal-Survival Models via Semi-Parametric Association Surfaces.
#'
#' Bhattacharjee, A. (2024). jmBIG: Scalable Joint Models for Big Data.
#'
#' Wood, S.N. (2017). Generalized Additive Models: An Introduction with R.
#' Chapman & Hall/CRC.
#'
#' Tsiatis, A.A. & Davidian, M. (2004). Joint modeling of longitudinal and
#' time-to-event data: an overview. Statistica Sinica, 14, 809-834.
#'
#' @examples
#' \donttest{
#' # Simulate data
#' sim <- simulate_jmSurface(n_patients = 300)
#'
#' # Fit the joint model
#' fit <- jmSurf(
#'   long_data = sim$long_data,
#'   surv_data = sim$surv_data,
#'   covariates = c("age_baseline", "sex")
#' )
#'
#' # Summary with EDF diagnostics
#' summary(fit)
#'
#' # Dynamic prediction for patient 1
#' pred <- dynPred(fit, patient_id = 1, landmark = 2, horizon = 3)
#'
#' # Visualize surfaces
#' plot_surface(fit, transition = "CKD -> CVD")
#' contour_heatmap(fit, transition = "CKD -> CVD")
#' marginal_slices(fit, transition = "CKD -> CVD")
#' }
#'
#' @export
jmSurf <- function(long_data,
                    surv_data,
                    transitions = NULL,
                    covariates = c("age_baseline", "sex"),
                    k_marginal = c(5, 5),
                    k_additive = 6,
                    bs = "tp",
                    method = "REML",
                    min_events = 10,
                    verbose = TRUE) {

  cl <- match.call()

  ## ── Validate inputs ──
  req_long <- c("patient_id", "visit_time_years", "biomarker", "value")
  req_surv <- c("patient_id", "start_time", "stop_time", "status",
                 "state_from", "state_to", "transition")

  if (!all(req_long %in% names(long_data)))
    stop("long_data must contain columns: ", paste(req_long, collapse = ", "))
  if (!all(req_surv %in% names(surv_data)))
    stop("surv_data must contain columns: ", paste(req_surv, collapse = ", "))

  ## Ensure shared patients
  shared_ids <- intersect(unique(long_data$patient_id),
                          unique(surv_data$patient_id))
  if (length(shared_ids) < 20)
    stop("Fewer than 20 shared patient_ids between long_data and surv_data.")

  long_data <- long_data[long_data$patient_id %in% shared_ids, ]
  surv_data <- surv_data[surv_data$patient_id %in% shared_ids, ]

  ## Normalize arrow encoding in transition column (handle both → and ->)
  surv_data$transition <- gsub("\u2192", "->", surv_data$transition)
  if (!is.null(transitions)) {
    transitions <- gsub("\u2192", "->", transitions)
  }

  ## Identify transitions
  if (is.null(transitions)) {
    event_rows <- surv_data[surv_data$status == 1, ]
    trans_tab <- table(event_rows$transition)
    transitions <- names(trans_tab[trans_tab >= min_events])
    if (length(transitions) == 0)
      stop("No transitions with >= ", min_events, " events.")
  }

  markers <- unique(long_data$biomarker)
  if (verbose) message("Biomarkers: ", paste(markers, collapse = ", "))
  if (verbose) message("Transitions: ", paste(transitions, collapse = "; "))

  ## Validate covariates
  covariates <- intersect(covariates, names(surv_data))

  ## ── Stage 1: Longitudinal submodels ──
  if (verbose) message("\n=== Stage 1: Fitting longitudinal submodels ===")
  lme_fits <- fit_longitudinal(long_data, markers, verbose = verbose)

  ## ── Stage 2: Transition-specific GAM-Cox surfaces ──
  if (verbose) message("\n=== Stage 2: Fitting transition-specific GAM-Cox models ===")

  gam_fits <- list()
  eta_data <- list()
  edf_vec <- c()
  dev_vec <- c()

  for (tr in transitions) {
    if (verbose) message("  Transition: ", tr)

    ## Build analysis dataset
    result <- .build_transition_data(surv_data, lme_fits, markers,
                                     tr, covariates)
    if (is.null(result) || nrow(result) < 20 || sum(result$status) < min_events) {
      if (verbose) message("    Skipped: insufficient data/events")
      next
    }

    ## Fit GAM-Cox
    gam_result <- fit_gam_cox(result, covariates,
                              k_marginal = k_marginal,
                              k_additive = k_additive,
                              bs = bs, method = method)

    if (!is.null(gam_result)) {
      gam_fits[[tr]] <- gam_result
      eta_data[[tr]] <- result

      sm <- .safe_summary_gam(gam_result)
      edf_val <- if (nrow(sm$s.table) > 0) sm$s.table[1, "edf"] else NA
      dev_val <- sm$dev.expl

      edf_vec[tr] <- edf_val
      dev_vec[tr] <- dev_val

      if (verbose) {
        message("    N = ", nrow(result), ", Events = ", sum(result$status),
                ", EDF = ", round(edf_val, 1),
                ", Deviance = ", round(dev_val * 100, 1), "%")
      }
    } else {
      if (verbose) message("    GAM fitting failed")
    }
  }

  if (length(gam_fits) == 0)
    stop("No transitions could be fitted. Check data and parameters.")

  ## ── Build output ──
  out <- structure(
    list(
      lme_fits = lme_fits,
      gam_fits = gam_fits,
      eta_data = eta_data,
      transitions = names(gam_fits),
      biomarkers = markers,
      covariates = covariates,
      edf = edf_vec,
      deviance_explained = dev_vec,
      k_marginal = k_marginal,
      k_additive = k_additive,
      bs = bs,
      method = method,
      n_patients = length(shared_ids),
      long_data = long_data,
      surv_data = surv_data,
      call = cl
    ),
    class = "jmSurface"
  )

  if (verbose) {
    message("\n=== Fitting complete ===")
    message("Fitted ", length(gam_fits), " transition models")
    message("EDF range: ", round(min(edf_vec, na.rm = TRUE), 1), " - ",
            round(max(edf_vec, na.rm = TRUE), 1))
  }

  out
}


## ── Internal: build transition-specific analysis dataset ──
.build_transition_data <- function(surv_data, lme_fits, markers, tr, covariates) {

  ## Normalize arrow

  tr <- gsub("\u2192", "->", tr)
  tr_parts <- strsplit(tr, " -> ")[[1]]
  if (length(tr_parts) != 2) return(NULL)
  from_state <- trimws(tr_parts[1])
  to_state <- trimws(tr_parts[2])

  ## Event rows for this transition

  event_rows <- surv_data[surv_data$transition == tr & surv_data$status == 1, ]

  ## Censored rows from the from-state
  at_risk_rows <- surv_data[surv_data$state_from == from_state &
                              surv_data$status == 0, ]

  ## Competing transitions (treated as censored)
  other_trans <- surv_data[surv_data$state_from == from_state &
                             surv_data$status == 1 &
                             surv_data$state_to != to_state, ]
  other_trans$status <- 0L

  base_cols <- c("patient_id", "start_time", "stop_time", "status")
  keep_cols <- c(base_cols, intersect(covariates, names(surv_data)))

  analysis_df <- rbind(
    event_rows[, keep_cols, drop = FALSE],
    at_risk_rows[, keep_cols, drop = FALSE],
    other_trans[, keep_cols, drop = FALSE]
  )

  ## Deduplicate
  analysis_df <- analysis_df[!duplicated(analysis_df$patient_id), ]

  ## Compute sojourn time
  analysis_df$start_time <- as.numeric(analysis_df$start_time)
  analysis_df$stop_time <- as.numeric(analysis_df$stop_time)
  analysis_df$time_in_state <- pmax(analysis_df$stop_time - analysis_df$start_time, 0.01)

  ## Midpoint evaluation of latent summaries
  analysis_df$eta_time <- (analysis_df$start_time + analysis_df$stop_time) / 2

  for (mk in markers) {
    if (is.null(lme_fits[[mk]])) next
    cm <- coef(lme_fits[[mk]])
    mk_clean <- gsub("[^A-Za-z0-9]", "", mk)
    pid_char <- as.character(analysis_df$patient_id)
    available_ids <- rownames(cm)
    matched <- pid_char %in% available_ids
    analysis_df[[paste0("eta_", mk_clean)]] <- NA_real_
    if (any(matched)) {
      analysis_df[[paste0("eta_", mk_clean)]][matched] <-
        cm[pid_char[matched], 1] + cm[pid_char[matched], 2] * analysis_df$eta_time[matched]
    }
  }

  ## Drop incomplete
  eta_cols <- grep("^eta_", names(analysis_df), value = TRUE)
  analysis_df <- analysis_df[complete.cases(analysis_df[, eta_cols, drop = FALSE]), ]

  analysis_df
}
