diff --git a/eval.py b/eval.py index d44b107..31fb825 100644 --- a/eval.py +++ b/eval.py @@ -99,7 +99,7 @@ def create_report(test, baseline, sample_rollouts, costs): for controller_cat, controller_type in [('baseline', args.baseline_controller), ('test', args.test_controller)]: print(f"Running batch rollouts => {controller_cat} controller: {controller_type}") rollout_partial = partial(run_rollout, controller_type=controller_type, model_path=args.model_path, debug=False) - results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16) + results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16, chunksize=10) costs += [{'controller': controller_cat, **result[0]} for result in results] create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs) diff --git a/tinyphysics.py b/tinyphysics.py index 0e31244..f73df3f 100644 --- a/tinyphysics.py +++ b/tinyphysics.py @@ -22,7 +22,7 @@ ACC_G = 9.81 FPS = 10 CONTROL_START_IDX = 100 -COST_END_IDX = 550 +COST_END_IDX = 500 CONTEXT_LENGTH = 20 VOCAB_SIZE = 1024 LATACCEL_RANGE = [-5, 5] @@ -148,10 +148,10 @@ def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float]: State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']), state['target_lataccel'], FuturePlan( - lataccel=self.data['target_lataccel'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(), - roll_lataccel=self.data['roll_lataccel'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(), - v_ego=self.data['v_ego'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(), - a_ego=self.data['a_ego'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist() + lataccel=self.data['target_lataccel'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(), + roll_lataccel=self.data['roll_lataccel'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(), + v_ego=self.data['v_ego'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(), + a_ego=self.data['a_ego'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist() ) ) @@ -232,7 +232,7 @@ def run_rollout(data_path, controller_type, model_path, debug=False): elif data_path.is_dir(): run_rollout_partial = partial(run_rollout, controller_type=args.controller, model_path=args.model_path, debug=False) files = sorted(data_path.iterdir())[:args.num_segs] - results = process_map(run_rollout_partial, files, max_workers=16) + results = process_map(run_rollout_partial, files, max_workers=16, chunksize=10) costs = [result[0] for result in results] costs_df = pd.DataFrame(costs) print(f"\nAverage lataccel_cost: {np.mean(costs_df['lataccel_cost']):>6.4}, average jerk_cost: {np.mean(costs_df['jerk_cost']):>6.4}, average total_cost: {np.mean(costs_df['total_cost']):>6.4}")