Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support functions that return scalars #51

Closed

Conversation

baggepinnen
Copy link
Contributor

No description provided.

@codecov
Copy link

codecov bot commented May 24, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.08 🎉

Comparison is base (1f93d66) 93.58% compared to head (8e7cc49) 93.67%.

❗ Current head 8e7cc49 differs from pull request most recent head 827f565. Consider uploading reports for the commit 827f565 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #51      +/-   ##
==========================================
+ Coverage   93.58%   93.67%   +0.08%     
==========================================
  Files           5        5              
  Lines          78       79       +1     
==========================================
+ Hits           73       74       +1     
  Misses          5        5              
Impacted Files Coverage Δ
ext/ImplicitDifferentiationForwardDiffExt.jl 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@gdalle
Copy link
Member

gdalle commented May 24, 2023

Thanks for the contribution! Not gonna merge it right away because I still struggle to make it work for reverse mode.
Could you explain to me why you need the empty getindex [] to make it work?

@gdalle gdalle marked this pull request as draft May 24, 2023 12:33
@baggepinnen
Copy link
Contributor Author

baggepinnen commented May 24, 2023

The empty getindex [] is the same as only, i.e., it gets the first element [1], but errors if there is more than one element.
I used a one-element array as input, but a scalar as output.

@gdalle gdalle linked an issue May 24, 2023 that may be closed by this pull request
@gdalle gdalle added the feature New feature or request label May 24, 2023
@gdalle gdalle removed the feature New feature or request label May 27, 2023
@gdalle
Copy link
Member

gdalle commented May 27, 2023

@mohamed82008 this PR is a bit of a mess and I could use some help. Basically, scalar outputs work in forward mode but it's much harder in reverse mode to juggle the arrays.

@mohamed82008
Copy link
Collaborator

let me take a look

@mohamed82008
Copy link
Collaborator

Let's take advantage of Forward and Conditions in #57 and make them convert scalars to SVectors of length 1 automatically. This should handle that case cleanly.

@mohamed82008
Copy link
Collaborator

mohamed82008 commented May 27, 2023

even better

struct ScalarImplicitFunction{F}
  f::F
end
function ScalarImplicitFunction(forward, conditions, linear_solver)
  _forward = x -> begin
    y, z = forward(only(x))
    return @SVector([y]), z
  end
  _conditions = (x, y, z) -> @SVector [conditions(only(x), only(y), z)]
  f = ImplicitFunction(_forward, _conditions, linear_solver)
  return ScalarImplicitFunction(f)
end
function (f::ScalarImplicitFunction)(x::Real;_ kwargs...)
  return only(f.f(@SVector([x]); kwargs...))
end

@mohamed82008
Copy link
Collaborator

this shouldn't require touching any line of code in ImplicitDifferentiation unless we don't support static arrays.

@gdalle
Copy link
Member

gdalle commented May 30, 2023

Gonna close this and start a new PR for the issue at hand

pbmA = PullbackMul!(pbA, size(y))
pbmB = PullbackMul!(pbB, size(y))
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA)
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA) # TODO: can it accept a vector if y is scalar?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baggepinnen the bug that made me give up was here, so if you find an elegant solution we'll be happy to implement it!

@@ -1,3 +1,6 @@
make_array(a::AbstractArray) = a
make_array(a::Number) = [a]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now it is not static but it should be

@@ -39,7 +43,9 @@ function (implicit::ImplicitFunction)(
reshape(dₖy_vec, size(y))
end

y_and_dy = let y = y, dy = dy
y_and_dy = if y isa Number
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this type-stable or should it be two methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the type of y is inferred correctly the code will be type stable, if it isn't type stable, you will have a problem either case

example:

julia> foo(x) = x isa Int ? 2 : 2.0
foo (generic function with 1 method)

julia> @code_warntype foo(3)
MethodInstance for foo(::Int64)
  from foo(x) @ Main REPL[15]:1
Arguments
  #self#::Core.Const(foo)
  x::Int64
Body::Int64
1%1 = (x isa Main.Int)::Core.Const(true)
└──      goto #3 if not %1
2return 2
3 ─      Core.Const(:(return 2.0))

the interesting part is (x isa Main.Int)::Core.Const(true)

@gdalle gdalle reopened this May 30, 2023
@gdalle
Copy link
Member

gdalle commented May 30, 2023

I think @mohamed82008 is right and this will be a bit easier once #57 is merged, because we'll have custom types for Forward and Conditions

@gdalle gdalle closed this Aug 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compatibility with scalar functions
3 participants