gradient.adjoint (::catalyst::gradient::AdjointOp)

Perform quantum AD using the adjoint method on a device.

Syntax:

operation ::= `gradient.adjoint` $callee `(` $args `)`
              `size` `(` $gradSize `)`
              ( `in` `(` $data_in^ `:` type($data_in) `)` )?
              attr-dict `:` functional-type($args, results)

Traits: AttrSizedOperandSegments

Attributes:

AttributeMLIR TypeDescription
callee::mlir::SymbolRefAttrsymbol reference attribute

Operands:

Operand

Description

gradSize

index

args

variadic of any type

data_in

variadic of memref of floating-point values

Results:

Result

Description

«unnamed»

variadic of floating-point or ranked tensor of floating-point values

gradient.backprop (::catalyst::gradient::BackpropOp)

Perform classic automatic differentiation using Enzyme AD.

Syntax:

operation ::= `gradient.backprop` $callee `(` $args `)`
              ( `grad_out` `(` $diffArgShadows^ `:` type($diffArgShadows) `)` )?
              ( `callee_out` `(` $calleeResults^ `:` type($calleeResults) `)` )?
              `cotangents` `(` $cotangents `:` type($cotangents) `)`
              attr-dict `:` functional-type($args, results)

Traits: AttrSizedOperandSegments, AttrSizedResultSegments

Interfaces: SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
callee::mlir::SymbolRefAttrsymbol reference attribute
diffArgIndices::mlir::DenseIntElementsAttrinteger elements attribute
keepValueResults::mlir::BoolAttrbool attribute

Operands:

Operand

Description

args

variadic of any type

diffArgShadows

variadic of memref of floating-point values

calleeResults

variadic of memref of floating-point values

cotangents

variadic of ranked tensor of floating-point values or memref of floating-point values

Results:

Result

Description

vals

variadic of floating-point or ranked tensor of floating-point values

gradients

variadic of floating-point or ranked tensor of floating-point values

gradient.custom_grad (::catalyst::gradient::CustomGradOp)

Operation denoting the registration of the custom gradient with Enzyme.

Syntax:

operation ::= `gradient.custom_grad` $callee $forward $reverse attr-dict

A triple of three functions. The function itself, the forward and reverse pass.

Interfaces: SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
callee::mlir::FlatSymbolRefAttrflat symbol reference attribute
forward::mlir::FlatSymbolRefAttrflat symbol reference attribute
reverse::mlir::FlatSymbolRefAttrflat symbol reference attribute

gradient.forward (::catalyst::gradient::ForwardOp)

Operation denoting the forwrad pass that is registered with Enzyme.

Wrapper around the concrete function. This wrapper ensures calling convention.

This function matches the expected calling convention from Enzyme. Enzyme’s calling convention expects a shadow argument for every pointer. Since the callbacks all expect tensors, all of them are pointers. Also, since the callbacks passes out parameters, then these are also duplicated.

After lowered to LLVM, this function will have the following parameters:

@foo(%inp0: !llvm.ptr, %diff0: !llvm.ptr,

… %inpArgc-1: !llvm.ptr, %diffArgc-1: !llvm.ptr, %out0: !llvm.ptr, %cotangent0: !llvm.ptr, … %outputResc-1: !llvm.ptr, %cotangentResc-1: !llvm.ptr)

The return value of enzyme is expected to be the tape. Enzyme’s documentation has the following to say:

The return type of the augmented forward pass is a struct type containing first the tape type,
followed by the original return type, if any.
If the return type is a duplicated type,
then there is a third argument which contains the shadow of the return.

Let’s just break this down a bit:

The return type of the augmented forward pass is a struct type containing first the tape type,

This means that the return type for function foo will be the following in pseudocode

%tape0Type = { memref elements }
...
%tapeTapec-1Type = { memref elements }
%tape = { %tape0Type, ... %tapeTapec-1Type }
%returnTy = { %tape, ... }

Then:

followed by the original return type, if any.

since there is none, then:

%returnTy = { %tape }

Then:

  If the return type is a duplicated type,
  then there is a third argument which contains the shadow of the return.


this one is also nothing to worry for the current implementation because there are no returns.

One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer.
We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null
pointer that is just never dereferenced.

Traits: IsolatedFromAbove

Interfaces: CallableOpInterface, FunctionOpInterface, Symbol

Attributes:

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
function_type::mlir::TypeAttrtype attribute of function type
implementation::mlir::FlatSymbolRefAttrflat symbol reference attribute
argc::mlir::IntegerAttr64-bit signless integer attribute
resc::mlir::IntegerAttr64-bit signless integer attribute
tape::mlir::IntegerAttr64-bit signless integer attribute
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

gradient.grad (::catalyst::gradient::GradOp)

Compute the gradient of a function.

Syntax:

operation ::= `gradient.grad` $method $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)

The gradient.grad operation computes the gradient of a function using the finite difference method.

This operation acts much like the func.call operation, taking a symbol reference and arguments to the original functionan as input. However, instead of the function result, the gradient of the function is returned.

Example:

func.func @foo(%arg0: f64) -> f64 {
    %res = arith.mulf %arg0, %arg0 : f64
    func.return %res : f64
}

%0 = arith.constant 2.0 : f64
%1 = gradient.grad @foo(%0) : (f64) -> f64

Interfaces: CallOpInterface, GradientOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
method::mlir::StringAttrstring attribute
callee::mlir::SymbolRefAttrsymbol reference attribute
diffArgIndices::mlir::DenseIntElementsAttrinteger elements attribute
finiteDiffParam::mlir::FloatAttr
An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}}
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

Operands:

Operand

Description

operands

variadic of any type

Results:

Result

Description

«unnamed»

variadic of floating-point or ranked tensor of floating-point values

gradient.jvp (::catalyst::gradient::JVPOp)

Compute the jvp of a function.

Syntax:

operation ::= `gradient.jvp` $method $callee `(` $params `)` `tangents` `(` $tangents `)`
              attr-dict `:` functional-type(operands, results)

Traits: AttrSizedOperandSegments, SameVariadicResultSize

Interfaces: CallOpInterface, GradientOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
method::mlir::StringAttrstring attribute
callee::mlir::SymbolRefAttrsymbol reference attribute
diffArgIndices::mlir::DenseIntElementsAttrinteger elements attribute
finiteDiffParam::mlir::FloatAttr
An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}}
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

Operands:

Operand

Description

params

variadic of any type

tangents

variadic of any type

Results:

Result

Description

calleeResults

variadic of floating-point or ranked tensor of floating-point values

jvps

variadic of floating-point or ranked tensor of floating-point values

gradient.return (::catalyst::gradient::ReturnOp)

Return tapes or nothing

Syntax:

operation ::= `gradient.return` attr-dict ($tape^ `:` type($tape))?

Traits: HasParent<ForwardOp, ReverseOp>, ReturnLike, Terminator

Interfaces: RegionBranchTerminatorOpInterface

Attributes:

AttributeMLIR TypeDescription
empty::mlir::IntegerAttr1-bit signless integer attribute

Operands:

Operand

Description

tape

variadic of ranked tensor of any type values or memref of any type values

gradient.reverse (::catalyst::gradient::ReverseOp)

Operation denoting the reverse pass that is registered with Enzyme.

Wrapper around the concrete function. This wrapper ensures calling convention.

This matches Enzyme’s calling convention. From the documentation:

The final argument is a custom “tape” type that can be used to pass information from the forward to the reverse pass.

Experimentally, it looks like whenever there are no return values, the type passed to this function is the following type which matches the return type of the forward op, but it is somewhat ambiguous with what it says in the documentation.

%returnTy = { %tape }

Traits: IsolatedFromAbove

Interfaces: CallableOpInterface, FunctionOpInterface, Symbol

Attributes:

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
function_type::mlir::TypeAttrtype attribute of function type
implementation::mlir::FlatSymbolRefAttrflat symbol reference attribute
argc::mlir::IntegerAttr64-bit signless integer attribute
resc::mlir::IntegerAttr64-bit signless integer attribute
tape::mlir::IntegerAttr64-bit signless integer attribute
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

gradient.vjp (::catalyst::gradient::VJPOp)

Compute the vjp of a function.

Syntax:

operation ::= `gradient.vjp` $method $callee `(` $params `)` `cotangents` `(` $cotangents `)`
              attr-dict `:` functional-type(operands, results)

Traits: AttrSizedOperandSegments, AttrSizedResultSegments

Interfaces: CallOpInterface, GradientOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
method::mlir::StringAttrstring attribute
callee::mlir::SymbolRefAttrsymbol reference attribute
diffArgIndices::mlir::DenseIntElementsAttrinteger elements attribute
finiteDiffParam::mlir::FloatAttr
An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}}
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

Operands:

Operand

Description

params

variadic of any type

cotangents

variadic of any type

Results:

Result

Description

calleeResults

variadic of floating-point or ranked tensor of floating-point values

vjps

variadic of floating-point or ranked tensor of floating-point values

gradient.value_and_grad (::catalyst::gradient::ValueAndGradOp)

Compute the value and gradient of a function.

Syntax:

operation ::= `gradient.value_and_grad` $method $callee `(` $operands `)`
              attr-dict `:` functional-type(operands, results)

Traits: AttrSizedResultSegments

Interfaces: CallOpInterface, GradientOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
method::mlir::StringAttrstring attribute
callee::mlir::SymbolRefAttrsymbol reference attribute
diffArgIndices::mlir::DenseIntElementsAttrinteger elements attribute
finiteDiffParam::mlir::FloatAttr
An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}}
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

Operands:

Operand

Description

operands

variadic of any type

Results:

Result

Description

vals

variadic of floating-point or ranked tensor of floating-point values

gradients

variadic of floating-point or ranked tensor of floating-point values