From b2bbc15463411a24a93ab2878e3d5eb5d3a5c6fe Mon Sep 17 00:00:00 2001
From: Mark Olesen <Mark.Olesen@esi-group.com>
Date: Thu, 25 Apr 2024 18:54:22 +0200
Subject: [PATCH] ENH: use MPI Get_elements_x() for message sizes (#3152)

- ensures more accurate values for message sizes than using
  MPI Get_count(), which trucates at INT_MAX

- add more/better error messages when trying to receive messages
  that exceed INT_MAX or the char buffer lengths
---
 src/Pstream/mpi/UIPstreamRead.C | 60 +++++++++++++++++++++++++++++++++
 src/Pstream/mpi/UPstream.C      | 32 +++++++++++++++++-
 2 files changed, 91 insertions(+), 1 deletion(-)

diff --git a/src/Pstream/mpi/UIPstreamRead.C b/src/Pstream/mpi/UIPstreamRead.C
index c8e6de0d747..2a690f4b9f8 100644
--- a/src/Pstream/mpi/UIPstreamRead.C
+++ b/src/Pstream/mpi/UIPstreamRead.C
@@ -35,6 +35,8 @@ License
 // - as of 2023-06 appears to be broken with INTELMPI + PMI-2 (slurm)
 //   and perhaps other places so currently avoid
 
+#undef Pstream_use_MPI_Get_count
+
 // * * * * * * * * * * * * * * * Local Functions * * * * * * * * * * * * * * //
 
 // General blocking/non-blocking MPI receive
@@ -128,8 +130,30 @@ static std::streamsize UPstream_mpi_receive
         }
 
         // Check size of message read
+        #ifdef Pstream_use_MPI_Get_count
         int count(0);
         MPI_Get_count(&status, MPI_BYTE, &count);
+        #else
+        MPI_Count count(0);
+        MPI_Get_elements_x(&status, MPI_BYTE, &count);
+        #endif
+
+        // Errors
+        if (count == MPI_UNDEFINED || int64_t(count) < 0)
+        {
+            FatalErrorInFunction
+                << "MPI_Get_count() or MPI_Get_elements_x() : "
+                   "returned undefined or negative value"
+                << Foam::abort(FatalError);
+        }
+        else if (int64_t(count) > int64_t(UList<char>::max_size()))
+        {
+            FatalErrorInFunction
+                << "MPI_Get_count() or MPI_Get_elements_x() : "
+                   "count is larger than UList<char>::max_size() bytes"
+                << Foam::abort(FatalError);
+        }
+
 
         if (bufSize < std::streamsize(count))
         {
@@ -240,8 +264,30 @@ void Foam::UIPstream::bufferIPCrecv()
 
         profilingPstream::addProbeTime();
 
+
+        #ifdef Pstream_use_MPI_Get_count
         int count(0);
         MPI_Get_count(&status, MPI_BYTE, &count);
+        #else
+        MPI_Count count(0);
+        MPI_Get_elements_x(&status, MPI_BYTE, &count);
+        #endif
+
+        // Errors
+        if (count == MPI_UNDEFINED || int64_t(count) < 0)
+        {
+            FatalErrorInFunction
+                << "MPI_Get_count() or MPI_Get_elements_x() : "
+                   "returned undefined or negative value"
+                << Foam::abort(FatalError);
+        }
+        else if (int64_t(count) > int64_t(UList<char>::max_size()))
+        {
+            FatalErrorInFunction
+                << "MPI_Get_count() or MPI_Get_elements_x() : "
+                   "count is larger than UList<char>::max_size() bytes"
+                << Foam::abort(FatalError);
+        }
 
         if (UPstream::debug)
         {
@@ -264,6 +310,20 @@ void Foam::UIPstream::bufferIPCrecv()
         nullptr   // UPstream::Request
     );
 
+    if (count < 0)
+    {
+        FatalErrorInFunction
+            << "MPI_recv() with negative size??"
+            << Foam::abort(FatalError);
+    }
+    else if (int64_t(count) > int64_t(UList<char>::max_size()))
+    {
+        FatalErrorInFunction
+            << "MPI_recv() larger than "
+                "UList<char>::max_size() bytes"
+            << Foam::abort(FatalError);
+    }
+
     // Set addressed size. Leave actual allocated memory intact.
     recvBuf_.resize(label(count));
     messageSize_ = label(count);
diff --git a/src/Pstream/mpi/UPstream.C b/src/Pstream/mpi/UPstream.C
index a9f19d3b3f7..b23c85d66c2 100644
--- a/src/Pstream/mpi/UPstream.C
+++ b/src/Pstream/mpi/UPstream.C
@@ -40,6 +40,8 @@ License
 #include <numeric>
 #include <string>
 
+#undef Pstream_use_MPI_Get_count
+
 // * * * * * * * * * * * * * * Static Data Members * * * * * * * * * * * * * //
 
 // The min value and default for MPI buffer length
@@ -838,8 +840,36 @@ Foam::UPstream::probeMessage
 
     if (flag)
     {
+        // Unlikely to be used with large amounts of data,
+        // but use MPI_Get_elements_x() instead of MPI_Count() anyhow
+
+        #ifdef Pstream_use_MPI_Get_count
+        int count(0);
+        MPI_Get_count(&status, MPI_BYTE, &count);
+        #else
+        MPI_Count count(0);
+        MPI_Get_elements_x(&status, MPI_BYTE, &count);
+        #endif
+
+        // Errors
+        if (count == MPI_UNDEFINED || int64_t(count) < 0)
+        {
+            FatalErrorInFunction
+                << "MPI_Get_count() or MPI_Get_elements_x() : "
+                   "returned undefined or negative value"
+                << Foam::abort(FatalError);
+        }
+        else if (int64_t(count) > int64_t(INT_MAX))
+        {
+            FatalErrorInFunction
+                << "MPI_Get_count() or MPI_Get_elements_x() : "
+                   "count is larger than INI_MAX bytes"
+                << Foam::abort(FatalError);
+        }
+
+
         result.first = status.MPI_SOURCE;
-        MPI_Get_count(&status, MPI_BYTE, &result.second);
+        result.second = int(count);
     }
 
     return result;
-- 
GitLab