diff --git a/code/chapter-8/01_classification/train_main.py b/code/chapter-8/01_classification/train_main.py index 77e8c71..c431bcd 100644 --- a/code/chapter-8/01_classification/train_main.py +++ b/code/chapter-8/01_classification/train_main.py @@ -28,10 +28,10 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--data-path", default=r"G:\deep_learning_data\chest_xray", type=str, help="dataset path") + parser.add_argument("--data-path", required = True, type=str, help="dataset path, like G:\deep_learning_data\chest_xray/train") parser.add_argument("--model", default="convnext-tiny", type=str, help="model name; resnet50/convnext/convnext-tiny") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") + parser.add_argument("--device", default="cpu", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" ) @@ -216,5 +216,7 @@ def main(args): if __name__ == "__main__": args = get_args_parser().parse_args() utils.setup_seed(args.random_seed) - args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device_choices = ["cuda", "mps", "cpu"] + device = torch.device(args.device) if args.device in device_choices else torch.device("cpu") + args.device = device main(args)