Skip to content

Commit

Permalink
make PD lyap solver differentiable (#853)
Browse files Browse the repository at this point in the history
* make PD lyap solver differentiable

* drop Duals in `isstable`

* mode isstable method

* update known limitations

* fix plyap matrix properties
  • Loading branch information
baggepinnen committed May 26, 2023
1 parent d9cac5a commit a82bf19
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/src/examples/automatic_differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ The following issues are currently known to exist when using AD through ControlS
- [`hinfnorm`](@ref) requires ImplicitDifferentiation.jl and ComponentArrays.jl to be manually loaded by the user, after which there are implicit differentiation rules defined for [`hinfnorm`](@ref). The implicit rule calls `opnorm`, and is thus affected by the first limitation above for MIMO systems. [`hinfnorm`](@ref) has a reverse rule defined in RobustAndOptimalControl.jl, which is not affected by this limitation.
- [`are`](@ref), [`lqr`](@ref) and [`kalman`](@ref) all require ImplicitDifferentiation.jl and ComponentArrays.jl to be manually loaded by the user, after which there are implicit differentiation rules defined. To invoke the correct method of these functions, it is important that the second matrix (corresponding to input or measurement) has the `Dual` number type, i.e., the `R` matrix in `lqr(P, Q, R)` or `lqr(Continuous, A, B, Q, R)`
- The `schur` factorization is not amenable to differentiation using ForwardDiff. This is the fundamental reason for requireing ImplicitDifferentiation.jl to differentiate through the Ricatti equation solver. `schur` is called in several additional places, including [`balreal`](@ref) and all [`lyap`](@ref) solvers. To make `schur` differentiable, an implicit differentiation rule would be required.
- An implicit rule is defined for continuous-time [`lyap`](@ref) solvers, but not yet for discrete-time solvers. No rules are defined for [`plyap`](@ref) solvers (that return the Cholesky factors). This means that [`gram`](@ref) is currently not differentiable.
- An implicit rule is defined for continuous-time [`lyap`](@ref) and [`plyap`](@ref) solvers, but not yet for discrete-time solvers. This means that [`gram`](@ref) [`covar`](@ref) and [`norm`](@ref) (``H_2``-norm) is differentiable for continuous-time systems but not for discrete.

### Reverse-mode AD
- Zygote does not work very well at all, due to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,30 @@ function ControlSystemsBase.lyap(::ContinuousType, A::AbstractMatrix, Q::Abstrac
end


# plyap
function forward_plyapc(pars)
(; A,Q) = pars
ControlSystemsBase.plyapc(A, Q), 0
end

function conditions_plyapc(pars, Xc, noneed)
(; A,Q) = pars
Q = Q*Q'
X = Xc*Xc'
AX = A*X
O = AX .+ AX' .+ Q
vec(O) + vec(Xc - UpperTriangular(Xc))
end

# linear_solver = (A, b) -> (Matrix(A) \ b, (solved=true,))
const implicit_plyapc = ImplicitFunction(forward_plyapc, conditions_plyapc)

function ControlSystemsBase.plyap(::ContinuousType, A::AbstractMatrix, Q::AbstractMatrix{<:Dual}; kwargs...)
pars = ComponentVector(; A,Q)
X0, _ = implicit_plyapc(pars)
X0 isa AbstractMatrix ? X0 : reshape(X0, size(A))
end


## Hinf norm
import ControlSystemsBase: hinfnorm
Expand Down
6 changes: 6 additions & 0 deletions lib/ControlSystemsBase/src/matrix_comps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ LinearAlgebra.lyap(::ContinuousType, args...; kwargs...) = lyapc(args...; kwargs
LinearAlgebra.lyap(::DiscreteType, args...; kwargs...) = lyapd(args...; kwargs...)
LinearAlgebra.lyap(sys::AbstractStateSpace, args...; kwargs...) = lyap(timeevol(sys), sys.A, args...; kwargs...)

"""
Xc = plyap(sys::AbstractStateSpace, Ql; kwargs...)
Lyapunov solver that takes the `L` Cholesky factor of `Q` and returns a triangular matrix `Xc` such that `Xc*Xc' = X`.
"""
plyap(sys::AbstractStateSpace, args...; kwargs...) = plyap(timeevol(sys), sys.A, args...; kwargs...)
plyap(::ContinuousType, args...; kwargs...) = MatrixEquations.plyapc(args...; kwargs...)
plyap(::DiscreteType, args...; kwargs...) = MatrixEquations.plyapd(args...; kwargs...)

Expand Down
9 changes: 9 additions & 0 deletions lib/ControlSystemsBase/src/types/StateSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ nstates(sys::AbstractStateSpace) = size(sys.A, 1)

isproper(sys::AbstractStateSpace) = iszero(sys.D)

function isstable(sys::StateSpace{Continuous, <:ForwardDiff.Dual})
# Drop duals for this check since it's not differentiable anyway
all(real.(eigvals(ForwardDiff.value.(sys.A))) .< 0)
end
function isstable(sys::StateSpace{<:Discrete, <:ForwardDiff.Dual})
all(abs.(ForwardDiff.value.(sys.A)) .< 1)
end


#####################################################################
## Math Operators ##
#####################################################################
Expand Down
39 changes: 39 additions & 0 deletions lib/ControlSystemsBase/test/test_implicit_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,45 @@ J1 = vec((J1 + J1') ./ 2)
J2 = fdgrad(difffun, q)
@test J1 J2 rtol = 1e-6

## positive definite Lyap
Ql = [1 0; 0.1 1]
function difffun(q)
Ql = LowerTriangular(copy(reshape(q, 2, 2)))
sum(ControlSystemsBase.plyap(P, Ql))
end

q = Ql |> vec
J1 = ForwardDiff.gradient(difffun, q)
J2 = fdgrad(difffun, q)
@test J1 J2 rtol = 1e-6



## covar (tests plyap)
P = ssrand(1, 2, 2, proper=true)
function difffun(q)
Q = reshape(q, 2, 2)
Q = (Q .+ Q') ./ 2 # Needed for finite diff
sum(ControlSystemsBase.covar(P, Q))
end

q = Q |> vec
J1 = ForwardDiff.gradient(difffun, q)
J2 = fdgrad(difffun, q)
@test J1 J2 rtol = 1e-6

# covar w.r.t. plant
function difffun(a)
A = copy(reshape(a, 2, 2))
sys2 = ss(A, P.B, P.C, P.D)
sum(ControlSystemsBase.covar(sys2, Q))
end

a = P.A |> vec
J1 = ForwardDiff.gradient(difffun, a)
J2 = fdgrad(difffun, a)
@test J1 J2 rtol = 1e-6


## hinfnorm

Expand Down

0 comments on commit a82bf19

Please sign in to comment.