Skip to content

Commit

Permalink
[compiler] Allow refs to be lazily initialized during render
Browse files Browse the repository at this point in the history
Summary:
The official guidance for useRef notes an exception to the rule that refs cannot be accessed during render: to avoid recreating the ref's contents, you can test that the ref is uninitialized and then initialize it using an if statement:

```
if (ref.current == null) {
  ref.current = SomeExpensiveOperation()
}
```

The compiler didn't recognize this exception, however, leading to code that obeyed all the official guidance for refs being rejected by the compiler. This PR fixes that, by extending the ref validation machinery with an awareness of guard operations that allow lazy initialization. We now understand `== null` and similar operations, when applied to a ref and consumed by an if terminal, as marking the consequent of the if as a block in which the ref can be safely written to. In order to do so we need to create a notion of ref ids, which link different usages of the same ref via both the ref and the ref value.

ghstack-source-id: d2729274f351e1eb0268f28f629fa4c2568ebc4d
Pull Request resolved: #31188
  • Loading branch information
mvitousek committed Oct 11, 2024
1 parent 9c525ea commit 6cf5bd9
Show file tree
Hide file tree
Showing 19 changed files with 628 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import {CompilerError, ErrorSeverity} from '../CompilerError';
import {
BlockId,
HIRFunction,
Identifier,
IdentifierId,
Place,
SourceLocation,
Expand All @@ -17,6 +17,7 @@ import {
isUseRefType,
} from '../HIR';
import {
eachInstructionOperand,
eachInstructionValueOperand,
eachPatternOperand,
eachTerminalOperand,
Expand Down Expand Up @@ -44,11 +45,32 @@ import {Err, Ok, Result} from '../Utils/Result';
* or based on property name alone (`foo.current` might be a ref).
*/

type RefAccessType = {kind: 'None'} | RefAccessRefType;
const opaqueRefId = Symbol();
type RefId = number & {[opaqueRefId]: 'RefId'};

function makeRefId(id: number): RefId {
CompilerError.invariant(id >= 0 && Number.isInteger(id), {
reason: 'Expected identifier id to be a non-negative integer',
description: null,
loc: null,
suggestions: null,
});
return id as RefId;
}
let _refId = 0;
function nextRefId(): RefId {
return makeRefId(_refId++);
}

type RefAccessType =
| {kind: 'None'}
| {kind: 'Nullable'}
| {kind: 'Guard'; refId: RefId}
| RefAccessRefType;

type RefAccessRefType =
| {kind: 'Ref'}
| {kind: 'RefValue'; loc?: SourceLocation}
| {kind: 'Ref'; refId: RefId}
| {kind: 'RefValue'; loc?: SourceLocation; refId?: RefId}
| {kind: 'Structure'; value: null | RefAccessRefType; fn: null | RefFnType};

type RefFnType = {readRefEffect: boolean; returnType: RefAccessType};
Expand Down Expand Up @@ -82,11 +104,11 @@ export function validateNoRefAccessInRender(fn: HIRFunction): void {
validateNoRefAccessInRenderImpl(fn, env).unwrap();
}

function refTypeOfType(identifier: Identifier): RefAccessType {
if (isRefValueType(identifier)) {
function refTypeOfType(place: Place): RefAccessType {
if (isRefValueType(place.identifier)) {
return {kind: 'RefValue'};
} else if (isUseRefType(identifier)) {
return {kind: 'Ref'};
} else if (isUseRefType(place.identifier)) {
return {kind: 'Ref', refId: nextRefId()};
} else {
return {kind: 'None'};
}
Expand All @@ -101,6 +123,14 @@ function tyEqual(a: RefAccessType, b: RefAccessType): boolean {
return true;
case 'Ref':
return true;
case 'Nullable':
return true;
case 'Guard':
CompilerError.invariant(b.kind === 'Guard', {
reason: 'Expected ref value',
loc: null,
});
return a.refId === b.refId;
case 'RefValue':
CompilerError.invariant(b.kind === 'RefValue', {
reason: 'Expected ref value',
Expand Down Expand Up @@ -133,11 +163,17 @@ function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
b: RefAccessRefType,
): RefAccessRefType {
if (a.kind === 'RefValue') {
return a;
if (b.kind === 'RefValue' && a.refId === b.refId) {
return a;
}
return {kind: 'RefValue'};
} else if (b.kind === 'RefValue') {
return b;
} else if (a.kind === 'Ref' || b.kind === 'Ref') {
return {kind: 'Ref'};
if (a.kind === 'Ref' && b.kind === 'Ref' && a.refId === b.refId) {
return a;
}
return {kind: 'Ref', refId: nextRefId()};
} else {
CompilerError.invariant(
a.kind === 'Structure' && b.kind === 'Structure',
Expand Down Expand Up @@ -178,6 +214,16 @@ function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
return b;
} else if (b.kind === 'None') {
return a;
} else if (a.kind === 'Guard' || b.kind === 'Guard') {
if (a.kind === 'Guard' && b.kind === 'Guard' && a.refId === b.refId) {
return a;
}
return {kind: 'None'};
} else if (a.kind === 'Nullable' || b.kind === 'Nullable') {
if (a.kind === 'Nullable' && b.kind === 'Nullable') {
return a;
}
return {kind: 'None'};
} else {
return joinRefAccessRefTypes(a, b);
}
Expand All @@ -198,13 +244,14 @@ function validateNoRefAccessInRenderImpl(
} else {
place = param.place;
}
const type = refTypeOfType(place.identifier);
const type = refTypeOfType(place);
env.set(place.identifier.id, type);
}

for (let i = 0; (i == 0 || env.hasChanged()) && i < 10; i++) {
env.resetChanged();
returnValues = [];
const safeBlocks = new Map<BlockId, RefId>();
const errors = new CompilerError();
for (const [, block] of fn.body.blocks) {
for (const phi of block.phis) {
Expand Down Expand Up @@ -238,11 +285,15 @@ function validateNoRefAccessInRenderImpl(
if (objType?.kind === 'Structure') {
lookupType = objType.value;
} else if (objType?.kind === 'Ref') {
lookupType = {kind: 'RefValue', loc: instr.loc};
lookupType = {
kind: 'RefValue',
loc: instr.loc,
refId: objType.refId,
};
}
env.set(
instr.lvalue.identifier.id,
lookupType ?? refTypeOfType(instr.lvalue.identifier),
lookupType ?? refTypeOfType(instr.lvalue),
);
break;
}
Expand All @@ -251,7 +302,7 @@ function validateNoRefAccessInRenderImpl(
env.set(
instr.lvalue.identifier.id,
env.get(instr.value.place.identifier.id) ??
refTypeOfType(instr.lvalue.identifier),
refTypeOfType(instr.lvalue),
);
break;
}
Expand All @@ -260,12 +311,12 @@ function validateNoRefAccessInRenderImpl(
env.set(
instr.value.lvalue.place.identifier.id,
env.get(instr.value.value.identifier.id) ??
refTypeOfType(instr.value.lvalue.place.identifier),
refTypeOfType(instr.value.lvalue.place),
);
env.set(
instr.lvalue.identifier.id,
env.get(instr.value.value.identifier.id) ??
refTypeOfType(instr.lvalue.identifier),
refTypeOfType(instr.lvalue),
);
break;
}
Expand All @@ -277,13 +328,10 @@ function validateNoRefAccessInRenderImpl(
}
env.set(
instr.lvalue.identifier.id,
lookupType ?? refTypeOfType(instr.lvalue.identifier),
lookupType ?? refTypeOfType(instr.lvalue),
);
for (const lval of eachPatternOperand(instr.value.lvalue.pattern)) {
env.set(
lval.identifier.id,
lookupType ?? refTypeOfType(lval.identifier),
);
env.set(lval.identifier.id, lookupType ?? refTypeOfType(lval));
}
break;
}
Expand Down Expand Up @@ -354,7 +402,11 @@ function validateNoRefAccessInRenderImpl(
types.push(env.get(operand.identifier.id) ?? {kind: 'None'});
}
const value = joinRefAccessTypes(...types);
if (value.kind === 'None') {
if (
value.kind === 'None' ||
value.kind === 'Guard' ||
value.kind === 'Nullable'
) {
env.set(instr.lvalue.identifier.id, {kind: 'None'});
} else {
env.set(instr.lvalue.identifier.id, {
Expand All @@ -369,7 +421,18 @@ function validateNoRefAccessInRenderImpl(
case 'PropertyStore':
case 'ComputedDelete':
case 'ComputedStore': {
validateNoRefAccess(errors, env, instr.value.object, instr.loc);
const safe = safeBlocks.get(block.id);
const target = env.get(instr.value.object.identifier.id);
if (
instr.value.kind === 'PropertyStore' &&
safe != null &&
target?.kind === 'Ref' &&
target.refId === safe
) {
safeBlocks.delete(block.id);
} else {
validateNoRefAccess(errors, env, instr.value.object, instr.loc);
}
for (const operand of eachInstructionValueOperand(instr.value)) {
if (operand === instr.value.object) {
continue;
Expand All @@ -381,23 +444,67 @@ function validateNoRefAccessInRenderImpl(
case 'StartMemoize':
case 'FinishMemoize':
break;
case 'Primitive': {
if (instr.value.value == null) {
env.set(instr.lvalue.identifier.id, {kind: 'Nullable'});
}
break;
}
case 'BinaryExpression': {
const left = env.get(instr.value.left.identifier.id);
const right = env.get(instr.value.right.identifier.id);
let nullish: boolean = false;
let refId: RefId | null = null;
if (left?.kind === 'RefValue' && left.refId != null) {
refId = left.refId;
} else if (right?.kind === 'RefValue' && right.refId != null) {
refId = right.refId;
}

if (left?.kind === 'Nullable') {
nullish = true;
} else if (right?.kind === 'Nullable') {
nullish = true;
}

if (refId !== null && nullish) {
env.set(instr.lvalue.identifier.id, {kind: 'Guard', refId});
} else {
for (const operand of eachInstructionValueOperand(instr.value)) {
validateNoRefValueAccess(errors, env, operand);
}
}
break;
}
default: {
for (const operand of eachInstructionValueOperand(instr.value)) {
validateNoRefValueAccess(errors, env, operand);
}
break;
}
}
if (isUseRefType(instr.lvalue.identifier)) {

// Guard values are derived from ref.current, so they can only be used in if statement targets
for (const operand of eachInstructionOperand(instr)) {
guardCheck(errors, operand, env);
}

if (
isUseRefType(instr.lvalue.identifier) &&
env.get(instr.lvalue.identifier.id)?.kind !== 'Ref'
) {
env.set(
instr.lvalue.identifier.id,
joinRefAccessTypes(
env.get(instr.lvalue.identifier.id) ?? {kind: 'None'},
{kind: 'Ref'},
{kind: 'Ref', refId: nextRefId()},
),
);
}
if (isRefValueType(instr.lvalue.identifier)) {
if (
isRefValueType(instr.lvalue.identifier) &&
env.get(instr.lvalue.identifier.id)?.kind !== 'RefValue'
) {
env.set(
instr.lvalue.identifier.id,
joinRefAccessTypes(
Expand All @@ -407,12 +514,24 @@ function validateNoRefAccessInRenderImpl(
);
}
}

if (block.terminal.kind === 'if') {
const test = env.get(block.terminal.test.identifier.id);
if (test?.kind === 'Guard') {
safeBlocks.set(block.terminal.consequent, test.refId);
}
}

for (const operand of eachTerminalOperand(block.terminal)) {
if (block.terminal.kind !== 'return') {
validateNoRefValueAccess(errors, env, operand);
if (block.terminal.kind !== 'if') {
guardCheck(errors, operand, env);
}
} else {
// Allow functions containing refs to be returned, but not direct ref values
validateNoDirectRefValueAccess(errors, operand, env);
guardCheck(errors, operand, env);
returnValues.push(env.get(operand.identifier.id));
}
}
Expand Down Expand Up @@ -444,6 +563,23 @@ function destructure(
return type;
}

function guardCheck(errors: CompilerError, operand: Place, env: Env): void {
if (env.get(operand.identifier.id)?.kind === 'Guard') {
errors.push({
severity: ErrorSeverity.InvalidReact,
reason:
'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)',
loc: operand.loc,
description:
operand.identifier.name !== null &&
operand.identifier.name.kind === 'named'
? `Cannot access ref value \`${operand.identifier.name.value}\``
: null,
suggestions: null,
});
}
}

function validateNoRefValueAccess(
errors: CompilerError,
env: Env,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

## Input

```javascript
//@flow
import {useRef} from 'react';

component C() {
const r = useRef(null);
if (r.current == null) {
r.current = 1;
}
}

export const FIXTURE_ENTRYPOINT = {
fn: C,
params: [{}],
};

```

## Code

```javascript
import { useRef } from "react";

function C() {
const r = useRef(null);
if (r.current == null) {
r.current = 1;
}
}

export const FIXTURE_ENTRYPOINT = {
fn: C,
params: [{}],
};

```
### Eval output
(kind: ok)
Loading

0 comments on commit 6cf5bd9

Please sign in to comment.