#' Cut simulated trial data at a calendar date
#'
#' Censors follow-up at a specified calendar time and aggregates events per subject.
#' Returns one row per subject randomized before the cut date, with the total
#' number of observed events and follow-up times.
#'
#' @param data Data generated by [nb_sim()].
#' @param ... Additional arguments passed to methods.
#' @param cut_date Calendar time (relative to trial start) at which to censor follow-up.
#' @param event_gap Gap duration after each event during which no new events are counted.
#'   Can be a numeric value (default `0`) or a function returning a numeric value.
#'   The time at risk is reduced by the sum of these gaps (truncated by the cut date).
#'
#' @return A data frame with one row per subject randomized prior to `cut_date` containing:
#' \describe{
#'   \item{id}{Subject identifier}
#'   \item{treatment}{Treatment group}
#'   \item{enroll_time}{Time of enrollment relative to trial start}
#'   \item{tte}{Time at risk (total follow-up minus event gap periods)}
#'   \item{tte_total}{Total follow-up time (calendar time, not adjusted for gaps)}
#'   \item{events}{Number of observed events}
#' }
#'
#' @export
#'
#' @examples
#' enroll_rate <- data.frame(rate = 20 / (5 / 12), duration = 5 / 12)
#' fail_rate <- data.frame(treatment = c("Control", "Experimental"), rate = c(0.5, 0.3))
#' dropout_rate <- data.frame(
#'   treatment = c("Control", "Experimental"),
#'   rate = c(0.1, 0.05), duration = c(100, 100)
#' )
#' sim <- nb_sim(enroll_rate, fail_rate, dropout_rate, max_followup = 2, n = 20)
#' cut_data_by_date(sim, cut_date = 1)
cut_data_by_date <- function(data, cut_date, event_gap = 0, ...) {
  UseMethod("cut_data_by_date")
}

#' @describeIn cut_data_by_date Default method.
#'
#' @return A data frame with one row per subject randomized prior to `cut_date`.
#'   This method stops with an error for unsupported classes.
#' @export
cut_data_by_date.default <- function(data, cut_date, event_gap = 0, ...) {
  stop("No cut_data_by_date() method for objects of class ", class(data)[1], call. = FALSE)
}

#' @describeIn cut_data_by_date Method for `nb_sim` data.
#'
#' @return A data frame with one row per subject randomized prior to `cut_date`.
#'   Includes total events and follow-up time within the cut window.
#' @export
cut_data_by_date.nb_sim_data <- function(data, cut_date, event_gap = 0, ...) {
  if (is.null(cut_date) || length(cut_date) != 1L || !is.finite(cut_date)) {
    stop("cut_date must be a single finite numeric value", call. = FALSE)
  }

  dt <- data.table::as.data.table(data)
  dt[, calendar_time := enroll_time + tte]
  dt <- dt[enroll_time < cut_date]
  if (nrow(dt) == 0) {
    return(data.frame(
      id = integer(0), treatment = character(0), enroll_time = numeric(0),
      tte = numeric(0), tte_total = numeric(0), events = integer(0)
    ))
  }

  calc_gap_stats <- function(events_t, limit, gap_rule) {
    events_t <- sort(events_t)
    events_t <- events_t[events_t <= limit] # Should be redundant but safe

    count <- 0L
    gap_total <- 0
    last_gap_end <- -Inf

    for (t in events_t) {
      if (t < last_gap_end) next

      count <- count + 1L
      g <- if (is.function(gap_rule)) gap_rule() else gap_rule

      gap_end <- t + g
      # Only subtract gap time within the observation window
      effective_end <- min(gap_end, limit)
      gap_dur <- max(0, effective_end - t)

      gap_total <- gap_total + gap_dur
      last_gap_end <- gap_end
    }
    list(events = count, tte = limit - gap_total)
  }

  agg <- dt[,
    {
      followup_limit <- cut_date - first(enroll_time)
      followup_limit <- max(followup_limit, 0)
      tte_in_window <- tte[calendar_time <= cut_date]
      max_tte_in_window <- if (length(tte_in_window)) max(tte_in_window) else 0
      overall_max_tte <- max(tte)
      max_tte <- min(overall_max_tte, followup_limit)
      max_tte <- max(max_tte, max_tte_in_window)

      events_vec <- tte[event == 1 & calendar_time <= cut_date]
      res <- calc_gap_stats(events_vec, max_tte, event_gap)

      list(
        treatment = first(treatment),
        enroll_time = first(enroll_time),
        tte = res$tte,
        tte_total = max_tte,
        events = as.integer(res$events)
      )
    },
    by = id
  ]

  data.table::setorder(agg, id)
  out <- as.data.frame(agg)
  out
}

#' @describeIn cut_data_by_date Method for `nb_sim_seasonal` data.
#'
#' @return A data frame with one row per subject randomized prior to `cut_date`.
#'   Includes season and follow-up time within the cut window.
#' @export
cut_data_by_date.nb_sim_seasonal <- function(data, cut_date, event_gap = 0, ...) {
  if (is.null(cut_date) || length(cut_date) != 1L || !is.finite(cut_date)) {
    stop("cut_date must be a single finite numeric value", call. = FALSE)
  }

  dt <- data.table::as.data.table(data)

  # Filter subjects not yet enrolled
  dt <- dt[enroll_time < cut_date]

  if (nrow(dt) == 0) {
    return(data.frame(
      id = integer(0), treatment = character(0), season = character(0),
      enroll_time = numeric(0), tte = numeric(0), tte_total = numeric(0), events = integer(0)
    ))
  }

  # Calculate cut time relative to randomization
  dt[, cut_rel := cut_date - enroll_time]

  # Filter intervals that start after cut date
  dt <- dt[start < cut_rel]

  # Truncate intervals that end after cut date
  dt[end > cut_rel, `:=`(end = cut_rel, event = 0)]

  # Helper to subtract gaps
  # We need to process per ID to handle cross-interval gaps
  # Events can happen in any interval.

  agg <- dt[,
    {
      # All event times for this subject
      # Note: 'end' is the event time if event=1
      event_times <- end[event == 1]

      # Calculate total duration per season, subtracting gaps
      # Intervals: [start, end]
      # Gaps: (t_evt, t_evt + gap]

      # Define gaps as a set of intervals
      gaps <- if (length(event_times) > 0) {
        g <- if (is.function(event_gap)) event_gap() else event_gap
        data.table(g_start = event_times, g_end = event_times + g)
      } else {
        NULL
      }

      # Function to calculate overlap of [s, e] with gaps
      calc_exposure <- function(s, e, g_dt) {
        dur <- e - s
        if (dur <= 0) {
          return(0)
        }
        if (is.null(g_dt) || nrow(g_dt) == 0) {
          return(dur)
        }

        # Union of gaps: overlaps?
        # Gaps generally don't overlap if gap < inter-arrival, but they might.
        # Simplified: Just integrate indicator function?
        # Or: Gaps are small, events rare.
        # Robust way: Union of gap intervals.
        # sort by start
        gs <- g_dt[order(g_start)]
        # Merge overlaps
        merged_gaps <- list()
        if (nrow(gs) > 0) {
          curr_g_start <- gs$g_start[1]
          curr_g_end <- gs$g_end[1]
          for (i in seq_len(nrow(gs))[-1]) {
            if (gs$g_start[i] < curr_g_end) {
              curr_g_end <- max(curr_g_end, gs$g_end[i])
            } else {
              merged_gaps[[length(merged_gaps) + 1]] <- c(curr_g_start, curr_g_end)
              curr_g_start <- gs$g_start[i]
              curr_g_end <- gs$g_end[i]
            }
          }
          merged_gaps[[length(merged_gaps) + 1]] <- c(curr_g_start, curr_g_end)
        }

        gap_loss <- 0
        for (mg in merged_gaps) {
          overlap_start <- max(s, mg[1])
          overlap_end <- min(e, mg[2])
          if (overlap_end > overlap_start) {
            gap_loss <- gap_loss + (overlap_end - overlap_start)
          }
        }
        max(0, dur - gap_loss)
      }

      # Apply to each row in .SD
      # .SD has columns start, end, season, event

      # We can vectorize or just loop
      # Since we are inside 'by id', N is small.
      exposures <- mapply(calc_exposure, start, end, MoreArgs = list(g_dt = gaps))
      total_durations <- end - start

      list(
        events = sum(event),
        tte = sum(exposures),
        tte_total = sum(total_durations)
      )
    },
    by = .(id, treatment, enroll_time, season)
  ]

  data.table::setorder(agg, id, season)
  as.data.frame(agg)
}
