time_integration.f90 Source File


Source Code

!> @file time_integration.f90
!> @brief Time integration schemes for method-of-lines ODE systems.
!!
!! SSPRK = Strong Stability Preserving Runge-Kutta (Shu & Osher, 1988).
!!
!! Supported schemes:
!!   'euler'   — Explicit Euler (1st order)
!!   'ssprk22' — SSPRK(2,2) (Shu & Osher, 1988)
!!   'rk3'     — TVD-RK3 (Shu & Osher, 1988) [default]
!!   'rk4'     — Classic RK4
!!   'ssprk54' — SSPRK(5,4) (Spiteri & Ruuth, 2002)
!!   'beuler'  — Backward Euler with Newton-Raphson (implicit, 1st order)
!!   'bdf2'    — BDF2 with Newton-Raphson (implicit, 2nd order)
!!
!! Implicit banded solver selection is controlled per-instance via
!!   state%cfg%lapack_solver = .true.  (default) — LAPACK dgbsv (pivoted, faster)
!!   state%cfg%lapack_solver = .false.           — built-in Gaussian elimination (no pivoting)

module time_integration
  use precision, only: wp
  use solver_state, only: solver_state_t, neq
  use logger, only: log_warn
  use option_registry, only: time_euler, time_ssprk22, time_rk3, time_rk4, &
                             time_ssprk54, time_beuler, time_bdf2, &
                             time_scheme_names, join_token_list
  implicit none
  private

  !> Abstract interface satisfied by every single-step stepper.
  public :: stepper_iface
  abstract interface
    subroutine stepper_iface(state)
      import :: solver_state_t
      type(solver_state_t), intent(inout) :: state
    end subroutine stepper_iface
  end interface

  !> Procedure pointer to the active time-stepping scheme.
  !! Initialised to null; set once by init_time_scheme() before the time loop.
  procedure(stepper_iface), pointer, public :: step => null()

  !> Maximum Newton-Raphson iterations per time step (beuler and bdf2).
  !! 3 iterations is sufficient for BDF2 at CFL ≤ 10; increase for stiff problems.
  integer, parameter :: n_newton = 3

  !> Newton convergence tolerance: max-norm of correction ΔQ must fall below this.
  !! 1e-10 gives full convergence well below solver tolerance for double precision.
  real(wp), parameter :: tol_newton = 1.0e-10_wp

  !> Finite-difference step size for Jacobian column approximation.
  !! ≈ √(machine ε) ≈ √(2.2e-16) ≈ 1.5e-8; balances FD truncation error
  !! (O(h)) against floating-point cancellation error (O(ε/h)).
  real(wp), parameter :: eps_jac = 1.0e-7_wp

  public :: init_time_scheme, resolve_time_scheme

contains

  ! ---------------------------------------------------------------------------
  !> Resolve a scheme name to a specific stepper procedure pointer.
  !!
  !! This is the session-safe path used by `solver_runtime`, while the legacy
  !! module-global `step` pointer remains available for unit tests and older
  !! call sites through `init_time_scheme`.
  ! ---------------------------------------------------------------------------
  subroutine resolve_time_scheme(stepper, scheme)
    procedure(stepper_iface), pointer, intent(out) :: stepper
    character(len=*), intent(in) :: scheme

    select case (trim(scheme))
    case (time_euler)
      stepper => euler_step
    case (time_ssprk22)
      stepper => ssprk22_step
    case (time_rk3)
      stepper => tvd_rk3_step
    case (time_rk4)
      stepper => rk4_step
    case (time_ssprk54)
      stepper => ssprk54_step
    case (time_beuler)
      stepper => beuler_step
    case (time_bdf2)
      stepper => bdf2_step
    case default
      error stop 'time_integration: unknown scheme "'//trim(scheme)// &
        '"; valid: '//trim(join_token_list(time_scheme_names))
    end select
  end subroutine resolve_time_scheme

  ! ---------------------------------------------------------------------------
  !> Bind the procedure pointer @p step to the requested scheme.
  !!
  !! Valid scheme names:
  !!   'euler'   — Explicit Euler (1st order, 1 stage)
  !!   'ssprk22' — SSPRK(2,2) / Heun (2nd order, 2 stages; Shu & Osher, 1988)
  !!   'rk3'     — TVD-RK3 (Shu & Osher, 1988)  [default]
  !!   'rk4'     — Classic RK4 (4th order, 4 stages; not SSP)
  !!   'ssprk54' — SSPRK(5,4) (Spiteri & Ruuth, 2002)
  !!   'beuler'  — Backward Euler (implicit, 1st order)
  !!   'bdf2'    — BDF2 (implicit, 2nd order; Gear 1971)
  ! ---------------------------------------------------------------------------
  subroutine init_time_scheme(scheme)
    character(len=*), intent(in) :: scheme

    call resolve_time_scheme(step, scheme)
  end subroutine init_time_scheme

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using the TVD-RK3 scheme.
  !!
  !! Third-order Strong Stability Preserving Runge-Kutta (Shu & Osher, 1988):
  !!
  !!   Q^(1) = Q^n + dt * R(Q^n)
  !!   Q^(2) = 3/4 * Q^n + 1/4 * Q^(1) + 1/4 * dt * R(Q^(1))
  !!   Q^{n+1} = 1/3 * Q^n + 2/3 * Q^(2) + 2/3 * dt * R(Q^(2))
  !!
  !! Also computes the global L2 residual norm for convergence monitoring.
  ! ---------------------------------------------------------------------------
  subroutine tvd_rk3_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    integer :: info

    ! Ensure scratch is allocated (once; subsequent steps reuse existing allocation)
    if (.not. allocated(state % scratch1)) then
      allocate (state % scratch1(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'tvd_rk3_step: scratch1 allocation failed'
    end if

    ! Save initial state (reuse pre-allocated scratch to avoid heap traffic)
    state % scratch1 = state % ub

    ! Stage 1
    call compute_resid(state)
    state % ub = state % scratch1 + state % dt * state % resid

    ! Stage 2
    call compute_resid(state)
    state % ub = 0.75_wp * state % scratch1 + 0.25_wp * state % ub &
        & + 0.25_wp * state % dt * state % resid

    ! Stage 3
    call compute_resid(state)
    state % ub = 1.0_wp / 3.0_wp * state % scratch1 + 2.0_wp / 3.0_wp * state % ub &
        & + 2.0_wp / 3.0_wp * state % dt * state % resid

    call compute_resid_glob(state)
  end subroutine tvd_rk3_step

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using the SSPRK(5,4) scheme.
  !!
  !! Five-stage, fourth-order SSP Runge-Kutta method.
  !! Coefficients from Spiteri & Ruuth (2002), Table 1.
  !!
  !! Also computes the global L2 residual norm for convergence monitoring.
  ! ---------------------------------------------------------------------------
  subroutine ssprk54_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    ! Spiteri-Ruuth SSPRK(5,4) coefficients
    real(wp), parameter :: a30 = 0.355909775063327_wp
    real(wp), parameter :: a32 = 0.644090224936674_wp
    real(wp), parameter :: a40 = 0.367933791638137_wp
    real(wp), parameter :: a43 = 0.632066208361863_wp
    real(wp), parameter :: a52 = 0.237593836598569_wp
    real(wp), parameter :: a54 = 0.762406163401431_wp
    real(wp), parameter :: b10 = 0.377268915331368_wp
    real(wp), parameter :: b21 = 0.377268915331368_wp
    real(wp), parameter :: b32 = 0.242995220537396_wp
    real(wp), parameter :: b43 = 0.238458932846290_wp
    real(wp), parameter :: b54 = 0.287632146308408_wp

    integer :: info

    ! Ensure scratch is allocated (once; subsequent steps reuse existing allocation)
    if (.not. allocated(state % scratch1)) then
      allocate (state % scratch1(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'ssprk54_step: scratch1 allocation failed'
    end if
    if (.not. allocated(state % scratch2)) then
      allocate (state % scratch2(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'ssprk54_step: scratch2 allocation failed'
    end if

    ! Reuse pre-allocated scratch arrays to avoid heap traffic
    state % scratch1 = state % ub

    ! Stage 1
    call compute_resid(state)
    state % ub = state % ub + b10 * state % dt * state % resid

    ! Stage 2
    call compute_resid(state)
    state % ub = state % ub + b21 * state % dt * state % resid
    state % scratch2 = state % ub

    ! Stage 3
    call compute_resid(state)
    state % ub = a30 * state % scratch1 + a32 * state % ub + b32 * state % dt * state % resid

    ! Stage 4
    call compute_resid(state)
    state % ub = a40 * state % scratch1 + a43 * state % ub + b43 * state % dt * state % resid

    ! Stage 5
    call compute_resid(state)
    state % ub = a52 * state % scratch2 + a54 * state % ub + b54 * state % dt * state % resid

    call compute_resid_glob(state)
  end subroutine ssprk54_step

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using the explicit Euler scheme.
  !!
  !! First-order, one-stage method:
  !!
  !!   Q^{n+1} = Q^n + dt * R(Q^n)
  !!
  !! No scratch arrays are required; the single-stage update is atomic.
  !! Also computes the global L2 residual norm for convergence monitoring.
  ! ---------------------------------------------------------------------------
  subroutine euler_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    call compute_resid(state)
    state % ub = state % ub + state % dt * state % resid

    call compute_resid_glob(state)
  end subroutine euler_step

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using the SSPRK(2,2) scheme.
  !!
  !! Two-stage, second-order Strong Stability Preserving Runge-Kutta
  !! (Shu & Osher, 1988; also known as Heun's method in SSP form):
  !!
  !!   Q^(1)   = Q^n + dt * R(Q^n)
  !!   Q^{n+1} = 1/2 * Q^n + 1/2 * Q^(1) + 1/2 * dt * R(Q^(1))
  !!
  !! CFL stability limit: 1.
  !! Also computes the global L2 residual norm for convergence monitoring.
  ! ---------------------------------------------------------------------------
  subroutine ssprk22_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    integer :: info

    ! Ensure scratch is allocated (once; subsequent steps reuse existing allocation)
    if (.not. allocated(state % scratch1)) then
      allocate (state % scratch1(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'ssprk22_step: scratch1 allocation failed'
    end if

    ! Save initial state Q^n (reuse pre-allocated scratch to avoid heap traffic)
    state % scratch1 = state % ub

    ! Stage 1: Q^(1) = Q^n + dt * R(Q^n)
    call compute_resid(state)
    state % ub = state % scratch1 + state % dt * state % resid

    ! Stage 2: Q^{n+1} = 1/2*Q^n + 1/2*Q^(1) + 1/2*dt*R(Q^(1))
    call compute_resid(state)
    state % ub = 0.5_wp * state % scratch1 + 0.5_wp * state % ub &
        & + 0.5_wp * state % dt * state % resid

    call compute_resid_glob(state)
  end subroutine ssprk22_step

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using the classic RK4 scheme.
  !!
  !! Four-stage, fourth-order Runge-Kutta (not SSP):
  !!
  !!   k1 = R(Q^n)
  !!   k2 = R(Q^n + dt/2 * k1)
  !!   k3 = R(Q^n + dt/2 * k2)
  !!   k4 = R(Q^n + dt   * k3)
  !!   Q^{n+1} = Q^n + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
  !!
  !! Storage: ub0 holds Q^n; k_sum accumulates (dt/6)*(k1+2*k2+2*k3+k4)
  !! incrementally after each compute_resid() call; ub is overwritten with
  !! the next stage input between calls.
  !!
  !! WARNING: Classic RK4 is not strong-stability preserving.  Near shocks
  !! it may amplify oscillations at large CFL numbers.  Keep CFL <= 1 with
  !! WENO spatial discretisation to avoid spurious artefacts.
  !! Also computes the global L2 residual norm for convergence monitoring.
  ! ---------------------------------------------------------------------------
  subroutine rk4_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    integer :: info

    ! Ensure scratch is allocated (once; subsequent steps reuse existing allocation)
    if (.not. allocated(state % scratch1)) then
      allocate (state % scratch1(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'rk4_step: scratch1 allocation failed'
    end if
    if (.not. allocated(state % scratch2)) then
      allocate (state % scratch2(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'rk4_step: scratch2 allocation failed'
    end if

    ! Save initial state Q^n; zero weighted-stage accumulator.
    ! Reuse pre-allocated scratch arrays to avoid heap traffic.
    state % scratch1 = state % ub
    state % scratch2 = 0.0_wp

    ! Stage 1: k1 = R(Q^n)
    call compute_resid(state)
    state % scratch2 = state % scratch2 + (1.0_wp / 6.0_wp) * state % dt * state % resid
    state % ub = state % scratch1 + 0.5_wp * state % dt * state % resid

    ! Stage 2: k2 = R(Q^n + dt/2 * k1)
    call compute_resid(state)
    state % scratch2 = state % scratch2 + (2.0_wp / 6.0_wp) * state % dt * state % resid
    state % ub = state % scratch1 + 0.5_wp * state % dt * state % resid

    ! Stage 3: k3 = R(Q^n + dt/2 * k2)
    call compute_resid(state)
    state % scratch2 = state % scratch2 + (2.0_wp / 6.0_wp) * state % dt * state % resid
    state % ub = state % scratch1 + state % dt * state % resid

    ! Stage 4: k4 = R(Q^n + dt * k3)
    call compute_resid(state)
    state % scratch2 = state % scratch2 + (1.0_wp / 6.0_wp) * state % dt * state % resid

    ! Final update: Q^{n+1} = Q^n + dt/6*(k1 + 2*k2 + 2*k3 + k4)
    state % ub = state % scratch1 + state % scratch2

    call compute_resid_glob(state)
  end subroutine rk4_step

  ! ---------------------------------------------------------------------------
  !> Compute banded-storage dimensions for the Newton-step Jacobian.
  !!
  !! n_dof = neq * n_pt.  The Jacobian is banded with
  !! kl = ku = neq*coupling_radius because resid(i) depends on ub(j) only when
  !! |i_cell - j_cell| <= coupling_radius.
  !!
  !! LAPACK dgbsv needs kl extra pivot rows: ldab = 2*kl+ku+1, diag_row = kl+ku+1.
  !! Built-in solver uses compact storage:   ldab = kl+ku+1,   diag_row = ku+1.
  !!
  !! @param[in]  state     Solver state (reads n_pt and cfg%lapack_solver).
  !! @param[out] n_dof     Total degrees of freedom (neq * n_pt)
  !! @param[out] kl        Lower bandwidth
  !! @param[out] ku        Upper bandwidth
  !! @param[out] ldab      Leading dimension of band matrix
  !! @param[out] diag_row  Band-storage row that holds the diagonal
  ! ---------------------------------------------------------------------------
  subroutine setup_band_storage(state, n_dof, kl, ku, ldab, diag_row)
    type(solver_state_t), intent(in) :: state
    integer, intent(out) :: n_dof, kl, ku, ldab, diag_row

    n_dof = neq * state % n_pt
    kl = neq * state % coupling_radius
    ku = neq * state % coupling_radius
    if (state % cfg % lapack_solver) then
      ldab = 2 * kl + ku + 1
      diag_row = kl + ku + 1
    else
      ldab = kl + ku + 1
      diag_row = ku + 1
    end if
  end subroutine setup_band_storage

  ! ---------------------------------------------------------------------------
  !> Banded linear solver using LAPACK dgbsv.
  !!
  !! AB must be in LAPACK band storage: AB(kl+ku+1+i-j, j) = A(i,j),
  !! leading dimension ldab = 2*kl + ku + 1 (the extra kl rows are used as
  !! pivoting workspace by LAPACK).  On exit b contains the solution x = A^{-1} b.
  !!
  !! @param[inout] ab    Banded matrix in LAPACK storage (overwritten with LU)
  !! @param[in]    kl    Lower bandwidth
  !! @param[in]    ku    Upper bandwidth
  !! @param[in]    n     Matrix order (number of unknowns)
  !! @param[inout] b     RHS on entry; solution on exit
  !! @param[out]   info  LAPACK return code (0 = success)
  ! ---------------------------------------------------------------------------
  subroutine band_lapack_solve(ab, kl, ku, n, b, info)
    integer, intent(in) :: kl, ku, n
    real(wp), intent(inout) :: ab(2 * kl + ku + 1, n)
    real(wp), intent(inout) :: b(n)
    integer, intent(out) :: info
    integer :: ipiv(n)

    ! Explicit interface for LAPACK dgbsv (double-precision banded solver).
    ! wp = real64 = double precision, so the kinds match.
    interface
      subroutine dgbsv(n_in, kl_in, ku_in, nrhs, ab_in, ldab, ipiv_in, b_in, ldb, info_out)
        integer, intent(in) :: n_in, kl_in, ku_in, nrhs, ldab, ldb
        double precision, intent(inout) :: ab_in(ldab, *)
        double precision, intent(inout) :: b_in(ldb, *)
        integer, intent(out) :: ipiv_in(*)
        integer, intent(out) :: info_out
      end subroutine dgbsv
    end interface

    call dgbsv(n, kl, ku, 1, ab, 2 * kl + ku + 1, ipiv, b, n, info)
  end subroutine band_lapack_solve

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using backward (implicit) Euler.
  !!
  !! Solves Q^{n+1} = Q^n + dt · R(Q^{n+1}) by Newton-Raphson iteration:
  !!
  !!   Q^{0} = Q^n
  !!   for k = 0 .. n_newton-1:
  !!     compute  R(Q^k)
  !!     form FD Jacobian  J = ∂R/∂Q |_{Q^k}  via column-wise perturbation
  !!     assemble  A = I - dt·J  (banded structure, kl=ku=neq·coupling_radius)
  !!     solve  A·ΔQ = dt·R(Q^k) - (Q^k - Q^n)
  !!     Q^{k+1} = Q^k + ΔQ
  !!     if max|ΔQ| < tol_newton: exit
  !!
  !! WARNING: each Newton step performs (neq·n_pt + 1) residual evaluations.
  !! This is expensive; use modest grid sizes (n_cell <= 200) when testing.
  !! The main benefit is unconditional linear stability, which permits CFL > 1.
  !!
  !! Also computes the global L2 residual norm for convergence monitoring.
  !!
  !! References:
  !!   LeVeque, "Finite Volume Methods for Hyperbolic Problems" (2002), Ch. 12.
  ! ---------------------------------------------------------------------------
  subroutine beuler_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    integer :: n_dof, kl, ku, ldab, diag_row
    integer :: iter, j, eq, cell, info_la
    real(wp) :: h, ub_save, delta_max

    real(wp), allocatable :: ub_n(:, :)       ! Q^n
    real(wp), allocatable :: resid_base(:)    ! R(Q^k) flattened (n_dof)
    real(wp), allocatable :: resid_pert(:)    ! R(Q^k + h·e_j) flattened
    real(wp), allocatable :: ab(:, :)         ! banded LHS = I - dt*J
    real(wp), allocatable :: rhs(:)           ! Newton RHS b
    real(wp), allocatable :: dq(:)            ! Newton correction ΔQ
    integer :: info

    call setup_band_storage(state, n_dof, kl, ku, ldab, diag_row)

    allocate (ub_n(neq, state % n_pt), stat=info)
    if (info /= 0) error stop 'beuler_step: allocation failed (ub_n)'
    allocate (resid_base(n_dof), resid_pert(n_dof), stat=info)
    if (info /= 0) error stop 'beuler_step: allocation failed (resid)'
    allocate (ab(ldab, n_dof), rhs(n_dof), dq(n_dof), stat=info)
    if (info /= 0) error stop 'beuler_step: allocation failed (ab/rhs/dq)'

    ub_n = state % ub   ! save Q^n

    ! Newton-Raphson: solve Q^{n+1} = Q^n + dt*R(Q^{n+1})
    newton_loop: do iter = 1, n_newton

      ! R(Q^k)
      call compute_resid(state)
      call pack_field(state % resid, resid_base, neq, state % n_pt)

      ! Finite-difference Jacobian columns.
      ! DOF layout (matches pack_field): j = (cell-1)*neq + eq, so:
      ab = 0.0_wp
      do j = 1, n_dof
        eq = mod(j - 1, neq) + 1  ! equation index [1..neq], eq varies fastest
        cell = (j - 1) / neq + 1    ! cell index [1..n_pt]

        ub_save = state % ub(eq, cell)
        h = eps_jac * max(1.0_wp, abs(ub_save))
        state % ub(eq, cell) = ub_save + h

        call compute_resid(state)
        call pack_field(state % resid, resid_pert, neq, state % n_pt)

        state % ub(eq, cell) = ub_save   ! restore

        ! Store column j of J; band row = diag_row+i-j for row i in [j-ku..j+kl]
        call jac_store_col(ab, j, (resid_pert - resid_base) / h, kl, ku, diag_row, n_dof)
      end do

      ! Build LHS: A = I - dt*J  (negate J, add 1 on diagonal)
      ab = -state % dt * ab
      do j = 1, n_dof
        ab(diag_row, j) = ab(diag_row, j) + 1.0_wp
      end do

      ! RHS: b = dt*R(Q^k) - (Q^k - Q^n)
      call pack_field(state % ub, rhs, neq, state % n_pt)
      call pack_field(ub_n, dq, neq, state % n_pt)
      rhs = state % dt * resid_base - (rhs - dq)

      ! Solve A * dq = rhs
      if (state % cfg % lapack_solver) then
        dq = rhs   ! dgbsv overwrites its b argument with the solution
        call band_lapack_solve(ab, kl, ku, n_dof, dq, info_la)
        if (info_la /= 0) error stop 'beuler_step: LAPACK dgbsv failed (singular matrix)'
      else
        ! Built-in banded Gaussian elimination (no pivoting; valid when A is
        ! diagonally dominant, i.e. dt is not too large)
        call band_lu_solve(ab, n_dof, kl, ku, rhs, dq)
      end if

      ! Q^{k+1} = Q^k + ΔQ
      delta_max = maxval(abs(dq))
      call unpack_add(dq, state % ub, neq, state % n_pt)

      if (delta_max < tol_newton) exit newton_loop
    end do newton_loop

    if (iter > n_newton) then
      block
        character(len=24) :: buf
        write (buf, '(ES12.4)') delta_max
        call log_warn('beuler_step: Newton did not converge; max_norm(dQ) = '//trim(buf))
      end block
    end if

    call compute_resid_glob(state)

    deallocate (ub_n, resid_base, resid_pert, ab, rhs, dq, stat=info)
    if (info /= 0) error stop 'beuler_step: deallocation failed'

  end subroutine beuler_step

  ! ---------------------------------------------------------------------------
  !> Advance the solution by one time step using the BDF2 (Gear) scheme.
  !!
  !! Second-order two-step implicit formula (Gear, 1971):
  !!
  !!   Q^{n+1} = 4/3 * Q^n - 1/3 * Q^{n-1} + (2/3) * dt * R(Q^{n+1})
  !!
  !! Solved via Newton-Raphson with the same banded-LU infrastructure as
  !! beuler_step().  LHS: A = I - (2/3)*dt*J.  RHS:
  !!   b = (2/3)*dt*R(Q^k) - (Q^k - 4/3*Q^n + 1/3*Q^{n-1})
  !!
  !! Bootstrap: the very first call performs one backward Euler step
  !! (A = I - dt*J, standard beuler RHS) to produce Q^1 from Q^0, storing Q^0
  !! in state%bdf2_ub_prev so that subsequent calls can use the BDF2 formula.
  !!
  !! The solver selection (LAPACK vs built-in) is controlled by state%cfg%lapack_solver.
  !!
  !! WARNING: each Newton step performs (neq*n_pt + 1) residual evaluations.
  !! Use modest grid sizes (n_cell <= 200) when testing.
  !!
  !! References:
  !!   C.W. Gear, "Numerical Initial Value Problems in Ordinary Differential
  !!   Equations," Prentice-Hall, 1971.
  ! ---------------------------------------------------------------------------
  subroutine bdf2_step(state)
    use spatial_discretization, only: compute_resid
    type(solver_state_t), intent(inout) :: state

    integer :: n_dof, kl, ku, ldab, diag_row
    real(wp) :: coeff

    real(wp), allocatable :: ub_n(:, :)       ! Q^n
    real(wp), allocatable :: ub_nm1(:, :)     ! Q^{n-1} (from state%bdf2_ub_prev)
    real(wp), allocatable :: resid_base(:)
    real(wp), allocatable :: resid_pert(:)
    real(wp), allocatable :: ab(:, :)
    real(wp), allocatable :: rhs(:)
    real(wp), allocatable :: dq(:)
    integer :: info

    call setup_band_storage(state, n_dof, kl, ku, ldab, diag_row)

    allocate (ub_n(neq, state % n_pt), ub_nm1(neq, state % n_pt), stat=info)
    if (info /= 0) error stop 'bdf2_step: allocation failed (ub_n/ub_nm1)'
    allocate (resid_base(n_dof), resid_pert(n_dof), stat=info)
    if (info /= 0) error stop 'bdf2_step: allocation failed (resid)'
    allocate (ab(ldab, n_dof), rhs(n_dof), dq(n_dof), stat=info)
    if (info /= 0) error stop 'bdf2_step: allocation failed (ab/rhs/dq)'

    ! ---- Bootstrap: first call uses backward Euler to produce Q^1 ----
    if (.not. state % bdf2_initialized) then
      allocate (state % bdf2_ub_prev(neq, state % n_pt), stat=info)
      if (info /= 0) error stop 'bdf2_step: allocation failed (bdf2_ub_prev)'

      state % bdf2_ub_prev = state % ub   ! store Q^0
      coeff = 1.0_wp                  ! backward Euler coefficient

      ub_n = state % ub
      call run_newton(state, ub_n, ub_n, coeff, n_dof, kl, ku, ldab, diag_row, &
                      ab, rhs, dq, resid_base, resid_pert)
      state % bdf2_initialized = .true.

      deallocate (ub_n, ub_nm1, resid_base, resid_pert, ab, rhs, dq, stat=info)
      if (info /= 0) error stop 'bdf2_step: deallocation failed (bootstrap)'
      call compute_resid_glob(state)
      return
    end if

    ! ---- Normal BDF2 step ----
    coeff = 2.0_wp / 3.0_wp

    ! Q^{n-1} from saved state; advance the saved pointer
    ub_nm1 = state % bdf2_ub_prev
    state % bdf2_ub_prev = state % ub       ! Q^n becomes Q^{n-1} for the next call
    ub_n = state % ub                     ! local copy of Q^n

    call run_newton(state, ub_n, ub_nm1, coeff, n_dof, kl, ku, ldab, diag_row, &
                    ab, rhs, dq, resid_base, resid_pert)

    deallocate (ub_n, ub_nm1, resid_base, resid_pert, ab, rhs, dq, stat=info)
    if (info /= 0) error stop 'bdf2_step: deallocation failed'
    call compute_resid_glob(state)

  contains

    !> Shared Newton-Raphson loop for both BDF2 and its backward Euler bootstrap.
    !!
    !! For backward Euler (coeff=1):  A = I - dt*J,
    !!   b = dt*R(Q^k) - (Q^k - Q^n)
    !! For BDF2 (coeff=2/3):  A = I - (2/3)*dt*J,
    !!   b = (2/3)*dt*R(Q^k) - (Q^k - 4/3*Q^n + 1/3*Q^{n-1})
    subroutine run_newton(st, ub_n_loc, ub_nm1_loc, coeff_loc, n_dof_loc, &
                          kl_loc, ku_loc, ldab_loc, diag_row_loc, &
                          ab_loc, rhs_loc, dq_loc, &
                          resid_base_loc, resid_pert_loc)
      type(solver_state_t), intent(inout) :: st
      real(wp), intent(in) :: ub_n_loc(neq, st % n_pt), ub_nm1_loc(neq, st % n_pt)
      real(wp), intent(in) :: coeff_loc
      integer, intent(in) :: n_dof_loc, kl_loc, ku_loc, ldab_loc, diag_row_loc
      real(wp), intent(inout) :: ab_loc(ldab_loc, n_dof_loc)
      real(wp), intent(inout) :: rhs_loc(n_dof_loc), dq_loc(n_dof_loc)
      real(wp), intent(inout) :: resid_base_loc(n_dof_loc), resid_pert_loc(n_dof_loc)

      integer :: it, jj, eq_loc, cell_loc, info_la_loc
      real(wp) :: h_loc, ub_save_loc, delta_max_loc
      real(wp) :: q_n_flat(n_dof_loc), q_nm1_flat(n_dof_loc)

      call pack_field(ub_n_loc, q_n_flat, neq, st % n_pt)
      call pack_field(ub_nm1_loc, q_nm1_flat, neq, st % n_pt)

      newton_loop: do it = 1, n_newton

        call compute_resid(st)
        call pack_field(st % resid, resid_base_loc, neq, st % n_pt)

        ab_loc = 0.0_wp
        do jj = 1, n_dof_loc
          eq_loc = mod(jj - 1, neq) + 1
          cell_loc = (jj - 1) / neq + 1

          ub_save_loc = st % ub(eq_loc, cell_loc)
          h_loc = eps_jac * max(1.0_wp, abs(ub_save_loc))
          st % ub(eq_loc, cell_loc) = ub_save_loc + h_loc

          call compute_resid(st)
          call pack_field(st % resid, resid_pert_loc, neq, st % n_pt)

          st % ub(eq_loc, cell_loc) = ub_save_loc

          call jac_store_col(ab_loc, jj, (resid_pert_loc - resid_base_loc) / h_loc, &
                             kl_loc, ku_loc, diag_row_loc, n_dof_loc)
        end do

        ! LHS: A = I - coeff*dt*J
        ab_loc = -coeff_loc * st % dt * ab_loc
        do jj = 1, n_dof_loc
          ab_loc(diag_row_loc, jj) = ab_loc(diag_row_loc, jj) + 1.0_wp
        end do

        ! RHS: b = coeff*dt*R - (Q^k - 4/3*Q^n + 1/3*Q^{n-1})
        ! (For backward Euler bootstrap coeff=1 and q_nm1_flat=q_n_flat, so
        !  the term simplifies to coeff*dt*R - (Q^k - Q^n) as expected.)
        call pack_field(st % ub, rhs_loc, neq, st % n_pt)
        rhs_loc = coeff_loc * st % dt * resid_base_loc &
                  - (rhs_loc - (4.0_wp / 3.0_wp) * q_n_flat &
                     + (1.0_wp / 3.0_wp) * q_nm1_flat)

        if (st % cfg % lapack_solver) then
          dq_loc = rhs_loc
          call band_lapack_solve(ab_loc, kl_loc, ku_loc, n_dof_loc, dq_loc, info_la_loc)
          if (info_la_loc /= 0) error stop 'bdf2_step: LAPACK dgbsv failed (singular matrix)'
        else
          call band_lu_solve(ab_loc, n_dof_loc, kl_loc, ku_loc, rhs_loc, dq_loc)
        end if

        delta_max_loc = maxval(abs(dq_loc))
        call unpack_add(dq_loc, st % ub, neq, st % n_pt)

        if (delta_max_loc < tol_newton) exit newton_loop
      end do newton_loop

      if (it > n_newton) then
        block
          character(len=24) :: buf
          write (buf, '(ES12.4)') delta_max_loc
          call log_warn('bdf2_step: Newton did not converge; max_norm(dQ) = '//trim(buf))
        end block
      end if

    end subroutine run_newton

  end subroutine bdf2_step

  ! ---------------------------------------------------------------------------
  !> Flatten a (neq x n_pt) field into a length-neq*n_pt 1-D array.
  !! Ordering: (eq=1,cell=1), (eq=2,cell=1), ..., (eq=neq,cell=n_pt).
  pure subroutine pack_field(src, dst, nq, np)
    integer, intent(in) :: nq, np
    real(wp), intent(in) :: src(nq, np)
    real(wp), intent(out) :: dst(nq * np)
    integer :: ipt, ieq
    do ipt = 1, np
      do ieq = 1, nq
        dst((ipt - 1) * nq + ieq) = src(ieq, ipt)
      end do
    end do
  end subroutine pack_field

  ! ---------------------------------------------------------------------------
  !> Add a flattened 1-D correction to a (neq x n_pt) 2-D field in place.
  subroutine unpack_add(src, dst, nq, np)
    integer, intent(in) :: nq, np
    real(wp), intent(in) :: src(nq * np)
    real(wp), intent(inout) :: dst(nq, np)
    integer :: ipt, ieq
    do ipt = 1, np
      do ieq = 1, nq
        dst(ieq, ipt) = dst(ieq, ipt) + src((ipt - 1) * nq + ieq)
      end do
    end do
  end subroutine unpack_add

  ! ---------------------------------------------------------------------------
  !> Store one Jacobian column into band storage AB(diag_row_loc+i-j, j) = J(i,j).
  !! Works for both the custom (diag_row_loc = ku+1) and LAPACK
  !! (diag_row_loc = kl+ku+1) storage layouts.
  pure subroutine jac_store_col(ab_loc, j_col, col, kl_loc, ku_loc, diag_row_loc, n)
    integer, intent(in) :: j_col, kl_loc, ku_loc, diag_row_loc, n
    real(wp), intent(in) :: col(n)
    real(wp), intent(inout) :: ab_loc(:, :)
    integer :: i
    do i = max(1, j_col - ku_loc), min(n, j_col + kl_loc)
      ab_loc(diag_row_loc + i - j_col, j_col) = col(i)
    end do
  end subroutine jac_store_col

  ! ---------------------------------------------------------------------------
  !> Banded Gaussian elimination without pivoting.
  !! Band storage: AB(ku+1+i-j, j) = A(i,j).
  !! On exit x = A^{-1} b.  The matrix ab_loc is overwritten with LU factors.
  subroutine band_lu_solve(ab_loc, n, kl_loc, ku_loc, b, x)
    integer, intent(in) :: n, kl_loc, ku_loc
    real(wp), intent(inout) :: ab_loc(kl_loc + ku_loc + 1, n)
    real(wp), intent(in) :: b(n)
    real(wp), intent(out) :: x(n)

    integer :: i, j_col, k
    real(wp) :: factor
    integer :: diag   ! band-row of diagonal = ku_loc + 1

    diag = ku_loc + 1
    x = b

    ! --- Forward elimination (LU factorisation in place) ---
    do j_col = 1, n
      do i = j_col + 1, min(n, j_col + kl_loc)
        if (abs(ab_loc(diag, j_col)) < tiny(1.0_wp)) cycle
        factor = ab_loc(diag + i - j_col, j_col) / ab_loc(diag, j_col)
        ab_loc(diag + i - j_col, j_col) = factor   ! store L multiplier
        do k = j_col + 1, min(n, j_col + ku_loc)
          ab_loc(diag + i - k, k) = ab_loc(diag + i - k, k) &
              & - factor * ab_loc(diag + j_col - k, k)
        end do
        x(i) = x(i) - factor * x(j_col)
      end do
    end do

    ! --- Back substitution ---
    do j_col = n, 1, -1
      if (abs(ab_loc(diag, j_col)) < tiny(1.0_wp)) then
        x(j_col) = 0.0_wp   ! singular pivot: zero the component
      else
        x(j_col) = x(j_col) / ab_loc(diag, j_col)
      end if
      do i = max(1, j_col - ku_loc), j_col - 1
        x(i) = x(i) - ab_loc(diag + i - j_col, j_col) * x(j_col)
      end do
    end do
  end subroutine band_lu_solve

  ! ---------------------------------------------------------------------------
  !> Accumulate the global L2 norm of the residual into state%resid_glob.
  !!
  !! resid_glob = sqrt( sum_{i,k} resid(i,k)^2 / (n_pt * neq) )
  !!
  !! Called at the end of every time step for convergence monitoring.
  ! ---------------------------------------------------------------------------
  subroutine compute_resid_glob(state)
    type(solver_state_t), intent(inout) :: state

    integer :: i, ipt

    state % resid_glob = 0.0_wp
    do ipt = 1, state % n_pt
      do i = 1, neq
        state % resid_glob = state % resid_glob + state % resid(i, ipt)**2
      end do
    end do
    state % resid_glob = sqrt(state % resid_glob / real(state % n_pt * neq, wp))
  end subroutine compute_resid_glob

end module time_integration