Skip to content

Commit

Permalink
[scripts] Fix convert_nnet2_to_nnet3.py (kaldi-asr#1774)
Browse files Browse the repository at this point in the history
Use 'execute_command' instead of 'run_kaldi_command' and correctly parse/convert SumGroupComponents
  • Loading branch information
entn-at authored and danpovey committed Jul 24, 2017
1 parent 5c3c142 commit dbdd284
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@ def write_model(self, model, binary="true"):
raise IOError("Config file {0} does not exist.".format(self.config))

# write raw model
common_lib.run_kaldi_command("nnet3-init --binary=true {0} {1}"
common_lib.execute_command("nnet3-init --binary=true {0} {1}"
.format(self.config, os.path.join(tmpdir, "nnet3.raw")))

# add transition model
common_lib.run_kaldi_command("nnet3-am-init --binary=true {0} {1} {2}"
common_lib.execute_command("nnet3-am-init --binary=true {0} {1} {2}"
.format(self.transition_model, os.path.join(tmpdir, "nnet3.raw"),
os.path.join(tmpdir, "nnet3_no_prior.mdl")))

# add priors
common_lib.run_kaldi_command("nnet3-am-adjust-priors "
common_lib.execute_command("nnet3-am-adjust-priors "
"--binary={0} {1} {2} {3}"
.format(binary, os.path.join(tmpdir, "nnet3_no_prior.mdl"),
self.priors, model))
Expand Down Expand Up @@ -270,6 +270,8 @@ def parse_component(line, line_buffer):
pairs = parse_fixed_scale_component(component, line, line_buffer)
elif component == "<FixedBiasComponent>":
pairs = parse_fixed_bias_component(component, line, line_buffer)
elif component == "<SumGroupComponent>":
pairs = parse_sum_group_component(component, line, line_buffer)
elif component in KNOWN_COMPONENTS:
pairs = parse_standard_component(component, line, line_buffer)
else:
Expand Down Expand Up @@ -300,6 +302,14 @@ def parse_fixed_scale_component(component, line, line_buffer):

return {"<Scales>" : filename}

def parse_sum_group_component(component, line, line_buffer):
line = consume_token(component, line)
line = consume_token("<Sizes>", line)

sizes = line.strip().strip("[]").strip().replace(' ', ',')

return {"<Sizes>" : sizes}

def parse_fixed_bias_component(component, line, line_buffer):
line = consume_token(component, line)
line = consume_token("<Bias>", line)
Expand Down Expand Up @@ -433,7 +443,7 @@ def Main():
tmpdir = tempfile.mkdtemp(dir=args.tmpdir)

# Convert nnet2 model to text and remove preconditioning
common_lib.run_kaldi_command("nnet-am-copy "
common_lib.execute_command("nnet-am-copy "
"--remove-preconditioning=true --binary=false {0}/{1} {2}/{1}"
.format(args.nnet2_dir, args.model, tmpdir))

Expand Down

0 comments on commit dbdd284

Please sign in to comment.