Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional shared bind group to custom materials #5962

Closed
wants to merge 10 commits into from
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,16 @@ description = "A shader that shows how to reuse the core bevy PBR shading functi
category = "Shaders"
wasm = true

[[example]]
name = "shader_material_shared_group"
path = "examples/shader/shader_material_shared_group.rs"

[package.metadata.example.shader_material_shared_group]
name = "Material - Shared Bind Group"
description = "Two custom materials that share a common bind group"
category = "Shaders"
wasm = true

# Stress tests
[[package.metadata.category]]
name = "Stress Tests"
Expand Down
32 changes: 32 additions & 0 deletions assets/shaders/shared_group_common.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#define_import_path example::shared_group::common

struct Time {
seconds_since_startup: f32,
};

struct Emitter {
position: vec3<f32>,
radius: f32,
strength: f32,
propagation_speed: f32,
phase_speed: f32,
};

@group(3) @binding(0)
var<uniform> time: Time;
@group(3) @binding(1)
var<uniform> emitter: Emitter;

fn field_phase(distance: f32) -> f32 {
return sin((time.seconds_since_startup * emitter.propagation_speed) - (distance * emitter.phase_speed));
}

fn field_amplitude(distance: f32) -> f32 {
let amp = (emitter.strength * (emitter.radius - distance))/(emitter.radius * (distance + 1.0));
return max(0.0, amp);
}

fn field_impact(position: vec3<f32>) -> f32 {
let dist = distance(position, emitter.position);
return field_amplitude(dist) * (0.5 + field_phase(dist) * 0.5);
}
69 changes: 69 additions & 0 deletions assets/shaders/shared_group_mat1.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#import bevy_pbr::mesh_view_bindings
#import bevy_pbr::mesh_bindings

#import bevy_pbr::mesh_functions
#import bevy_pbr::utils
#import bevy_pbr::clustered_forward
#import bevy_pbr::pbr_types
#import bevy_pbr::lighting
#import bevy_pbr::shadows
#import bevy_pbr::pbr_functions

#import example::shared_group::common

struct EmitterMaterial {
base_color: vec4<f32>,
};

@group(1) @binding(0)
var<uniform> material: EmitterMaterial;


struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) normal: vec3<f32>,
@location(2) uv: vec2<f32>,
};

struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
#import bevy_pbr::mesh_vertex_output
};

@vertex
fn vertex(in: VertexInput) -> VertexOutput {
var out: VertexOutput;

let scale = 1.0 + field_phase(0.0) * 0.01;
let scaled_position = vec4<f32>(in.position * vec3<f32>(scale, scale, scale), 1.0);

out.world_normal = mesh_normal_local_to_world(in.normal);
out.world_position = mesh_position_local_to_world(mesh.model, scaled_position);
out.uv = in.uv;
out.clip_position = mesh_position_world_to_clip(out.world_position);

return out;
}

struct FragmentInput {
@builtin(front_facing) is_front: bool,
@builtin(position) frag_coord: vec4<f32>,
#import bevy_pbr::mesh_vertex_output
};

@fragment
fn fragment(in: FragmentInput) -> @location(0) vec4<f32> {
var pbr_input = pbr_input_new();

// Set PBR material properties
pbr_input.material.base_color = material.base_color;

// Set PBR frament / world properties
pbr_input.frag_coord = in.frag_coord;
pbr_input.world_position = in.world_position;
pbr_input.world_normal = in.world_normal;
pbr_input.N = prepare_normal(pbr_input.material.flags, in.world_normal, in.uv, in.is_front);
pbr_input.V = calculate_view(in.world_position, pbr_input.is_orthographic);

return pbr(pbr_input);
}
46 changes: 46 additions & 0 deletions assets/shaders/shared_group_mat2.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#import bevy_pbr::mesh_view_bindings
#import bevy_pbr::mesh_bindings

#import bevy_pbr::mesh_functions
#import bevy_pbr::utils
#import bevy_pbr::clustered_forward
#import bevy_pbr::pbr_types
#import bevy_pbr::lighting
#import bevy_pbr::shadows
#import bevy_pbr::pbr_functions

#import example::shared_group::common

struct ReceiverMaterial {
base_color: vec4<f32>,
};

@group(1) @binding(0)
var<uniform> material: ReceiverMaterial;


struct FragmentInput {
@builtin(front_facing) is_front: bool,
@builtin(position) frag_coord: vec4<f32>,
#import bevy_pbr::mesh_vertex_output
};

@fragment
fn fragment(in: FragmentInput) -> @location(0) vec4<f32> {
var pbr_input = pbr_input_new();

// Set PBR material properties
let impact = field_impact(in.world_position.xyz);
pbr_input.material.base_color = material.base_color + vec4<f32>(impact, impact, impact, 1.0);

// Set PBR frament / world properties
pbr_input.frag_coord = in.frag_coord;
pbr_input.world_position = in.world_position;
pbr_input.world_normal = in.world_normal;

// Calculate stuff?
pbr_input.N = prepare_normal(pbr_input.material.flags, in.world_normal, in.uv, in.is_front);
pbr_input.V = calculate_view(in.world_position, pbr_input.is_orthographic);

return pbr(pbr_input);
}
2 changes: 2 additions & 0 deletions crates/bevy_pbr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ mod light;
mod material;
mod pbr_material;
mod render;
mod shared_group;

pub use alpha::*;
pub use bundle::*;
pub use light::*;
pub use material::*;
pub use pbr_material::*;
pub use render::*;
pub use shared_group::*;

use bevy_window::ModifiesWindows;

Expand Down
104 changes: 81 additions & 23 deletions crates/bevy_pbr/src/material.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
AlphaMode, DrawMesh, MeshPipeline, MeshPipelineKey, MeshUniform, SetMeshBindGroup,
SetMeshViewBindGroup,
SetMeshViewBindGroup, SharedBindGroup,
};
use bevy_app::{App, Plugin};
use bevy_asset::{AddAsset, AssetEvent, AssetServer, Assets, Handle};
Expand Down Expand Up @@ -151,15 +151,15 @@ pub trait Material: AsBindGroup + Send + Sync + Clone + TypeUuid + Sized + 'stat

/// Adds the necessary ECS resources and render logic to enable rendering entities using the given [`Material`]
/// asset type.
pub struct MaterialPlugin<M: Material>(PhantomData<M>);
pub struct MaterialPlugin<M: Material, G: AsBindGroup + 'static = ()>(PhantomData<fn() -> (M, G)>);

impl<M: Material> Default for MaterialPlugin<M> {
impl<M: Material, G: AsBindGroup + 'static> Default for MaterialPlugin<M, G> {
fn default() -> Self {
Self(Default::default())
}
}

impl<M: Material> Plugin for MaterialPlugin<M>
impl<M: Material, G: AsBindGroup + Send + Sync + 'static> Plugin for MaterialPlugin<M, G>
where
M::Data: PartialEq + Eq + Hash + Clone,
{
Expand All @@ -168,9 +168,6 @@ where
.add_plugin(ExtractComponentPlugin::<Handle<M>>::extract_visible());
if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
render_app
.add_render_command::<Transparent3d, DrawMaterial<M>>()
.add_render_command::<Opaque3d, DrawMaterial<M>>()
.add_render_command::<AlphaMask3d, DrawMaterial<M>>()
.init_resource::<MaterialPipeline<M>>()
.init_resource::<ExtractedMaterials<M>>()
.init_resource::<RenderMaterials<M>>()
Expand All @@ -179,8 +176,32 @@ where
.add_system_to_stage(
RenderStage::Prepare,
prepare_materials::<M>.after(PrepareAssetLabel::PreAssetPrepare),
)
.add_system_to_stage(RenderStage::Queue, queue_material_meshes::<M>);
);

if let Some(shared_group) = render_app.world.get_resource_mut::<SharedBindGroup<G>>() {
render_app
.world
.resource_mut::<MaterialPipeline<M>>()
.shared_layout = Some(shared_group.bind_group_layout.clone());

render_app
.add_render_command::<Transparent3d, DrawWithSharedGroup<M, G>>()
.add_render_command::<Opaque3d, DrawWithSharedGroup<M, G>>()
.add_render_command::<AlphaMask3d, DrawWithSharedGroup<M, G>>()
.add_system_to_stage(
RenderStage::Queue,
queue_material_meshes::<M, DrawWithSharedGroup<M, G>>,
);
} else {
render_app
.add_render_command::<Transparent3d, DrawWithoutSharedGroup<M>>()
.add_render_command::<Opaque3d, DrawWithoutSharedGroup<M>>()
.add_render_command::<AlphaMask3d, DrawWithoutSharedGroup<M>>()
.add_system_to_stage(
RenderStage::Queue,
queue_material_meshes::<M, DrawWithoutSharedGroup<M>>,
);
}
}
}
}
Expand Down Expand Up @@ -229,6 +250,7 @@ where
pub struct MaterialPipeline<M: Material> {
pub mesh_pipeline: MeshPipeline,
pub material_layout: BindGroupLayout,
pub shared_layout: Option<BindGroupLayout>,
pub vertex_shader: Option<Handle<Shader>>,
pub fragment_shader: Option<Handle<Shader>>,
marker: PhantomData<M>,
Expand Down Expand Up @@ -259,6 +281,20 @@ where
let descriptor_layout = descriptor.layout.as_mut().unwrap();
descriptor_layout.insert(1, self.material_layout.clone());

if let Some(shared_layout) = &self.shared_layout {
descriptor
.vertex
.shader_defs
.push(String::from("SHARED_DATA"));
descriptor
.fragment
.as_mut()
.unwrap()
.shader_defs
.push(String::from("SHARED_DATA"));
descriptor_layout.insert(3, shared_layout.clone());
}

M::specialize(self, &mut descriptor, layout, key)?;
Ok(descriptor)
}
Expand All @@ -272,6 +308,7 @@ impl<M: Material> FromWorld for MaterialPipeline<M> {
MaterialPipeline {
mesh_pipeline: world.resource::<MeshPipeline>().clone(),
material_layout: M::bind_group_layout(render_device),
shared_layout: None,
vertex_shader: match M::vertex_shader() {
ShaderRef::Default => None,
ShaderRef::Handle(handle) => Some(handle),
Expand All @@ -287,11 +324,20 @@ impl<M: Material> FromWorld for MaterialPipeline<M> {
}
}

type DrawMaterial<M> = (
type DrawWithoutSharedGroup<M> = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMaterialBindGroup<M, 1>,
SetMeshBindGroup<2>,
DrawMesh,
);

type DrawWithSharedGroup<M, G> = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMaterialBindGroup<M, 1>,
SetMeshBindGroup<2>,
SetSharedBindGroup<G, 3>,
DrawMesh,
);

Expand All @@ -312,8 +358,29 @@ impl<M: Material, const I: usize> EntityRenderCommand for SetMaterialBindGroup<M
}
}

pub struct SetSharedBindGroup<G: AsBindGroup, const I: usize>(PhantomData<G>);
impl<G: AsBindGroup + Send + Sync + 'static, const I: usize> EntityRenderCommand
for SetSharedBindGroup<G, I>
{
type Param = SRes<SharedBindGroup<G>>;
fn render<'w>(
_view: Entity,
_item: Entity,
param: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let shared_group = param.into_inner();
if let Some(bind_group) = &shared_group.bind_group {
pass.set_bind_group(I, bind_group, &[]);
RenderCommandResult::Success
} else {
RenderCommandResult::Failure
}
}
}

#[allow(clippy::too_many_arguments)]
pub fn queue_material_meshes<M: Material>(
pub fn queue_material_meshes<M: Material, F: 'static>(
opaque_draw_functions: Res<DrawFunctions<Opaque3d>>,
alpha_mask_draw_functions: Res<DrawFunctions<AlphaMask3d>>,
transparent_draw_functions: Res<DrawFunctions<Transparent3d>>,
Expand All @@ -337,18 +404,9 @@ pub fn queue_material_meshes<M: Material>(
for (view, visible_entities, mut opaque_phase, mut alpha_mask_phase, mut transparent_phase) in
&mut views
{
let draw_opaque_pbr = opaque_draw_functions
.read()
.get_id::<DrawMaterial<M>>()
.unwrap();
let draw_alpha_mask_pbr = alpha_mask_draw_functions
.read()
.get_id::<DrawMaterial<M>>()
.unwrap();
let draw_transparent_pbr = transparent_draw_functions
.read()
.get_id::<DrawMaterial<M>>()
.unwrap();
let draw_opaque_pbr = opaque_draw_functions.read().get_id::<F>().unwrap();
let draw_alpha_mask_pbr = alpha_mask_draw_functions.read().get_id::<F>().unwrap();
let draw_transparent_pbr = transparent_draw_functions.read().get_id::<F>().unwrap();

let rangefinder = view.rangefinder3d();
let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples);
Expand Down
Loading