Automatic Differentiation
Hankel
implements the primitives defined by ChainRules
for automatic differentiation (AD). These enables all AD packages that use ChainRules
' rules to differentiate the exported functions.
Here is an example of reverse-mode automatic differentiation using Zygote
. To run this example, first call
julia> using Pkg
julia> Pkg.add(Zygote)
Then call the following:
julia> using Hankel, Zygote
┌ Warning: the implicit keyword argument `filter_modules=(:Base, :SpecialFunctions, :NaNMath)` in `diffrules()` is deprecated and will be changed to `filter_modules=nothing` in an upcoming breaking release of DiffRules (i.e., `diffrules()` will return all rules defined in DiffRules) │ caller = top-level scope at number.jl:6 └ @ Core ~/.julia/packages/Zygote/ggM8Z/src/forward/number.jl:6 ┌ Warning: the implicit keyword argument `filter_modules=(:Base, :SpecialFunctions, :NaNMath)` in `diffrules()` is deprecated and will be changed to `filter_modules=nothing` in an upcoming breaking release of DiffRules (i.e., `diffrules()` will return all rules defined in DiffRules) │ caller = top-level scope at number.jl:14 └ @ Core ~/.julia/packages/Zygote/ggM8Z/src/forward/number.jl:14 WARNING: importing deprecated binding ChainRulesCore.Composite into ChainRules. WARNING: ChainRulesCore.Composite is deprecated, use Tangent instead. likely near /home/runner/.julia/packages/Zygote/ggM8Z/src/compiler/chainrules.jl:46 WARNING: ChainRulesCore.Composite is deprecated, use Tangent instead. likely near /home/runner/.julia/packages/Zygote/ggM8Z/src/compiler/chainrules.jl:46 WARNING: importing deprecated binding ChainRulesCore.AbstractDifferential into ChainRules. WARNING: importing deprecated binding ChainRulesCore.DoesNotExist into ChainRules. WARNING: importing deprecated binding ChainRulesCore.Zero into ChainRules.
julia> R, N = 10.0, 10
(10.0, 10)
julia> q = QDHT{0,1}(R, N);
julia> f(r) = exp(-r^2 / 2);
julia> fk = q * f.(q.r)
10-element Array{Float64,1}: 0.97149813153671 0.8586822523591615 0.6876776056939765 0.4989735437926904 0.32802445451636614 0.1953704406072909 0.10540205038671452 0.05143113892374989 0.022439124159191845 0.00795907967241482
julia> # Compute the function and a pullback function for computing the gradient I, back = Zygote.pullback(fk -> integrateR(q \ fk, q), fk);
ERROR: MethodError: no method matching _methods_by_ftype(::Type{Tuple{typeof(ChainRulesCore.rrule),Main.var"#1#2",Array{Float64,1}}}, ::Int64, ::UInt64, ::Bool, ::Base.RefValue{UInt64}, ::Base.RefValue{UInt64}, ::Ptr{Int32}) Closest candidates are: _methods_by_ftype(::Any, ::Int64, ::UInt64) at reflection.jl:838 _methods_by_ftype(::Any, ::Int64, ::UInt64, !Matched::Array{UInt64,1}, !Matched::Array{UInt64,1}) at reflection.jl:841
julia> I
ERROR: UndefVarError: I not defined
julia> Igrad = only(back(1)) # Compute the gradient
ERROR: UndefVarError: back not defined
This example computes the gradient of the real space integral of the function f
with respect to each sampled point in the reciprocal space.
Pushforwards and Pullbacks
For a summary of ChainRules
' primitives and an introduction to the terminology used here, see the ChainRules
docs. We define custom rules for 2 reasons:
- Many AD packages in Julia do not completely support mutating arrays. Since our internal functions mutate, we need custom rules.
- While for vector samples, the functional form of the pushforwards and pullbacks are simple (see below), they are more complicated for multi-dimensional samples. Providing our own rules helps the AD system sidestep this difficulty.
The QDHT
constructor
The QDHT
objects are intended to be used like Plan
s in AbstractFFTs
, defined once and then potentially used many times. Consequently, we define its constructor as non-differentiable with respect to its inputs.
The transform
The quasi-discrete Hankel transform of a vector can be written in component form as
\[\tilde{f}(k_i) = s \sum_{j=1}^N T_{ij} f(r_j),\]
where $s = \left(\frac{R}{K}\right)^{(n + 1) / 2}$
The pushforward of the transform is written as
\[\dot{\tilde{f}}(k_i) = s \sum_{j=1}^N T_{ij} \dot{f}(r_j).\]
That is, the pushforward is just the transform itself. The transform's pullback is
\[\overline{f}(r_j) = s \sum_{i=1}^N T_{ij} \overline{\tilde{f}}(k_i).\]
The pushforwards and pullbacks of the inverse transform are similar, where the scalar $s$ is inverted.
Integration
Integration using the QDHT can be written as
\[I = \sum_{i=1}^N v_i f(i),\]
where the elements of $v_i$, which are precomputed, are given in Integration of functions. The pushforward is then
\[\dot{I} = \sum_{i=1}^N v_i \dot{f}(r_i),\]
and the pullback is
\[\overline{f}(r_i) = \overline{I} v_i.\]