Skip to content

Commit

Permalink
add chunksize
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwandavek committed May 30, 2024
1 parent dba6893 commit aaf974f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create_report(test, baseline, sample_rollouts, costs):
for controller_cat, controller_type in [('baseline', args.baseline_controller), ('test', args.test_controller)]:
print(f"Running batch rollouts => {controller_cat} controller: {controller_type}")
rollout_partial = partial(run_rollout, controller_type=controller_type, model_path=args.model_path, debug=False)
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16)
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16, chunksize=10)
costs += [{'controller': controller_cat, **result[0]} for result in results]

create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs)
12 changes: 6 additions & 6 deletions tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ACC_G = 9.81
FPS = 10
CONTROL_START_IDX = 100
COST_END_IDX = 550
COST_END_IDX = 500
CONTEXT_LENGTH = 20
VOCAB_SIZE = 1024
LATACCEL_RANGE = [-5, 5]
Expand Down Expand Up @@ -148,10 +148,10 @@ def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float]:
State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']),
state['target_lataccel'],
FuturePlan(
lataccel=self.data['target_lataccel'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(),
roll_lataccel=self.data['roll_lataccel'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(),
v_ego=self.data['v_ego'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(),
a_ego=self.data['a_ego'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist()
lataccel=self.data['target_lataccel'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(),
roll_lataccel=self.data['roll_lataccel'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(),
v_ego=self.data['v_ego'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(),
a_ego=self.data['a_ego'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist()
)
)

Expand Down Expand Up @@ -232,7 +232,7 @@ def run_rollout(data_path, controller_type, model_path, debug=False):
elif data_path.is_dir():
run_rollout_partial = partial(run_rollout, controller_type=args.controller, model_path=args.model_path, debug=False)
files = sorted(data_path.iterdir())[:args.num_segs]
results = process_map(run_rollout_partial, files, max_workers=16)
results = process_map(run_rollout_partial, files, max_workers=16, chunksize=10)
costs = [result[0] for result in results]
costs_df = pd.DataFrame(costs)
print(f"\nAverage lataccel_cost: {np.mean(costs_df['lataccel_cost']):>6.4}, average jerk_cost: {np.mean(costs_df['jerk_cost']):>6.4}, average total_cost: {np.mean(costs_df['total_cost']):>6.4}")
Expand Down

0 comments on commit aaf974f

Please sign in to comment.