Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@
(r'py:.*', r'pyadjoint\..*'),
(r'py:.*', r'tsfc\..*'),
(r'py:.*', r'ufl\..*'),
(r'py:.*', r'PETSc\..*'),
(r'py:.*', r'progress\..*'),
# Ignore undocumented PyOP2
('py:class', 'pyop2.caching.Cached'),
('py:class', 'pyop2.types.mat.Mat'),
# Ignore mission docs from Firedrake internal "private" code
# Any "Base" class eg:
# firedrake.adjoint.checkpointing.CheckpointBase
Expand Down Expand Up @@ -416,6 +416,7 @@
'ufl': ('https://docs.fenicsproject.org/ufl/main/', None),
'FIAT': ('https://firedrakeproject.org/fiat', None),
'petsctools': ('https://firedrakeproject.org/petsctools/', None),
'petsc4py': ('https://petsc.org/release/petsc4py/', None),
'mpi4py': ('https://mpi4py.readthedocs.io/en/stable/', None),
'h5py': ('http://docs.h5py.org/en/latest/', None),
'h5py.h5p': ('https://api.h5py.org/', None),
Expand Down
55 changes: 31 additions & 24 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from tsfc.ufl_utils import extract_firedrake_constants
import ufl
import finat.ufl
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
from firedrake import (extrusion_utils as eutils, parameters, solving,
tsfc_interface, utils)
from firedrake.adjoint_utils import annotate_assemble
from firedrake.ufl_expr import extract_domains
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
from firedrake.matrix import MatrixBase, Matrix, ImplicitMatrix
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
from firedrake.interpolation import get_interpolator
Expand Down Expand Up @@ -238,7 +239,7 @@ def assemble(self, tensor=None, current_state=None):
if isinstance(expr, ufl.algebra.Sum):
a, b = [assemble(e) for e in expr.ufl_operands]
# Only Expr resulting in a Matrix if assembled are BaseFormOperator
if not all(isinstance(op, matrix.AssembledMatrix) for op in (a, b)):
if not all(isinstance(op, MatrixBase) for op in (a, b)):
raise TypeError('Mismatching Sum shapes')
return assemble(ufl.FormSum((a, 1), (b, 1)), tensor=tensor)
elif isinstance(expr, ufl.algebra.Product):
Expand Down Expand Up @@ -356,7 +357,7 @@ def __init__(self,
def allocate(self):
rank = len(self._form.arguments())
if rank == 2 and not self._diagonal:
if isinstance(self._form, matrix.MatrixBase):
if isinstance(self._form, MatrixBase):
return self._form
elif self._mat_type == "matfree":
return MatrixFreeAssembler(self._form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
Expand All @@ -365,9 +366,8 @@ def allocate(self):
else:
test, trial = self._form.arguments()
sparsity = ExplicitMatrixAssembler._make_sparsity(test, trial, self._mat_type, self._sub_mat_type, self.maps_and_regions)
return matrix.Matrix(self._form, self._bcs, self._mat_type, sparsity, ScalarType,
sub_mat_type=self._sub_mat_type,
options_prefix=self._options_prefix)
op2mat = op2.Mat(sparsity, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type, dtype=ScalarType)
return Matrix(self._form, op2mat, bcs=self._bcs, options_prefix=self._options_prefix, fc_params=self._form_compiler_params)
else:
raise NotImplementedError("Only implemented for rank = 2 and diagonal = False")

Expand Down Expand Up @@ -474,13 +474,14 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
# Out-of-place Hermitian transpose
mat.petscmat.hermitianTranspose(out=result)
if tensor is None:
tensor = self.assembled_matrix(expr, bcs, result)
tensor = Matrix(expr, result, bcs=bcs,
options_prefix=self._options_prefix, fc_params=self._form_compiler_params)
return tensor
elif isinstance(expr, ufl.Action):
if len(args) != 2:
raise TypeError("Not enough operands for Action")
lhs, rhs = args
if isinstance(lhs, matrix.MatrixBase):
if isinstance(lhs, MatrixBase):
if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)):
petsc_mat = lhs.petscmat
(row, col) = lhs.arguments()
Expand All @@ -489,11 +490,11 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
with rhs.dat.vec_ro as v_vec, res.dat.vec as res_vec:
petsc_mat.mult(v_vec, res_vec)
return res
elif isinstance(rhs, matrix.MatrixBase):
elif isinstance(rhs, MatrixBase):
result = tensor.petscmat if tensor else PETSc.Mat()
lhs.petscmat.matMult(rhs.petscmat, result=result)
if tensor is None:
tensor = self.assembled_matrix(expr, bcs, result)
tensor = Matrix(expr, result, bcs=bcs, options_prefix=self._options_prefix)
return tensor
else:
raise TypeError("Incompatible RHS for Action.")
Expand All @@ -503,7 +504,7 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
with lhs.dat.vec_ro as x, rhs.dat.vec_ro as y:
res = x.dot(y)
return res
elif isinstance(rhs, matrix.MatrixBase):
elif isinstance(rhs, MatrixBase):
# Compute action(Cofunc, Mat) => Mat^* @ Cofunc
petsc_mat = rhs.petscmat
(_, col) = rhs.arguments()
Expand Down Expand Up @@ -584,7 +585,8 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
op.handle.copy(result=result)
result.scale(w)
if tensor is None:
tensor = self.assembled_matrix(expr, bcs, result)
tensor = Matrix(expr.arguments(), result, bcs=bcs,
options_prefix=self._options_prefix, fc_params=self._form_compiler_params)
return tensor
else:
raise TypeError("Mismatching FormSum shapes")
Expand Down Expand Up @@ -633,10 +635,6 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
else:
raise TypeError(f"Unrecognised BaseForm instance: {expr}")

def assembled_matrix(self, expr, bcs, petscmat):
return matrix.AssembledMatrix(expr.arguments(), bcs, petscmat,
options_prefix=self._options_prefix)

@staticmethod
def base_form_postorder_traversal(expr, visitor, visited={}):
if expr in visited:
Expand Down Expand Up @@ -1381,10 +1379,12 @@ def allocate(self):
self._mat_type,
self._sub_mat_type,
self._make_maps_and_regions())
return matrix.Matrix(self._form, self._bcs, self._mat_type, sparsity, ScalarType,
sub_mat_type=self._sub_mat_type,
options_prefix=self._options_prefix,
fc_params=self._form_compiler_params)
op2mat = op2.Mat(
sparsity, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
dtype=ScalarType
)
return Matrix(self._form, op2mat, bcs=self._bcs,
fc_params=self._form_compiler_params, options_prefix=self._options_prefix)

@staticmethod
def _make_sparsity(test, trial, mat_type, sub_mat_type, maps_and_regions):
Expand Down Expand Up @@ -1583,10 +1583,17 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None,
self._appctx = appctx

def allocate(self):
return matrix.ImplicitMatrix(self._form, self._bcs,
fc_params=self._form_compiler_params,
options_prefix=self._options_prefix,
appctx=self._appctx or {})
from firedrake.matrix_free.operators import ImplicitMatrixContext
ctx = ImplicitMatrixContext(
self._form, row_bcs=self._bcs, col_bcs=self._bcs,
fc_params=self._form_compiler_params,
appctx=self._appctx
)
return ImplicitMatrix(
self._form, ctx, self._bcs,
fc_params=self._form_compiler_params,
options_prefix=self._options_prefix
)

def assemble(self, tensor=None, current_state=None):
if tensor is None:
Expand Down
13 changes: 6 additions & 7 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
from pyop2.utils import as_tuple

import firedrake
import firedrake.matrix as matrix
import firedrake.utils as utils
from firedrake import ufl_expr
from firedrake import slate
from firedrake import solving
from firedrake import ufl_expr, slate, solving
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.adjoint_utils.dirichletbc import DirichletBCMixin
from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.cofunction import Cofunction

__all__ = ['DirichletBC', 'homogenize', 'EquationBC']

Expand Down Expand Up @@ -186,8 +185,8 @@ def zero(self, r):
boundary condition should be applied.

"""
if isinstance(r, matrix.MatrixBase):
raise NotImplementedError("Zeroing bcs on a Matrix is not supported")
if not isinstance(r, Function | Cofunction):
raise NotImplementedError(f"Zeroing bcs not supported for {type(r).__name__}")

for idx in self._indices:
r = r.sub(idx)
Expand Down Expand Up @@ -411,7 +410,7 @@ def apply(self, r, u=None):
corresponding rows and columns.

"""
if isinstance(r, matrix.MatrixBase):
if isinstance(r, ufl.Matrix):
raise NotImplementedError("Capability to delay bc application has been dropped. Use assemble(a, bcs=bcs, ...) to obtain a fully assembled matrix")

fs = self._function_space
Expand Down
10 changes: 3 additions & 7 deletions firedrake/external_operators/ml_operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from firedrake.external_operators import AbstractExternalOperator, assemble_method
from firedrake.matrix import AssembledMatrix
from firedrake.matrix import Matrix


class MLOperator(AbstractExternalOperator):
Expand Down Expand Up @@ -58,20 +58,16 @@ def assemble_jacobian(self, *args, **kwargs):
"""Assemble the Jacobian using the AD engine of the ML framework."""
# Delegate computation to the ML framework.
J = self._jac()
# Set bcs
bcs = ()
return AssembledMatrix(self, bcs, J)
return Matrix(self, J)

@assemble_method(1, (1, 0))
def assemble_jacobian_adjoint(self, *args, **kwargs):
"""Assemble the Jacobian Hermitian transpose using the AD engine of the ML framework."""
# Delegate computation to the ML framework.
J = self._jac()
# Set bcs
bcs = ()
# Take the adjoint (Hermitian transpose)
J.hermitianTranspose()
return AssembledMatrix(self, bcs, J)
return Matrix(self, J)

@assemble_method(1, (0, None))
def assemble_jacobian_action(self, *args, **kwargs):
Expand Down
5 changes: 2 additions & 3 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from firedrake.functionspace import MixedFunctionSpace
from firedrake.cofunction import Cofunction
from firedrake.ufl_expr import Coargument
from firedrake.matrix import AssembledMatrix


def subspace(V, indices):
Expand Down Expand Up @@ -161,6 +160,7 @@ def cofunction(self, o):
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))

def matrix(self, o):
from firedrake.matrix import AssembledMatrix
ises = []
args = []
for a in o.arguments():
Expand All @@ -180,8 +180,7 @@ def matrix(self, o):
args.append(asplit)

submat = o.petscmat.createSubMatrix(*ises)
bcs = ()
return AssembledMatrix(tuple(args), bcs, submat)
return AssembledMatrix(tuple(args), submat)

def zero_base_form(self, o):
return ZeroBaseForm(tuple(map(self, o.arguments())))
Expand Down
23 changes: 16 additions & 7 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from firedrake.petsc import PETSc
from firedrake.halo import _get_mtype
from firedrake.functionspaceimpl import WithGeometry
from firedrake.matrix import MatrixBase, AssembledMatrix
from firedrake.matrix import ImplicitMatrix, MatrixBase, Matrix
from firedrake.matrix_free.operators import ImplicitMatrixContext
from firedrake.bcs import DirichletBC
from firedrake.formmanipulation import split_form
from firedrake.functionspace import VectorFunctionSpace, TensorFunctionSpace, FunctionSpace
Expand Down Expand Up @@ -339,8 +340,8 @@ def assemble(
specified. By default None.
mat_type
The PETSc matrix type to use when assembling a rank 2 interpolation.
For cross-mesh interpolation, only ``"aij"`` is supported. For same-mesh
interpolation, ``"aij"`` and ``"baij"`` are supported. For same/cross mesh interpolation
For cross-mesh interpolation, ``"aij"`` and ``"matfree"`` are supported. For same-mesh
interpolation, ``"aij"``, ``"baij"``, and ``"matfree"`` are supported. For same/cross mesh interpolation
between :func:`.MixedFunctionSpace`, ``"aij"`` and ``"nest"`` are supported.
For interpolation between input-ordering linked :func:`.VertexOnlyMesh`,
``"aij"``, ``"baij"``, and ``"matfree"`` are supported.
Expand All @@ -356,15 +357,23 @@ def assemble(
"""
self._check_mat_type(mat_type)

if mat_type == "matfree" and self.rank == 2:
ctx = ImplicitMatrixContext(
self.ufl_interpolate, row_bcs=bcs, col_bcs=bcs,
)
return ImplicitMatrix(self.ufl_interpolate, ctx, bcs=bcs)

result = self._get_callable(tensor=tensor, bcs=bcs, mat_type=mat_type, sub_mat_type=sub_mat_type)()

if self.rank == 2:
# Assembling the operator
assert isinstance(tensor, MatrixBase | None)
assert isinstance(result, PETSc.Mat)
if tensor:
result.copy(tensor.petscmat)
return tensor
return AssembledMatrix(self.interpolate_args, bcs, result)
else:
return Matrix(self.ufl_interpolate, result, bcs=bcs)
else:
assert isinstance(tensor, Function | Cofunction | None)
return tensor.assign(result) if tensor else result
Expand Down Expand Up @@ -591,7 +600,7 @@ def callable() -> Function | Number:

@property
def _allowed_mat_types(self):
return {"aij", None}
return {"aij", "matfree", None}


class SameMeshInterpolator(Interpolator):
Expand Down Expand Up @@ -754,7 +763,7 @@ def callable() -> Function | Cofunction | PETSc.Mat | Number:

@property
def _allowed_mat_types(self):
return {"aij", "baij", None}
return {"aij", "baij", "matfree", None}


class VomOntoVomInterpolator(SameMeshInterpolator):
Expand Down Expand Up @@ -1642,4 +1651,4 @@ def callable() -> Number:

@property
def _allowed_mat_types(self):
return {"aij", "nest", None}
return {"aij", "nest", "matfree", None}
2 changes: 1 addition & 1 deletion firedrake/linear_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.matrix import MatrixBase
from firedrake.petsc import PETSc
from firedrake.variational_solver import LinearVariationalProblem, LinearVariationalSolver

Expand Down Expand Up @@ -38,6 +37,7 @@ def __init__(self, A, *, P=None, **kwargs):
Any boundary conditions for this solve *must* have been
applied when assembling the operator.
"""
from firedrake.matrix import MatrixBase
if not isinstance(A, MatrixBase):
raise TypeError("Provided operator is a '%s', not a MatrixBase" % type(A).__name__)
if P is not None and not isinstance(P, MatrixBase):
Expand Down
Loading
Loading