Skip to content

Commit

Permalink
fix: include_file should handle proto without package (#1002)
Browse files Browse the repository at this point in the history
* fix #1001 and add tests

* add alloc:: imports

* rewrite write_includes to allow for empty modules.

* create test fixture for `write_includes`

* fix lints, remove line feeds

* fixes after merge master

* remove some duplicate tests and alter existing ones to test write_includes

* more test

* module.rs Module::starts_with visibility
  • Loading branch information
MixusMinimax authored May 5, 2024
1 parent 1f38ea6 commit baddf98
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 58 deletions.
97 changes: 44 additions & 53 deletions prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ impl Config {
self.write_includes(
modules.keys().collect(),
&mut file,
0,
if target_is_env { None } else { Some(&target) },
&file_names,
)?;
file.flush()?;
}
Expand Down Expand Up @@ -955,67 +955,58 @@ impl Config {
self.compile_fds(file_descriptor_set)
}

fn write_includes(
pub(crate) fn write_includes(
&self,
mut entries: Vec<&Module>,
outfile: &mut fs::File,
depth: usize,
mut modules: Vec<&Module>,
outfile: &mut impl Write,
basepath: Option<&PathBuf>,
) -> Result<usize> {
let mut written = 0;
entries.sort();

while !entries.is_empty() {
let modident = entries[0].part(depth);
let matching: Vec<&Module> = entries
.iter()
.filter(|&v| v.part(depth) == modident)
.copied()
.collect();
{
// Will NLL sort this mess out?
let _temp = entries
.drain(..)
.filter(|&v| v.part(depth) != modident)
.collect();
entries = _temp;
file_names: &HashMap<Module, String>,
) -> Result<()> {
modules.sort();

let mut stack = Vec::new();

for module in modules {
while !module.starts_with(&stack) {
stack.pop();
self.write_line(outfile, stack.len(), "}")?;
}
self.write_line(outfile, depth, &format!("pub mod {} {{", modident))?;
let subwritten = self.write_includes(
matching
.iter()
.filter(|v| v.len() > depth + 1)
.copied()
.collect(),
outfile,
depth + 1,
basepath,
)?;
written += subwritten;
if subwritten != matching.len() {
let modname = matching[0].to_partial_file_name(..=depth);
if basepath.is_some() {
self.write_line(
outfile,
depth + 1,
&format!("include!(\"{}.rs\");", modname),
)?;
} else {
self.write_line(
outfile,
depth + 1,
&format!("include!(concat!(env!(\"OUT_DIR\"), \"/{}.rs\"));", modname),
)?;
}
written += 1;
while stack.len() < module.len() {
self.write_line(
outfile,
stack.len(),
&format!("pub mod {} {{", module.part(stack.len())),
)?;
stack.push(module.part(stack.len()).to_owned());
}

let file_name = file_names
.get(module)
.expect("every module should have a filename");

if basepath.is_some() {
self.write_line(
outfile,
stack.len(),
&format!("include!(\"{}\");", file_name),
)?;
} else {
self.write_line(
outfile,
stack.len(),
&format!("include!(concat!(env!(\"OUT_DIR\"), \"/{}\"));", file_name),
)?;
}
}

for depth in (0..stack.len()).rev() {
self.write_line(outfile, depth, "}")?;
}
Ok(written)

Ok(())
}

fn write_line(&self, outfile: &mut fs::File, depth: usize, line: &str) -> Result<()> {
fn write_line(&self, outfile: &mut impl Write, depth: usize, line: &str) -> Result<()> {
outfile.write_all(format!("{}{}\n", (" ").to_owned().repeat(depth), line).as_bytes())
}

Expand Down
23 changes: 23 additions & 0 deletions prost-build/src/fixtures/write_includes/_.includes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
include!(concat!(env!("OUT_DIR"), "/_.default.rs"));
pub mod bar {
include!(concat!(env!("OUT_DIR"), "/bar.rs"));
}
pub mod foo {
include!(concat!(env!("OUT_DIR"), "/foo.rs"));
pub mod bar {
include!(concat!(env!("OUT_DIR"), "/foo.bar.rs"));
pub mod a {
pub mod b {
pub mod c {
include!(concat!(env!("OUT_DIR"), "/foo.bar.a.b.c.rs"));
}
}
}
pub mod baz {
include!(concat!(env!("OUT_DIR"), "/foo.bar.baz.rs"));
}
pub mod qux {
include!(concat!(env!("OUT_DIR"), "/foo.bar.qux.rs"));
}
}
}
28 changes: 28 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,4 +530,32 @@ mod tests {
f.read_to_string(&mut content).unwrap();
content
}

#[test]
fn write_includes() {
let modules = [
Module::from_protobuf_package_name("foo.bar.baz"),
Module::from_protobuf_package_name(""),
Module::from_protobuf_package_name("foo.bar"),
Module::from_protobuf_package_name("bar"),
Module::from_protobuf_package_name("foo"),
Module::from_protobuf_package_name("foo.bar.qux"),
Module::from_protobuf_package_name("foo.bar.a.b.c"),
];

let file_names = modules
.iter()
.map(|m| (m.clone(), m.to_file_name_or("_.default")))
.collect();

let mut buf = Vec::new();
Config::new()
.default_package_filename("_.default")
.write_includes(modules.iter().collect(), &mut buf, None, &file_names)
.unwrap();
let expected =
read_all_content("src/fixtures/write_includes/_.includes.rs").replace("\r\n", "\n");
let actual = String::from_utf8(buf).unwrap();
assert_eq!(expected, actual);
}
}
14 changes: 9 additions & 5 deletions prost-build/src/module.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::fmt;
use std::ops::RangeToInclusive;

use crate::ident::to_snake;

Expand Down Expand Up @@ -40,6 +39,15 @@ impl Module {
self.components.iter().map(|s| s.as_str())
}

#[must_use]
#[inline(always)]
pub(crate) fn starts_with(&self, needle: &[String]) -> bool
where
String: PartialEq,
{
self.components.starts_with(needle)
}

/// Format the module path into a filename for generated Rust code.
///
/// If the module path is empty, `default` is used to provide the root of the filename.
Expand All @@ -65,10 +73,6 @@ impl Module {
self.components.is_empty()
}

pub(crate) fn to_partial_file_name(&self, range: RangeToInclusive<usize>) -> String {
self.components[range].join(".")
}

pub(crate) fn part(&self, idx: usize) -> &str {
self.components[idx].as_str()
}
Expand Down
1 change: 1 addition & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ fn main() {
no_root_packages_config
.out_dir(&no_root_packages)
.default_package_filename("__.default")
.include_file("__.include.rs")
.compile_protos(
&[src.join("no_root_packages/widget_factory.proto")],
&[src.join("no_root_packages")],
Expand Down
33 changes: 33 additions & 0 deletions tests/src/no_root_packages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ pub mod widget {
}
}

pub mod generated_include {
include!(concat!(env!("OUT_DIR"), "/no_root_packages/__.include.rs"));
}

#[test]
fn test() {
use prost::Message;
Expand Down Expand Up @@ -44,3 +48,32 @@ fn test() {
widget_factory.gizmo_inner = Some(gizmo::gizmo::Inner {});
assert_eq!(14, widget_factory.encoded_len());
}

#[test]
fn generated_include() {
use prost::Message;

let mut widget_factory = generated_include::widget::factory::WidgetFactory::default();
assert_eq!(0, widget_factory.encoded_len());

widget_factory.inner = Some(generated_include::widget::factory::widget_factory::Inner {});
assert_eq!(2, widget_factory.encoded_len());

widget_factory.root = Some(generated_include::Root {});
assert_eq!(4, widget_factory.encoded_len());

widget_factory.root_inner = Some(generated_include::root::Inner {});
assert_eq!(6, widget_factory.encoded_len());

widget_factory.widget = Some(generated_include::widget::Widget {});
assert_eq!(8, widget_factory.encoded_len());

widget_factory.widget_inner = Some(generated_include::widget::widget::Inner {});
assert_eq!(10, widget_factory.encoded_len());

widget_factory.gizmo = Some(generated_include::gizmo::Gizmo {});
assert_eq!(12, widget_factory.encoded_len());

widget_factory.gizmo_inner = Some(generated_include::gizmo::gizmo::Inner {});
assert_eq!(14, widget_factory.encoded_len());
}

0 comments on commit baddf98

Please sign in to comment.