Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix Variable overloads and add shape/dtype properties #4049

Merged
merged 1 commit into from
Jul 2, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jul 2, 2024

What does this PR do?

  • Fixes Variable inplace operators.
  • Updates the nnx_basics guide to use inplace updates on the Variable
  • Adds .shape and .dtype properties to Variable.

After these changes you can now correctly do in-place operations of Variables:

class Count(nnx.Variable): ...

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0, dtype=jnp.uint32))
    
  def increment(self):
    self.count += 1

Previously you have do this on the value: self.count.value += 1 .

@cgarciae cgarciae force-pushed the nnx-fix-variable-inplace-operators branch from c5a66f9 to 61343fa Compare July 2, 2024 15:54
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the nnx-fix-variable-inplace-operators branch from 61343fa to 385ca73 Compare July 2, 2024 16:01
@copybara-service copybara-service bot merged commit 1367acb into main Jul 2, 2024
18 checks passed
@copybara-service copybara-service bot deleted the nnx-fix-variable-inplace-operators branch July 2, 2024 21:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants