checkpoint.f90 Source File


Source Code

!> @file checkpoint.f90
!> @brief Checkpoint write and read for the 1D Euler solver.
!!
!! Checkpoints are Fortran unformatted (stream) binary files that capture
!! everything needed to resume a run exactly:
!!
!!   - A magic integer (42) and version (1) for format identification.
!!   - Scalars: iter, t, n_pt, neq, dt.
!!   - The conserved-variable array ub(neq, n_pt).
!!   - An optional BDF2 previous-step array bdf2_ub_prev(neq, n_pt),
!!     preceded by a presence flag (integer 1 = present, 0 = absent).
!!
!! File naming: `<base>_NNNNNN.bin` where NNNNNN is the zero-padded
!! iteration number.  A companion text file `latest_checkpoint` (written
!! alongside) records the path of the most recent checkpoint so that
!! restart drivers need not track the iteration themselves.
!!
!! Typical usage in the driver:
!! @code
!!   use checkpoint, only: write_checkpoint, read_checkpoint
!!   ! --- writing ---
!!   call write_checkpoint(state, cfg%checkpoint_file, t, iter)
!!   ! --- reading ---
!!   call read_checkpoint(state, cfg%restart_file, t, iter)
!! @endcode

module checkpoint
  use precision, only: wp
  use solver_state, only: solver_state_t, neq
  use logger, only: log_info, log_warn
  use iso_fortran_env, only: int32, int64
  implicit none
  private
  public :: write_checkpoint, read_checkpoint

  integer, parameter :: CKPT_MAGIC = 42   !< File-identification magic number
  integer, parameter :: CKPT_VERSION = 1    !< Format version

contains

  ! ---------------------------------------------------------------------------
  !> Write a checkpoint file for the current solver state.
  !!
  !! The file is written as an unformatted stream binary.  After a successful
  !! write, the plain text file `latest_checkpoint` is updated to the new
  !! file's path.
  !!
  !! @param[in] state   Current solver state (must have ub allocated).
  !! @param[in] base    Base name (e.g. 'checkpoint'); file = base_NNNNNN.bin.
  !! @param[in] t       Current simulation time [s].
  !! @param[in] iter    Current iteration number.
  subroutine write_checkpoint(state, base, t, iter, is_ok, message)
    type(solver_state_t), intent(in) :: state
    character(len=*), intent(in) :: base
    real(wp), intent(in) :: t
    integer, intent(in) :: iter
    logical, intent(out), optional :: is_ok
    character(len=*), intent(out), optional :: message

    character(len=512) :: fname
    integer :: u, info
    integer :: has_bdf2

    if (present(is_ok)) is_ok = .true.
    if (present(message)) message = ''

    write (fname, '(A,A,I6.6,A)') trim(base), '_', iter, '.bin'

    open (newunit=u, file=trim(fname), status='replace', action='write', &
          form='unformatted', access='stream', iostat=info)
    if (info /= 0) then
      call log_warn('checkpoint: cannot open "'//trim(fname)//'" for writing')
      if (present(is_ok)) is_ok = .false.
      if (present(message)) message = 'checkpoint: cannot open "'//trim(fname)//'" for writing'
      return
    end if

    ! Header
    write (u) CKPT_MAGIC, CKPT_VERSION

    ! Scalars
    write (u) iter, t, state % n_pt, neq, state % dt

    ! Primary solution
    write (u) state % ub

    ! Optional BDF2 previous-step array
    if (allocated(state % bdf2_ub_prev)) then
      has_bdf2 = 1
      write (u) has_bdf2
      write (u) state % bdf2_ub_prev
    else
      has_bdf2 = 0
      write (u) has_bdf2
    end if

    close (u, iostat=info)
    if (info /= 0) then
      call log_warn('checkpoint: close failed for "'//trim(fname)//'"')
      if (present(is_ok)) is_ok = .false.
      if (present(message)) message = 'checkpoint: close failed for "'//trim(fname)//'"'
      return
    end if

    call log_info('checkpoint: wrote "'//trim(fname)//'"')
    call update_latest(fname, is_ok, message)

  end subroutine write_checkpoint

  ! ---------------------------------------------------------------------------
  !> Read a checkpoint file and restore solver state.
  !!
  !! Validates the magic number, version, grid size, and equation count
  !! against the current state to catch mismatches early.
  !!
  !! @param[in,out] state   Solver state with ub already allocated (from setup_solver).
  !! @param[in]     fname   Path to the checkpoint file (or 'latest_checkpoint').
  !! @param[out]    t       Restored simulation time [s].
  !! @param[out]    iter    Restored iteration count.
  subroutine read_checkpoint(state, fname, t, iter, is_ok, message)
    type(solver_state_t), intent(inout) :: state
    character(len=*), intent(in) :: fname
    real(wp), intent(out) :: t
    integer, intent(out) :: iter
    logical, intent(out), optional :: is_ok
    character(len=*), intent(out), optional :: message

    character(len=512) :: actual_file
    integer :: u, info
    integer :: magic, version, n_pt_ck, neq_ck, has_bdf2
    real(wp) :: dt_ck
    logical :: ok
    character(len=256) :: err

    if (present(is_ok)) is_ok = .true.
    if (present(message)) message = ''

    ! If fname is the pointer file, read the real path from it
    actual_file = resolve_latest(fname, ok, err)
    if (.not. ok) then
      if (present(is_ok)) is_ok = .false.
      if (present(message)) message = trim(err)
      if (.not. present(is_ok) .and. .not. present(message)) error stop trim(err)
      return
    end if

    open (newunit=u, file=trim(actual_file), status='old', action='read', &
          form='unformatted', access='stream', iostat=info)
    if (info /= 0) then
      call fail_read('checkpoint: cannot open restart file "'//trim(actual_file)//'"')
      return
    end if

    ! Header
    read (u) magic, version
    if (magic /= CKPT_MAGIC) then
      call fail_read('checkpoint: bad magic number — not a valid checkpoint file')
      return
    end if
    if (version /= CKPT_VERSION) then
      call fail_read('checkpoint: unsupported checkpoint version')
      return
    end if

    ! Scalars
    read (u) iter, t, n_pt_ck, neq_ck, dt_ck

    if (n_pt_ck /= state % n_pt) then
      call fail_read('checkpoint: n_pt mismatch between checkpoint and current grid')
      return
    end if
    if (neq_ck /= neq) then
      call fail_read('checkpoint: neq mismatch in checkpoint file')
      return
    end if

    state % dt = dt_ck

    ! Primary solution
    read (u) state % ub

    ! Optional BDF2 previous-step
    read (u) has_bdf2
    if (has_bdf2 == 1) then
      if (.not. allocated(state % bdf2_ub_prev)) &
        allocate (state % bdf2_ub_prev(neq, state % n_pt))
      read (u) state % bdf2_ub_prev
    end if

    close (u, iostat=info)
    if (info /= 0) &
      call log_warn('checkpoint: close failed for "'//trim(actual_file)//'"')

    call log_info('checkpoint: resumed from "'//trim(actual_file)//'"')

  contains

    subroutine fail_read(err_msg)
      character(len=*), intent(in) :: err_msg
      integer :: close_info

      close (u, iostat=close_info)
      if (present(is_ok)) is_ok = .false.
      if (present(message)) message = trim(err_msg)
      if (.not. present(is_ok) .and. .not. present(message)) error stop trim(err_msg)
    end subroutine fail_read

  end subroutine read_checkpoint

  ! ---------------------------------------------------------------------------
  ! Write the path of the most recent checkpoint to `latest_checkpoint`.
  ! ---------------------------------------------------------------------------
  subroutine update_latest(fname, is_ok, message)
    character(len=*), intent(in) :: fname
    logical, intent(out), optional :: is_ok
    character(len=*), intent(out), optional :: message
    integer :: u, info

    if (present(is_ok)) is_ok = .true.
    if (present(message)) message = ''

    open (newunit=u, file='latest_checkpoint', status='replace', &
          action='write', form='formatted', iostat=info)
    if (info /= 0) then
      call log_warn('checkpoint: cannot update latest_checkpoint pointer file')
      if (present(is_ok)) is_ok = .false.
      if (present(message)) message = 'checkpoint: cannot update latest_checkpoint pointer file'
      return
    end if
    write (u, '(A)') trim(fname)
    close (u)

  end subroutine update_latest

  ! ---------------------------------------------------------------------------
  ! If fname == 'latest_checkpoint', read the real path from that file;
  ! otherwise return fname unchanged.
  ! ---------------------------------------------------------------------------
  function resolve_latest(fname, is_ok, message) result(actual)
    character(len=*), intent(in) :: fname
    logical, intent(out), optional :: is_ok
    character(len=*), intent(out), optional :: message
    character(len=512) :: actual
    integer :: u, info

    if (present(is_ok)) is_ok = .true.
    if (present(message)) message = ''

    if (trim(fname) == 'latest_checkpoint') then
      open (newunit=u, file='latest_checkpoint', status='old', &
            action='read', form='formatted', iostat=info)
      if (info /= 0) then
        if (present(is_ok)) is_ok = .false.
        if (present(message)) message = 'checkpoint: cannot open latest_checkpoint pointer file'
        if (.not. present(is_ok) .and. .not. present(message)) then
          error stop 'checkpoint: cannot open latest_checkpoint pointer file'
        end if
        actual = ''
        return
      end if
      read (u, '(A)', iostat=info) actual
      if (info /= 0) then
        close (u)
        if (present(is_ok)) is_ok = .false.
        if (present(message)) message = 'checkpoint: cannot read latest_checkpoint pointer file'
        if (.not. present(is_ok) .and. .not. present(message)) then
          error stop 'checkpoint: cannot read latest_checkpoint pointer file'
        end if
        actual = ''
        return
      end if
      close (u)
      actual = trim(actual)
    else
      actual = trim(fname)
    end if

  end function resolve_latest

end module checkpoint