gradient.adjoint
(::catalyst::gradient::AdjointOp)¶
Perform quantum AD using the adjoint method on a device.
Syntax:
operation ::= `gradient.adjoint` $callee `(` $args `)`
`size` `(` $gradSize `)`
( `in` `(` $data_in^ `:` type($data_in) `)` )?
attr-dict `:` functional-type($args, results)
Traits: AttrSizedOperandSegments
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
callee | ::mlir::SymbolRefAttr | symbol reference attribute |
Operands:¶
Operand |
Description |
---|---|
|
index |
|
variadic of any type |
|
variadic of memref of floating-point values |
Results:¶
Result |
Description |
---|---|
«unnamed» |
variadic of floating-point or ranked tensor of floating-point values |
gradient.backprop
(::catalyst::gradient::BackpropOp)¶
Perform classic automatic differentiation using Enzyme AD.
Syntax:
operation ::= `gradient.backprop` $callee `(` $args `)`
( `grad_out` `(` $diffArgShadows^ `:` type($diffArgShadows) `)` )?
( `callee_out` `(` $calleeResults^ `:` type($calleeResults) `)` )?
`cotangents` `(` $cotangents `:` type($cotangents) `)`
attr-dict `:` functional-type($args, results)
Traits: AttrSizedOperandSegments
, AttrSizedResultSegments
Interfaces: SymbolUserOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
callee | ::mlir::SymbolRefAttr | symbol reference attribute |
diffArgIndices | ::mlir::DenseIntElementsAttr | integer elements attribute |
keepValueResults | ::mlir::BoolAttr | bool attribute |
Operands:¶
Operand |
Description |
---|---|
|
variadic of any type |
|
variadic of memref of floating-point values |
|
variadic of memref of floating-point values |
|
variadic of ranked tensor of floating-point values or memref of floating-point values |
Results:¶
Result |
Description |
---|---|
|
variadic of floating-point or ranked tensor of floating-point values |
|
variadic of floating-point or ranked tensor of floating-point values |
gradient.custom_grad
(::catalyst::gradient::CustomGradOp)¶
Operation denoting the registration of the custom gradient with Enzyme.
Syntax:
operation ::= `gradient.custom_grad` $callee $forward $reverse attr-dict
A triple of three functions. The function itself, the forward and reverse pass.
Interfaces: SymbolUserOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
callee | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
forward | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
reverse | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
gradient.forward
(::catalyst::gradient::ForwardOp)¶
Operation denoting the forwrad pass that is registered with Enzyme.
Wrapper around the concrete function. This wrapper ensures calling convention.
This function matches the expected calling convention from Enzyme. Enzyme’s calling convention expects a shadow argument for every pointer. Since the callbacks all expect tensors, all of them are pointers. Also, since the callbacks passes out parameters, then these are also duplicated.
After lowered to LLVM, this function will have the following parameters:
- @foo(%inp0: !llvm.ptr, %diff0: !llvm.ptr,
… %inpArgc-1: !llvm.ptr, %diffArgc-1: !llvm.ptr, %out0: !llvm.ptr, %cotangent0: !llvm.ptr, … %outputResc-1: !llvm.ptr, %cotangentResc-1: !llvm.ptr)
The return value of enzyme is expected to be the tape. Enzyme’s documentation has the following to say:
The return type of the augmented forward pass is a struct type containing first the tape type,
followed by the original return type, if any.
If the return type is a duplicated type,
then there is a third argument which contains the shadow of the return.
Let’s just break this down a bit:
The return type of the augmented forward pass is a struct type containing first the tape type,
This means that the return type for function foo will be the following in pseudocode
%tape0Type = { memref elements }
...
%tapeTapec-1Type = { memref elements }
%tape = { %tape0Type, ... %tapeTapec-1Type }
%returnTy = { %tape, ... }
Then:
followed by the original return type, if any.
since there is none, then:
%returnTy = { %tape }
Then:
If the return type is a duplicated type,
then there is a third argument which contains the shadow of the return.
this one is also nothing to worry for the current implementation because there are no returns.
One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer.
We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null
pointer that is just never dereferenced.
Traits: IsolatedFromAbove
Interfaces: CallableOpInterface
, FunctionOpInterface
, Symbol
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
function_type | ::mlir::TypeAttr | type attribute of function type |
implementation | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
argc | ::mlir::IntegerAttr | 64-bit signless integer attribute |
resc | ::mlir::IntegerAttr | 64-bit signless integer attribute |
tape | ::mlir::IntegerAttr | 64-bit signless integer attribute |
arg_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
res_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
gradient.grad
(::catalyst::gradient::GradOp)¶
Compute the gradient of a function.
Syntax:
operation ::= `gradient.grad` $method $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
The gradient.grad
operation computes the gradient of a function
using the finite difference method.
This operation acts much like the func.call
operation, taking a
symbol reference and arguments to the original functionan as input.
However, instead of the function result, the gradient of the function
is returned.
Example:
func.func @foo(%arg0: f64) -> f64 {
%res = arith.mulf %arg0, %arg0 : f64
func.return %res : f64
}
%0 = arith.constant 2.0 : f64
%1 = gradient.grad @foo(%0) : (f64) -> f64
Interfaces: CallOpInterface
, GradientOpInterface
, SymbolUserOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
method | ::mlir::StringAttr | string attribute |
callee | ::mlir::SymbolRefAttr | symbol reference attribute |
diffArgIndices | ::mlir::DenseIntElementsAttr | integer elements attribute |
finiteDiffParam | ::mlir::FloatAttr | An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}} |
arg_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
res_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
Operands:¶
Operand |
Description |
---|---|
|
variadic of any type |
Results:¶
Result |
Description |
---|---|
«unnamed» |
variadic of floating-point or ranked tensor of floating-point values |
gradient.jvp
(::catalyst::gradient::JVPOp)¶
Compute the jvp of a function.
Syntax:
operation ::= `gradient.jvp` $method $callee `(` $params `)` `tangents` `(` $tangents `)`
attr-dict `:` functional-type(operands, results)
Traits: AttrSizedOperandSegments
, SameVariadicResultSize
Interfaces: CallOpInterface
, GradientOpInterface
, SymbolUserOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
method | ::mlir::StringAttr | string attribute |
callee | ::mlir::SymbolRefAttr | symbol reference attribute |
diffArgIndices | ::mlir::DenseIntElementsAttr | integer elements attribute |
finiteDiffParam | ::mlir::FloatAttr | An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}} |
arg_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
res_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
Operands:¶
Operand |
Description |
---|---|
|
variadic of any type |
|
variadic of any type |
Results:¶
Result |
Description |
---|---|
|
variadic of floating-point or ranked tensor of floating-point values |
|
variadic of floating-point or ranked tensor of floating-point values |
gradient.return
(::catalyst::gradient::ReturnOp)¶
Return tapes or nothing
Syntax:
operation ::= `gradient.return` attr-dict ($tape^ `:` type($tape))?
Traits: HasParent<ForwardOp, ReverseOp>
, ReturnLike
, Terminator
Interfaces: RegionBranchTerminatorOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
empty | ::mlir::IntegerAttr | 1-bit signless integer attribute |
Operands:¶
Operand |
Description |
---|---|
|
variadic of ranked tensor of any type values or memref of any type values |
gradient.reverse
(::catalyst::gradient::ReverseOp)¶
Operation denoting the reverse pass that is registered with Enzyme.
Wrapper around the concrete function. This wrapper ensures calling convention.
This matches Enzyme’s calling convention. From the documentation:
The final argument is a custom “tape” type that can be used to pass information from the forward to the reverse pass.
Experimentally, it looks like whenever there are no return values, the type passed to this function is the following type which matches the return type of the forward op, but it is somewhat ambiguous with what it says in the documentation.
%returnTy = { %tape }
Traits: IsolatedFromAbove
Interfaces: CallableOpInterface
, FunctionOpInterface
, Symbol
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
function_type | ::mlir::TypeAttr | type attribute of function type |
implementation | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
argc | ::mlir::IntegerAttr | 64-bit signless integer attribute |
resc | ::mlir::IntegerAttr | 64-bit signless integer attribute |
tape | ::mlir::IntegerAttr | 64-bit signless integer attribute |
arg_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
res_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
gradient.vjp
(::catalyst::gradient::VJPOp)¶
Compute the vjp of a function.
Syntax:
operation ::= `gradient.vjp` $method $callee `(` $params `)` `cotangents` `(` $cotangents `)`
attr-dict `:` functional-type(operands, results)
Traits: AttrSizedOperandSegments
, AttrSizedResultSegments
Interfaces: CallOpInterface
, GradientOpInterface
, SymbolUserOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
method | ::mlir::StringAttr | string attribute |
callee | ::mlir::SymbolRefAttr | symbol reference attribute |
diffArgIndices | ::mlir::DenseIntElementsAttr | integer elements attribute |
finiteDiffParam | ::mlir::FloatAttr | An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}} |
arg_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
res_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
Operands:¶
Operand |
Description |
---|---|
|
variadic of any type |
|
variadic of any type |
Results:¶
Result |
Description |
---|---|
|
variadic of floating-point or ranked tensor of floating-point values |
|
variadic of floating-point or ranked tensor of floating-point values |
gradient.value_and_grad
(::catalyst::gradient::ValueAndGradOp)¶
Compute the value and gradient of a function.
Syntax:
operation ::= `gradient.value_and_grad` $method $callee `(` $operands `)`
attr-dict `:` functional-type(operands, results)
Traits: AttrSizedResultSegments
Interfaces: CallOpInterface
, GradientOpInterface
, SymbolUserOpInterface
Attributes:¶
Attribute | MLIR Type | Description |
---|---|---|
method | ::mlir::StringAttr | string attribute |
callee | ::mlir::SymbolRefAttr | symbol reference attribute |
diffArgIndices | ::mlir::DenseIntElementsAttr | integer elements attribute |
finiteDiffParam | ::mlir::FloatAttr | An Attribute containing a floating-point value{{% markdown %}} Syntax: ``` float-attribute ::= (float-literal (`:` float-type)?) | (hexadecimal-literal `:` float-type) ``` A float attribute is a literal attribute that represents a floating point value of the specified [float type](#floating-point-types). It can be represented in the hexadecimal form where the hexadecimal value is interpreted as bits of the underlying binary representation. This form is useful for representing infinity and NaN floating point values. To avoid confusion with integer attributes, hexadecimal literals _must_ be followed by a float type to define a float attribute. Examples: ``` 42.0 // float attribute defaults to f64 type 42.0 : f32 // float attribute of f32 type 0x7C00 : f16 // positive infinity 0x7CFF : f16 // NaN (one of possible values) 42 : f32 // Error: expected integer type ``` {{% /markdown %}} |
arg_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
res_attrs | ::mlir::ArrayAttr | Array of dictionary attributes |
Operands:¶
Operand |
Description |
---|---|
|
variadic of any type |
Results:¶
Result |
Description |
---|---|
|
variadic of floating-point or ranked tensor of floating-point values |
|
variadic of floating-point or ranked tensor of floating-point values |