Skip to content

Commit

Permalink
accumulate in fp32 for Reduce* (#19868)
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue authored Mar 18, 2024
1 parent 28ad6c3 commit 7e0d424
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export const createReduceSharedProgramInfo =
const workgroupSize = 32;

const sharedMemorySnippet = `
var<workgroup> aBestValues : array<${output.type.storage}, ${workgroupSize}>;
var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
`;

const getShaderSource = (shaderHelper: ShaderHelper) => `
Expand All @@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo =
let outputIndex = global_idx / ${workgroupSize};
let offset = outputIndex * uniforms.reduceSize;
var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]});
var bestValue = f32(${reduceInitValues[reduceType]});
let Length = uniforms.reduceSize;
for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
let candidate = ${output.type.storage}(${input.getByOffset('offset + k')});
let candidate = f32(${input.getByOffset('offset + k')});
bestValue = ${reduceOps[reduceType]};
}
aBestValues[local_idx] = bestValue;
Expand All @@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo =
output.setByOffset(
'outputIndex',
`${
reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` :
`${reduceOutputValues[reduceType]}`}`)};
reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` :
`${output.type.storage}(${reduceOutputValues[reduceType]})`}`)};
}
}`;

Expand Down

0 comments on commit 7e0d424

Please sign in to comment.