Skip to content
This repository was archived by the owner on Apr 16, 2025. It is now read-only.

Commit e63b1a8

Browse files
Merge pull request #153 from HodgeLab/mb/immutable
add ImmutableNonlinearProblem
2 parents 40958d2 + f7c176b commit e63b1a8

16 files changed

+132
-40
lines changed

ext/SimpleNonlinearSolveChainRulesCoreExt.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ module SimpleNonlinearSolveChainRulesCoreExt
22

33
using ChainRulesCore: ChainRulesCore, NoTangent
44
using DiffEqBase: DiffEqBase
5-
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
6-
using SimpleNonlinearSolve: SimpleNonlinearSolve
5+
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
6+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
77

88
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
99
# eventually lift this requirement using a custom adjoint
1010
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
11-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
11+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
1212
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
1313
out, ∇internal = DiffEqBase._solve_adjoint(
1414
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)

ext/SimpleNonlinearSolveReverseDiffExt.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ module SimpleNonlinearSolveReverseDiffExt
33
using ArrayInterface: ArrayInterface
44
using DiffEqBase: DiffEqBase
55
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6-
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
7-
using SimpleNonlinearSolve: SimpleNonlinearSolve
6+
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem
7+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
88
import SimpleNonlinearSolve: __internal_solve_up
99

10-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
10+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
1111
@eval begin
1212
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
1313
p::TrackedArray, p_changed, alg, args...; kwargs...)

ext/SimpleNonlinearSolveTrackerExt.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
module SimpleNonlinearSolveTrackerExt
22

33
using DiffEqBase: DiffEqBase
4-
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
5-
using SimpleNonlinearSolve: SimpleNonlinearSolve
4+
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
5+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
66
using Tracker: Tracker, TrackedArray
77

8-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
8+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
99
@eval begin
1010
function SimpleNonlinearSolve.__internal_solve_up(
1111
prob::$(pType), sensealg, u0::TrackedArray,

src/SimpleNonlinearSolve.jl

+17-4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess
1919
norm, transpose
2020
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
2121
using Reexport: @reexport
22-
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
22+
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
23+
AbstractNonlinearFunction, StandardNonlinearProblem,
2324
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
2425
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
25-
build_solution, isinplace, _unwrap_val
26+
build_solution, isinplace, _unwrap_val, warn_paramtype
2627
using Setfield: @set!
2728
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2829

@@ -35,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
3536
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
3637

3738
@inline __is_extension_loaded(::Val) = false
38-
39+
include("immutable_nonlinear_problem.jl")
3940
include("utils.jl")
4041
include("linesearch.jl")
4142

@@ -70,6 +71,18 @@ end
7071
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
7172
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
7273
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
74+
prob = convert(ImmutableNonlinearProblem, prob)
75+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
76+
sensealg = prob.kwargs[:sensealg]
77+
end
78+
new_u0 = u0 !== nothing ? u0 : prob.u0
79+
new_p = p !== nothing ? p : prob.p
80+
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
81+
p === nothing, alg, args...; prob.kwargs..., kwargs...)
82+
end
83+
84+
function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
85+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
7386
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
7487
sensealg = prob.kwargs[:sensealg]
7588
end
@@ -79,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
7992
p === nothing, alg, args...; prob.kwargs..., kwargs...)
8093
end
8194

82-
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
95+
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed,
8396
p, p_changed, alg, args...; kwargs...)
8497
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
8598
return SciMLBase.__solve(prob, alg, args...; kwargs...)

src/ad.jl

+24-13
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2-
@eval function SciMLBase.solve(
3-
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
4-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5-
alg::AbstractSimpleNonlinearSolveAlgorithm,
6-
args...;
7-
kwargs...) where {T, V, P, iip}
8-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
9-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
10-
return SciMLBase.build_solution(
11-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
12-
end
1+
function SciMLBase.solve(
2+
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
3+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
4+
alg::AbstractSimpleNonlinearSolveAlgorithm,
5+
args...;
6+
kwargs...) where {T, V, P, iip}
7+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
8+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
9+
return SciMLBase.build_solution(
10+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
11+
end
12+
13+
function SciMLBase.solve(
14+
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
15+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
16+
alg::AbstractSimpleNonlinearSolveAlgorithm,
17+
args...;
18+
kwargs...) where {T, V, P, iip}
19+
prob = convert(ImmutableNonlinearProblem, prob)
20+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
21+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
22+
return SciMLBase.build_solution(
23+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1324
end
1425

1526
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -31,7 +42,7 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
3142
end
3243

3344
function __nlsolve_ad(
34-
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
45+
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, alg, args...; kwargs...)
3546
p = value(prob.p)
3647
if prob isa IntervalNonlinearProblem
3748
tspan = value.(prob.tspan)

src/immutable_nonlinear_problem.jl

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
2+
AbstractNonlinearProblem{uType, isinplace}
3+
f::F
4+
u0::uType
5+
p::P
6+
problem_type::PT
7+
kwargs::K
8+
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
9+
p = NullParameters(),
10+
problem_type = StandardNonlinearProblem();
11+
kwargs...) where {iip}
12+
if haskey(kwargs, :p)
13+
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
14+
end
15+
warn_paramtype(p)
16+
new{typeof(u0), iip, typeof(p), typeof(f),
17+
typeof(kwargs), typeof(problem_type)}(f,
18+
u0,
19+
p,
20+
problem_type,
21+
kwargs)
22+
end
23+
24+
"""
25+
Define a steady state problem using the given function.
26+
`isinplace` optionally sets whether the function is inplace or not.
27+
This is determined automatically, but not inferred.
28+
"""
29+
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
30+
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
31+
end
32+
end
33+
34+
"""
35+
Define a nonlinear problem using an instance of
36+
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
37+
"""
38+
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
39+
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
40+
end
41+
42+
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
43+
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
44+
end
45+
46+
"""
47+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
48+
"""
49+
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
50+
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
51+
end
52+
53+
54+
function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
55+
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
56+
prob.u0,
57+
prob.p,
58+
prob.problem_type;
59+
prob.kwargs...)
60+
end
61+
62+
function DiffEqBase.get_concrete_problem(prob::ImmutableNonlinearProblem, isadapt; kwargs...)
63+
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
64+
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
65+
p = DiffEqBase.get_concrete_p(prob, kwargs)
66+
DiffEqBase.remake(prob; u0 = u0, p = p)
67+
end

src/nlsolve/broyden.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)
2424

25-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
25+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
2626
abstol = nothing, reltol = nothing, maxiters = 1000,
2727
alias_u0 = false, termination_condition = nothing, kwargs...)
2828
x = __maybe_unaliased(prob.u0, alias_u0)

src/nlsolve/dfsane.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
5454
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
5555
end
5656

57-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
57+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
5858
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
5959
termination_condition = nothing, kwargs...) where {M}
6060
x = __maybe_unaliased(prob.u0, alias_u0)

src/nlsolve/halley.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
2424
autodiff = nothing
2525
end
2626

27-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
27+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
2828
abstol = nothing, reltol = nothing, maxiters = 1000,
2929
alias_u0 = false, termination_condition = nothing, kwargs...)
3030
x = __maybe_unaliased(prob.u0, alias_u0)

src/nlsolve/klement.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
66
"""
77
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end
88

9-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
9+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
1010
abstol = nothing, reltol = nothing, maxiters = 1000,
1111
alias_u0 = false, termination_condition = nothing, kwargs...)
1212
x = __maybe_unaliased(prob.u0, alias_u0)

src/nlsolve/lbroyden.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function SimpleLimitedMemoryBroyden(;
2929
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}(alpha)
3030
end
3131

32-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
32+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
3333
args...; termination_condition = nothing, kwargs...)
3434
if prob.u0 isa SArray
3535
if termination_condition === nothing ||
@@ -44,7 +44,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyd
4444
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
4545
end
4646

47-
@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
47+
@views function __generic_solve(prob::ImmutableNonlinearProblem, alg::SimpleLimitedMemoryBroyden,
4848
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
4949
alias_u0 = false, termination_condition = nothing, kwargs...)
5050
x = __maybe_unaliased(prob.u0, alias_u0)
@@ -114,7 +114,7 @@ end
114114
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
115115
# finicky, so we'll implement it separately from the generic version
116116
# Ignore termination_condition. Don't pass things into internal functions
117-
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
117+
function __static_solve(prob::ImmutableNonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
118118
args...; abstol = nothing, maxiters = 1000, kwargs...)
119119
x = prob.u0
120120
fx = _get_fx(prob, x)

src/nlsolve/raphson.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
const SimpleGaussNewton = SimpleNewtonRaphson
2525

26-
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
26+
function SciMLBase.__solve(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
2727
alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing,
2828
maxiters = 1000, termination_condition = nothing, alias_u0 = false, kwargs...)
2929
x = __maybe_unaliased(prob.u0, alias_u0)

src/nlsolve/trustRegion.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ scalar and static array problems.
5555
nlsolve_update_rule = Val(false)
5656
end
5757

58-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args...;
58+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegion, args...;
5959
abstol = nothing, reltol = nothing, maxiters = 1000,
6060
alias_u0 = false, termination_condition = nothing, kwargs...)
6161
x = __maybe_unaliased(prob.u0, alias_u0)

src/utils.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ end
123123
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype`")
124124
return _get_fx(prob.f, x, prob.p)
125125
end
126-
@inline _get_fx(prob::NonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
126+
@inline _get_fx(prob::ImmutableNonlinearProblem, x) = _get_fx(prob.f, x, prob.p)
127127
@inline function _get_fx(f::NonlinearFunction, x, p)
128128
if isinplace(f)
129129
if f.resid_prototype !== nothing
@@ -145,7 +145,7 @@ end
145145
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
146146
# is meant for low overhead solvers, users can opt into the other termination modes but the
147147
# default is to use the least overhead version.
148-
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
148+
function init_termination_cache(prob::ImmutableNonlinearProblem, abstol, reltol, du, u, ::Nothing)
149149
return init_termination_cache(
150150
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix1(maximum, abs)))
151151
end
@@ -155,14 +155,14 @@ function init_termination_cache(
155155
prob, abstol, reltol, du, u, AbsNormTerminationMode(Base.Fix2(norm, 2)))
156156
end
157157

158-
function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
158+
function init_termination_cache(prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
159159
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
160160
T = promote_type(eltype(du), eltype(u))
161161
abstol = __get_tolerance(u, abstol, T)
162162
reltol = __get_tolerance(u, reltol, T)
163163
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
164164
internalnorm = ifelse(
165-
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
165+
prob isa ImmutableNonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
166166
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
167167
else
168168
tc

test/core/adjoint_tests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p)
1717
∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p)))
1818
∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p)
19-
19+
@test ∂p_zygote ∂p_tracker ∂p_reversediff
2020
@test ∂p_zygote ∂p_forwarddiff ∂p_tracker ∂p_reversediff
2121
end

test/gpu/cuda_tests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ end
5151
end
5252

5353
prob = NonlinearProblem{false}(f, @SVector[1.0f0, 1.0f0])
54+
prob = convert(SimpleNonlinearSolve.ImmutableNonlinearProblem, prob)
5455

5556
@testset "$(nameof(typeof(alg)))" for alg in (
5657
SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(),

0 commit comments

Comments
 (0)