From d4046bb85e82e74f297cd5ca40a5aa9f4f39a86d Mon Sep 17 00:00:00 2001
From: Henry Weller <http://cfd.direct>
Date: Thu, 24 Mar 2016 19:13:04 +0000
Subject: [PATCH] LLTMatrix, LUscalarMatrix, QRMatrix: Provided consistent
 'solve' interface

---
 applications/test/Matrix/Test-Matrix.C        | 77 ++++++++-----------
 src/OpenFOAM/matrices/LLTMatrix/LLTMatrix.C   |  6 ++
 .../matrices/LUscalarMatrix/LUscalarMatrix.H  | 14 +++-
 .../LUscalarMatrix/LUscalarMatrixTemplates.C  | 58 ++++++++++----
 src/OpenFOAM/matrices/QRMatrix/QRMatrix.C     |  2 +-
 src/OpenFOAM/matrices/QRMatrix/QRMatrix.H     |  2 +-
 .../lduMatrix/solvers/GAMG/GAMGSolverSolve.C  |  3 +-
 7 files changed, 91 insertions(+), 71 deletions(-)

diff --git a/applications/test/Matrix/Test-Matrix.C b/applications/test/Matrix/Test-Matrix.C
index 87a39386fc1..908c364d9c5 100644
--- a/applications/test/Matrix/Test-Matrix.C
+++ b/applications/test/Matrix/Test-Matrix.C
@@ -24,6 +24,7 @@ License
 \*---------------------------------------------------------------------------*/
 
 #include "scalarMatrices.H"
+#include "LUscalarMatrix.H"
 #include "LLTMatrix.H"
 #include "QRMatrix.H"
 #include "vector.H"
@@ -113,70 +114,53 @@ int main(int argc, char *argv[])
         Info<< "Solution = " << rhs << endl;
     }
 
-    {
-        scalarSquareMatrix squareMatrix(3, Zero);
 
-        squareMatrix(0, 0) = 4;
-        squareMatrix(0, 1) = 12;
-        squareMatrix(0, 2) = -16;
-        squareMatrix(1, 0) = 12;
-        squareMatrix(1, 1) = 37;
-        squareMatrix(1, 2) = -43;
-        squareMatrix(2, 0) = -16;
-        squareMatrix(2, 1) = -43;
-        squareMatrix(2, 2) = 98;
+    scalarSquareMatrix squareMatrix(3, Zero);
+
+    squareMatrix(0, 0) = 4;
+    squareMatrix(0, 1) = 12;
+    squareMatrix(0, 2) = -16;
+    squareMatrix(1, 0) = 12;
+    squareMatrix(1, 1) = 37;
+    squareMatrix(1, 2) = -43;
+    squareMatrix(2, 0) = -16;
+    squareMatrix(2, 1) = -43;
+    squareMatrix(2, 2) = 98;
 
-        const scalarSquareMatrix squareMatrixCopy = squareMatrix;
-        Info<< nl << "Square Matrix = " << squareMatrix << endl;
+    Info<< nl << "Square Matrix = " << squareMatrix << endl;
 
-        Info<< "det = " << det(squareMatrixCopy) << endl;
+    const scalarField source(3, 1);
+
+    {
+        {
+            scalarSquareMatrix sm(squareMatrix);
+            Info<< "det = " << det(sm) << endl;
+        }
 
+        scalarSquareMatrix sm(squareMatrix);
         labelList rhs(3, 0);
         label sign;
-        LUDecompose(squareMatrix, rhs, sign);
+        LUDecompose(sm, rhs, sign);
 
-        Info<< "Decomposition = " << squareMatrix << endl;
+        Info<< "Decomposition = " << sm << endl;
         Info<< "Pivots = " << rhs << endl;
         Info<< "Sign = " << sign << endl;
-        Info<< "det = " << detDecomposed(squareMatrix, sign) << endl;
+        Info<< "det = " << detDecomposed(sm, sign) << endl;
     }
 
     {
-        scalarSquareMatrix squareMatrix(3, Zero);
-
-        squareMatrix(0, 0) = 4;
-        squareMatrix(0, 1) = 12;
-        squareMatrix(0, 2) = -16;
-        squareMatrix(1, 0) = 12;
-        squareMatrix(1, 1) = 37;
-        squareMatrix(1, 2) = -43;
-        squareMatrix(2, 0) = -16;
-        squareMatrix(2, 1) = -43;
-        squareMatrix(2, 2) = 98;
-
-        scalarField source(3, 1);
+        LUscalarMatrix LU(squareMatrix);
+        scalarField x((LU.solve(source));
+        Info<< "LU solve residual " << (squareMatrix*x - source) << endl;
+    }
 
+    {
         LLTMatrix<scalar> LLT(squareMatrix);
         scalarField x(LLT.solve(source));
-
         Info<< "LLT solve residual " << (squareMatrix*x - source) << endl;
     }
 
     {
-        scalarSquareMatrix squareMatrix(3, Zero);
-
-        squareMatrix(0, 0) = 4;
-        squareMatrix(0, 1) = 12;
-        squareMatrix(0, 2) = -16;
-        squareMatrix(1, 0) = 12;
-        squareMatrix(1, 1) = 37;
-        squareMatrix(1, 2) = -43;
-        squareMatrix(2, 0) = -16;
-        squareMatrix(2, 1) = -43;
-        squareMatrix(2, 2) = 98;
-
-        scalarField source(3, 1);
-
         QRMatrix<scalarSquareMatrix> QR(squareMatrix);
         scalarField x(QR.solve(source));
 
@@ -184,8 +168,7 @@ int main(int argc, char *argv[])
             << (squareMatrix*x - source) << endl;
 
         Info<< "QR inverse solve residual "
-            << (x - QR.inverse()*source) << endl;
-
+            << (x - QR.inv()*source) << endl;
     }
 
     Info<< "\nEnd\n" << endl;
diff --git a/src/OpenFOAM/matrices/LLTMatrix/LLTMatrix.C b/src/OpenFOAM/matrices/LLTMatrix/LLTMatrix.C
index 2e8f7421d95..d0f828d4f33 100644
--- a/src/OpenFOAM/matrices/LLTMatrix/LLTMatrix.C
+++ b/src/OpenFOAM/matrices/LLTMatrix/LLTMatrix.C
@@ -95,6 +95,12 @@ void Foam::LLTMatrix<Type>::solve
     const Field<Type>& source
 ) const
 {
+    // If x and source are different initialize x = source
+    if (&x != &source)
+    {
+        x = source;
+    }
+
     const SquareMatrix<Type>& LLT = *this;
     const label m = LLT.m();
 
diff --git a/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrix.H b/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrix.H
index c8a83e5e5a8..0e0bc010136 100644
--- a/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrix.H
+++ b/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrix.H
@@ -116,10 +116,16 @@ public:
         //- Perform the LU decomposition of the matrix M
         void decompose(const scalarSquareMatrix& M);
 
-        //- Solve the matrix using the LU decomposition with pivoting
-        //  returning the solution in the source
-        template<class T>
-        void solve(Field<T>& source) const;
+        //- Solve the linear system with the given source
+        //  and returning the solution in the Field argument x.
+        //  This function may be called with the same field for x and source.
+        template<class Type>
+        void solve(Field<Type>& x, const Field<Type>& source) const;
+
+        //- Solve the linear system with the given source
+        //  returning the solution
+        template<class Type>
+        tmp<Field<Type>> solve(const Field<Type>& source) const;
 };
 
 
diff --git a/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrixTemplates.C b/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrixTemplates.C
index 4d327faeee5..605babb8523 100644
--- a/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrixTemplates.C
+++ b/src/OpenFOAM/matrices/LUscalarMatrix/LUscalarMatrixTemplates.C
@@ -24,23 +24,34 @@ License
 \*---------------------------------------------------------------------------*/
 
 #include "LUscalarMatrix.H"
+#include "SubField.H"
 
 // * * * * * * * * * * * * * * * Member Functions  * * * * * * * * * * * * * //
 
 template<class Type>
-void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
+void Foam::LUscalarMatrix::solve
+(
+    Field<Type>& x,
+    const Field<Type>& source
+) const
 {
+    // If x and source are different initialize x = source
+    if (&x != &source)
+    {
+        x = source;
+    }
+
     if (Pstream::parRun())
     {
-        Field<Type> completeSourceSol(m());
+        Field<Type> X(m());
 
         if (Pstream::master(comm_))
         {
             typename Field<Type>::subField
             (
-                completeSourceSol,
-                sourceSol.size()
-            ).assign(sourceSol);
+                X,
+                x.size()
+            ).assign(x);
 
             for
             (
@@ -55,7 +66,7 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
                     slave,
                     reinterpret_cast<char*>
                     (
-                        &(completeSourceSol[procOffsets_[slave]])
+                        &(X[procOffsets_[slave]])
                     ),
                     (procOffsets_[slave+1]-procOffsets_[slave])*sizeof(Type),
                     Pstream::msgType(),
@@ -69,8 +80,8 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
             (
                 Pstream::scheduled,
                 Pstream::masterNo(),
-                reinterpret_cast<const char*>(sourceSol.begin()),
-                sourceSol.byteSize(),
+                reinterpret_cast<const char*>(x.begin()),
+                x.byteSize(),
                 Pstream::msgType(),
                 comm_
             );
@@ -78,12 +89,12 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
 
         if (Pstream::master(comm_))
         {
-            LUBacksubstitute(*this, pivotIndices_, completeSourceSol);
+            LUBacksubstitute(*this, pivotIndices_, X);
 
-            sourceSol = typename Field<Type>::subField
+            x = typename Field<Type>::subField
             (
-                completeSourceSol,
-                sourceSol.size()
+                X,
+                x.size()
             );
 
             for
@@ -99,7 +110,7 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
                     slave,
                     reinterpret_cast<const char*>
                     (
-                        &(completeSourceSol[procOffsets_[slave]])
+                        &(X[procOffsets_[slave]])
                     ),
                     (procOffsets_[slave + 1]-procOffsets_[slave])*sizeof(Type),
                     Pstream::msgType(),
@@ -113,8 +124,8 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
             (
                 Pstream::scheduled,
                 Pstream::masterNo(),
-                reinterpret_cast<char*>(sourceSol.begin()),
-                sourceSol.byteSize(),
+                reinterpret_cast<char*>(x.begin()),
+                x.byteSize(),
                 Pstream::msgType(),
                 comm_
             );
@@ -122,9 +133,24 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
     }
     else
     {
-        LUBacksubstitute(*this, pivotIndices_, sourceSol);
+        LUBacksubstitute(*this, pivotIndices_, x);
     }
 }
 
 
+template<class Type>
+Foam::tmp<Foam::Field<Type>> Foam::LUscalarMatrix::solve
+(
+    const Field<Type>& source
+) const
+{
+    tmp<Field<Type>> tx(new Field<Type>(m()));
+    Field<Type>& x = tx.ref();
+
+    solve(x, source);
+
+    return tx;
+}
+
+
 // ************************************************************************* //
diff --git a/src/OpenFOAM/matrices/QRMatrix/QRMatrix.C b/src/OpenFOAM/matrices/QRMatrix/QRMatrix.C
index 80e33b0fb2a..165056d68ed 100644
--- a/src/OpenFOAM/matrices/QRMatrix/QRMatrix.C
+++ b/src/OpenFOAM/matrices/QRMatrix/QRMatrix.C
@@ -225,7 +225,7 @@ Foam::QRMatrix<MatrixType>::solve
 
 template<class MatrixType>
 typename Foam::QRMatrix<MatrixType>::QMatrixType
-Foam::QRMatrix<MatrixType>::inverse() const
+Foam::QRMatrix<MatrixType>::inv() const
 {
     const label m = Q_.m();
 
diff --git a/src/OpenFOAM/matrices/QRMatrix/QRMatrix.H b/src/OpenFOAM/matrices/QRMatrix/QRMatrix.H
index bbad659be26..e3dab46b4f2 100644
--- a/src/OpenFOAM/matrices/QRMatrix/QRMatrix.H
+++ b/src/OpenFOAM/matrices/QRMatrix/QRMatrix.H
@@ -108,7 +108,7 @@ public:
         tmp<Field<cmptType>> solve(const Field<cmptType>& source) const;
 
         //- Return the inverse of a square matrix
-        QMatrixType inverse() const;
+        QMatrixType inv() const;
 };
 
 
diff --git a/src/OpenFOAM/matrices/lduMatrix/solvers/GAMG/GAMGSolverSolve.C b/src/OpenFOAM/matrices/lduMatrix/solvers/GAMG/GAMGSolverSolve.C
index 1a547204a74..3e357cdbe70 100644
--- a/src/OpenFOAM/matrices/lduMatrix/solvers/GAMG/GAMGSolverSolve.C
+++ b/src/OpenFOAM/matrices/lduMatrix/solvers/GAMG/GAMGSolverSolve.C
@@ -533,8 +533,7 @@ void Foam::GAMGSolver::solveCoarsestLevel
 
     if (directSolveCoarsest_)
     {
-        coarsestCorrField = coarsestSource;
-        coarsestLUMatrixPtr_->solve(coarsestCorrField);
+        coarsestLUMatrixPtr_->solve(coarsestCorrField, coarsestSource);
     }
     //else if
     //(
-- 
GitLab