Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Type provider infra for tests #30776

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
NonLocalBinding,
PolyType,
ScopeId,
SourceLocation,
Type,
ValidatedIdentifier,
ValueKind,
Expand Down Expand Up @@ -126,11 +127,6 @@ const HookSchema = z.object({

export type Hook = z.infer<typeof HookSchema>;

export const ModuleTypeResolver = z
.function()
.args(z.string())
.returns(z.nullable(TypeSchema));

Comment on lines -129 to -133
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting .returns() causes zod to return a function that internally validates the return type when you call it. that's cool but we want to control when the validation happens so i'm switching this to a plain function annotation

/*
* TODO(mofeiZ): User defined global types (with corresponding shapes).
* User defined global types should have inline ObjectShapes instead of directly
Expand All @@ -148,7 +144,7 @@ const EnvironmentConfigSchema = z.object({
* A function that, given the name of a module, can optionally return a description
* of that module's type signature.
*/
resolveModuleTypeSchema: z.nullable(ModuleTypeResolver).default(null),
moduleTypeProvider: z.nullable(z.function().args(z.string())).default(null),

/**
* A list of functions which the application compiles as macros, where
Expand Down Expand Up @@ -712,19 +708,27 @@ export class Environment {
return this.#outlinedFunctions;
}

#resolveModuleType(moduleName: string): Global | null {
if (this.config.resolveModuleTypeSchema == null) {
#resolveModuleType(moduleName: string, loc: SourceLocation): Global | null {
if (this.config.moduleTypeProvider == null) {
return null;
}
let moduleType = this.#moduleTypes.get(moduleName);
if (moduleType === undefined) {
const moduleConfig = this.config.resolveModuleTypeSchema(moduleName);
if (moduleConfig != null) {
const moduleTypes = TypeSchema.parse(moduleConfig);
const unparsedModuleConfig = this.config.moduleTypeProvider(moduleName);
if (unparsedModuleConfig != null) {
const parsedModuleConfig = TypeSchema.safeParse(unparsedModuleConfig);
if (!parsedModuleConfig.success) {
CompilerError.throwInvalidConfig({
reason: `Could not parse module type, the configured \`moduleTypeProvider\` function returned an invalid module description`,
description: parsedModuleConfig.error.toString(),
loc,
});
}
const moduleConfig = parsedModuleConfig.data;
moduleType = installTypeConfig(
this.#globals,
this.#shapes,
moduleTypes,
moduleConfig,
);
} else {
moduleType = null;
Expand All @@ -734,7 +738,10 @@ export class Environment {
return moduleType;
}

getGlobalDeclaration(binding: NonLocalBinding): Global | null {
getGlobalDeclaration(
binding: NonLocalBinding,
loc: SourceLocation,
): Global | null {
if (this.config.hookPattern != null) {
const match = new RegExp(this.config.hookPattern).exec(binding.name);
if (
Expand Down Expand Up @@ -772,7 +779,7 @@ export class Environment {
(isHookName(binding.imported) ? this.#getCustomHookType() : null)
);
} else {
const moduleType = this.#resolveModuleType(binding.module);
const moduleType = this.#resolveModuleType(binding.module, loc);
if (moduleType !== null) {
const importedType = this.getPropertyType(
moduleType,
Expand Down Expand Up @@ -805,10 +812,16 @@ export class Environment {
(isHookName(binding.name) ? this.#getCustomHookType() : null)
);
} else {
const moduleType = this.#resolveModuleType(binding.module);
const moduleType = this.#resolveModuleType(binding.module, loc);
if (moduleType !== null) {
// TODO: distinguish default/namespace cases
return moduleType;
if (binding.kind === 'ImportDefault') {
const defaultType = this.getPropertyType(moduleType, 'default');
if (defaultType !== null) {
return defaultType;
}
} else {
return moduleType;
}
}
return isHookName(binding.name) ? this.#getCustomHookType() : null;
}
Expand All @@ -819,9 +832,7 @@ export class Environment {
#isKnownReactModule(moduleName: string): boolean {
return (
moduleName.toLowerCase() === 'react' ||
moduleName.toLowerCase() === 'react-dom' ||
(this.config.enableSharedRuntime__testonly &&
moduleName === 'shared-runtime')
moduleName.toLowerCase() === 'react-dom'
);
}

Expand Down
23 changes: 23 additions & 0 deletions compiler/packages/babel-plugin-react-compiler/src/HIR/Globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,9 @@ export function installTypeConfig(
case 'Ref': {
return {kind: 'Object', shapeId: BuiltInUseRefId};
}
case 'Any': {
return {kind: 'Poly'};
}
default: {
assertExhaustive(
typeConfig.name,
Expand All @@ -566,6 +569,20 @@ export function installTypeConfig(
calleeEffect: typeConfig.calleeEffect,
returnType: installTypeConfig(globals, shapes, typeConfig.returnType),
returnValueKind: typeConfig.returnValueKind,
noAlias: typeConfig.noAlias === true,
mutableOnlyIfOperandsAreMutable:
typeConfig.mutableOnlyIfOperandsAreMutable === true,
});
}
case 'hook': {
return addHook(shapes, {
hookKind: 'Custom',
positionalParams: typeConfig.positionalParams ?? [],
restParam: typeConfig.restParam ?? Effect.Freeze,
calleeEffect: Effect.Read,
returnType: installTypeConfig(globals, shapes, typeConfig.returnType),
returnValueKind: typeConfig.returnValueKind ?? ValueKind.Frozen,
noAlias: typeConfig.noAlias === true,
});
}
case 'object': {
Expand All @@ -578,6 +595,12 @@ export function installTypeConfig(
]),
);
}
default: {
assertExhaustive(
typeConfig,
`Unexpected type kind '${(typeConfig as any).kind}'`,
);
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions compiler/packages/babel-plugin-react-compiler/src/HIR/HIR.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,15 @@ export enum ValueKind {
Context = 'context',
}

export const ValueKindSchema = z.enum([
ValueKind.MaybeFrozen,
ValueKind.Frozen,
ValueKind.Primitive,
ValueKind.Global,
ValueKind.Mutable,
ValueKind.Context,
]);

// The effect with which a value is modified.
export enum Effect {
// Default value: not allowed after lifetime inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import {isValidIdentifier} from '@babel/types';
import {z} from 'zod';
import {Effect, ValueKind} from '..';
import {EffectSchema} from './HIR';
import {EffectSchema, ValueKindSchema} from './HIR';

export type ObjectPropertiesConfig = {[key: string]: TypeConfig};
export const ObjectPropertiesSchema: z.ZodType<ObjectPropertiesConfig> = z
Expand All @@ -18,9 +18,9 @@ export const ObjectPropertiesSchema: z.ZodType<ObjectPropertiesConfig> = z
)
.refine(record => {
return Object.keys(record).every(
key => key === '*' || isValidIdentifier(key),
key => key === '*' || key === 'default' || isValidIdentifier(key),
);
}, 'Expected all "object" property names to be valid identifiers or `*` to match any property');
}, 'Expected all "object" property names to be valid identifier, `*` to match any property, of `default` to define a module default export');

export type ObjectTypeConfig = {
kind: 'object';
Expand All @@ -38,18 +38,45 @@ export type FunctionTypeConfig = {
calleeEffect: Effect;
returnType: TypeConfig;
returnValueKind: ValueKind;
noAlias?: boolean | null | undefined;
mutableOnlyIfOperandsAreMutable?: boolean | null | undefined;
};
export const FunctionTypeSchema: z.ZodType<FunctionTypeConfig> = z.object({
kind: z.literal('function'),
positionalParams: z.array(EffectSchema),
restParam: EffectSchema.nullable(),
calleeEffect: EffectSchema,
returnType: z.lazy(() => TypeSchema),
returnValueKind: z.nativeEnum(ValueKind),
returnValueKind: ValueKindSchema,
noAlias: z.boolean().nullable().optional(),
mutableOnlyIfOperandsAreMutable: z.boolean().nullable().optional(),
});

export type BuiltInTypeConfig = 'Ref' | 'Array' | 'Primitive' | 'MixedReadonly';
export type HookTypeConfig = {
kind: 'hook';
positionalParams?: Array<Effect> | null | undefined;
restParam?: Effect | null | undefined;
returnType: TypeConfig;
returnValueKind?: ValueKind | null | undefined;
noAlias?: boolean | null | undefined;
};
export const HookTypeSchema: z.ZodType<HookTypeConfig> = z.object({
kind: z.literal('hook'),
positionalParams: z.array(EffectSchema).nullable().optional(),
restParam: EffectSchema.nullable().optional(),
returnType: z.lazy(() => TypeSchema),
returnValueKind: ValueKindSchema.nullable().optional(),
noAlias: z.boolean().nullable().optional(),
});

export type BuiltInTypeConfig =
| 'Any'
| 'Ref'
| 'Array'
| 'Primitive'
| 'MixedReadonly';
export const BuiltInTypeSchema: z.ZodType<BuiltInTypeConfig> = z.union([
z.literal('Any'),
z.literal('Ref'),
z.literal('Array'),
z.literal('Primitive'),
Expand All @@ -68,9 +95,11 @@ export const TypeReferenceSchema: z.ZodType<TypeReferenceConfig> = z.object({
export type TypeConfig =
| ObjectTypeConfig
| FunctionTypeConfig
| HookTypeConfig
| TypeReferenceConfig;
export const TypeSchema: z.ZodType<TypeConfig> = z.union([
ObjectTypeSchema,
FunctionTypeSchema,
HookTypeSchema,
TypeReferenceSchema,
]);
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function collectTemporaries(
break;
}
case 'LoadGlobal': {
const global = env.getGlobalDeclaration(value.binding);
const global = env.getGlobalDeclaration(value.binding, value.loc);
const hookKind = global !== null ? getHookKindForType(env, global) : null;
const lvalId = instr.lvalue.identifier.id;
if (hookKind === 'useMemo' || hookKind === 'useCallback') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ function* generateInstructionTypes(
}

case 'LoadGlobal': {
const globalType = env.getGlobalDeclaration(value.binding);
const globalType = env.getGlobalDeclaration(value.binding, value.loc);
if (globalType) {
yield equation(left, globalType);
}
Expand Down
Loading
Loading