-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support functions that return scalars #51
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #51 +/- ##
==========================================
+ Coverage 93.58% 93.67% +0.08%
==========================================
Files 5 5
Lines 78 79 +1
==========================================
+ Hits 73 74 +1
Misses 5 5
☔ View full report in Codecov by Sentry. |
Thanks for the contribution! Not gonna merge it right away because I still struggle to make it work for reverse mode. |
The empty getindex |
@mohamed82008 this PR is a bit of a mess and I could use some help. Basically, scalar outputs work in forward mode but it's much harder in reverse mode to juggle the arrays. |
let me take a look |
Let's take advantage of |
even better struct ScalarImplicitFunction{F}
f::F
end
function ScalarImplicitFunction(forward, conditions, linear_solver)
_forward = x -> begin
y, z = forward(only(x))
return @SVector([y]), z
end
_conditions = (x, y, z) -> @SVector [conditions(only(x), only(y), z)]
f = ImplicitFunction(_forward, _conditions, linear_solver)
return ScalarImplicitFunction(f)
end
function (f::ScalarImplicitFunction)(x::Real;_ kwargs...)
return only(f.f(@SVector([x]); kwargs...))
end |
this shouldn't require touching any line of code in ImplicitDifferentiation unless we don't support static arrays. |
Gonna close this and start a new PR for the issue at hand |
pbmA = PullbackMul!(pbA, size(y)) | ||
pbmB = PullbackMul!(pbB, size(y)) | ||
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA) | ||
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA) # TODO: can it accept a vector if y is scalar? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@baggepinnen the bug that made me give up was here, so if you find an elegant solution we'll be happy to implement it!
@@ -1,3 +1,6 @@ | |||
make_array(a::AbstractArray) = a | |||
make_array(a::Number) = [a] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now it is not static but it should be
@@ -39,7 +43,9 @@ function (implicit::ImplicitFunction)( | |||
reshape(dₖy_vec, size(y)) | |||
end | |||
|
|||
y_and_dy = let y = y, dy = dy | |||
y_and_dy = if y isa Number |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this type-stable or should it be two methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the type of y
is inferred correctly the code will be type stable, if it isn't type stable, you will have a problem either case
example:
julia> foo(x) = x isa Int ? 2 : 2.0
foo (generic function with 1 method)
julia> @code_warntype foo(3)
MethodInstance for foo(::Int64)
from foo(x) @ Main REPL[15]:1
Arguments
#self#::Core.Const(foo)
x::Int64
Body::Int64
1 ─ %1 = (x isa Main.Int)::Core.Const(true)
└── goto #3 if not %1
2 ─ return 2
3 ─ Core.Const(:(return 2.0))
the interesting part is (x isa Main.Int)::Core.Const(true)
I think @mohamed82008 is right and this will be a bit easier once #57 is merged, because we'll have custom types for |
No description provided.