Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: can't pickle torch.distributed.ProcessGroupNCCL objects #279

Closed
onnx20 opened this issue Jul 3, 2020 · 26 comments · Fixed by #296
Closed

TypeError: can't pickle torch.distributed.ProcessGroupNCCL objects #279

onnx20 opened this issue Jul 3, 2020 · 26 comments · Fixed by #296
Assignees
Labels
bug Something isn't working

Comments

@onnx20
Copy link

onnx20 commented Jul 3, 2020

Hi,
I meet a problem:

Traceback (most recent call last):
File "train.py", line 394, in
train(hyp)
File "train.py", line 331, in train
torch.save(ckpt, last)
File "/home/yy/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 328, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/home/yy/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 401, in _legacy_save
pickler.dump(obj)
TypeError: can't pickle torch.distributed.ProcessGroupNCCL objects

Thanks!

environment:
ubuntu 16.04
GPU 2080Ti *4
pytorch 1.4.0

@onnx20 onnx20 added the enhancement New feature or request label Jul 3, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Jul 3, 2020

Hello @OYRQ, thank you for your interest in our work! Please visit our Custom Training Tutorial to get started, and see our Jupyter Notebook Open In Colab, Docker Image, and Google Cloud Quickstart Guide for example environments.

If this is a bug report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom model or data training question, please note that Ultralytics does not provide free personal support. As a leader in vision ML and AI, we do offer professional consulting, from simple expert advice up to delivery of fully customized, end-to-end production solutions for our clients, such as:

  • Cloud-based AI systems operating on hundreds of HD video streams in realtime.
  • Edge AI integrated into custom iOS and Android apps for realtime 30 FPS video inference.
  • Custom data training, hyperparameter evolution, and model exportation to any destination.

For more information please visit https://www.ultralytics.com.

@NanoCode012
Copy link
Contributor

Hm, I got this too after pulling from latest release. I thought I broke something in my code, but I guess not.

@onnx20
Copy link
Author

onnx20 commented Jul 3, 2020

Hm, I got this too after pulling from latest release. I thought I broke something in my code, but I guess not.

I think your code is broke, there are four GPUs running, but they have the same PID。

webwxgetmsgimg

@NanoCode012
Copy link
Contributor

NanoCode012 commented Jul 3, 2020

What do you mean? This is running Single process DistributedDataParallel, so it should be the same PID. I am working on Multi process DDP if that's what you're thinking about.

@GWwangshuo
Copy link

I encountered the same problem here after pulling from the lastest version. Any ideas how to fix it?

@NanoCode012
Copy link
Contributor

NanoCode012 commented Jul 3, 2020

I would just suggest waiting a while, see if others have the same issue and if @glenn-jocher would see this. It can be due to the most recent merge e02a189 . You can also use an earlier version. The 30th June was fine for me.

EDIT: Found out that this only happens for multiple GPU because of nccl backend. It works fine for single GPU. So you can run it by setting --device 0 or which ever single GPU ID.

@GWwangshuo
Copy link

@NanoCode012 Thanks for your suggestion. Branch of 30th June works fine for me.

@yxNONG
Copy link
Contributor

yxNONG commented Jul 3, 2020

@GWwangshuo @NanoCode012
i got this error with the new pull too, but my PR work, you guys can check it.

the problem may be in the new update in ema, i will try to find out what's going on with it

@NanoCode012
Copy link
Contributor

NanoCode012 commented Jul 3, 2020

@yxNONG , I’m not sure what went wrong. I tried to look through code. The issue is with saving for multi gpu, but the only place that is related is with saving ckpt for ema

The only thing else in commits are update to ONNX

@glenn-jocher
Copy link
Member

glenn-jocher commented Jul 3, 2020

@NanoCode012 Hi guys. Unfortunately I don't have access to a multi-gpu machine to debug. If anyone here finds the problem please submit a PR.

Seems like the current patch is to train single-gpu. Unit tests (run on single GPU) are all passing currently.
python train.py --device 0

@glenn-jocher glenn-jocher added bug Something isn't working TODO and removed enhancement New feature or request labels Jul 3, 2020
@glenn-jocher
Copy link
Member

To add a bit more detail, this issue likely originates in recent pushes to update the EMA code. I'll try to update the EMA handling to isolate it as single-GPU in all cases, as right now both the main model and the EMA are a confusing allowable mix of single GPU and DP.

We swapped test.py multigpu out last month for single-gpu FP16 testing during training, so I suppose this will go well with that change.

@NanoCode012
Copy link
Contributor

May I ask if you still have the test.py for multi GPU or can reference it?

So, ema.ema.module means that its distributed right? In which part does it become like that?

I only see that we pass model to Ema, and it creates deep copy called ema.ema . Where does module come from?

@glenn-jocher
Copy link
Member

glenn-jocher commented Jul 3, 2020

@NanoCode012 unit tests are here. We run these in a Colab notebook. You can modify the train device to multi-gpu i.e. 0,1. Be warned this will delete your default yolov5 directory if it exists, so you should run from a subdirectory.

# Unit tests
rm -rf yolov5 && git clone https://github.com/ultralytics/yolov5 && cd yolov5
export PYTHONPATH="$PWD" # to run *.py. files in subdirectories
pip install -r requirements.txt onnx
python3 -c "from utils.google_utils import *; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', 'coco128.zip')" && mv ./coco128 ../
for x in yolov5s yolov5m yolov5l yolov5x # models
do
  python train.py --weights $x.pt --cfg $x.yaml --epochs 4 --img 320 --device 0,1  # train
  for di in 0 cpu # inference devices
  do
    python detect.py --weights $x.pt --device $di  # detect official
    python detect.py --weights weights/last.pt --device $di  # detect custom
    python test.py --weights $x.pt --device $di # test official
    python test.py --weights weights/last.pt --device $di # test custom
  done
  python models/yolo.py --cfg $x.yaml # inspect
  python models/export.py --weights $x.pt --img 640 --batch 1 # export
done

@glenn-jocher
Copy link
Member

@NanoCode012 and everyone, I just pushed an EMA update which may or may not resolve this issue. The updates 1) creates and maintains the EMA as a single-device model at all times, and passes it to test.py this way and to checkpoint saving this way, and 2) reverts the EMA to FP16 to reduce device 0 GPU memory usage slightly.

This passes all single-gpu unit tests above, though as I said before someone with a multi-gpu machine should run the tests themselves to verify.

@NanoCode012
Copy link
Contributor

NanoCode012 commented Jul 3, 2020

Right now, i only have one gpu available, so I will test multiple later, but I got weird error when running it on single.

Calling python train.py --weights yolov5s.pt --epochs 4 --img 320 --device 1 modified from the UnitTest after pip install ..

Epoch   gpu_mem      GIoU       obj       cls     total   targets  img_size
       0/3     1.75G    0.1334    0.1112   0.04518    0.2898       235       320
               Class      Images     Targets           P           R      mAP@.5
Traceback (most recent call last):
  File "train.py", line 394, in <module>
    train(hyp)
  File "train.py", line 299, in train
    dataloader=testloader)
  File "test.py", line 97, in test
    output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
  File "utils/utils.py", line 605, in non_max_suppression
    i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
  File "python3.7/site-packages/torchvision/ops/boxes.py", line 35, in nms
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
RuntimeError: Trying to create tensor with negative dimension -1754807296: [-1754807296]

Second time I ran it, I got RuntimeError: Trying to create tensor with negative dimension -1754807296: [-1754807296]

Third time, RuntimeError: Trying to create tensor with negative dimension -1754807296: [-1754807296]

EDIT: My device 0 is busy. Could it be related? I also tried using CUDA_VISIBLE_DEVICES=1, same result.

@glenn-jocher
Copy link
Member

@NanoCode012 I don't know. I can't reproduce on Colab. May be specific to your environment? The tests are intended for Colab or Docker. Anything outside of that I can't speak for.

glenn-jocher added a commit that referenced this issue Jul 3, 2020
@glenn-jocher
Copy link
Member

@NanoCode012 got your same error in the docker container, but not on colab strangely enough. When I reverted EMA to FP32 this removed the docker error.

@NanoCode012
Copy link
Contributor

@glenn-jocher , thanks. That commit fixed the Single GPU issue for me. I also tested it on Colab, no issue.

For Multiple GPU, the same error persists unfortunately.

              Class      Images     Targets           P           R      mAP@.5
                 all         128         929       0.135       0.633       0.343       0.145
Traceback (most recent call last):
  File "train.py", line 394, in <module>
    train(hyp)
  File "train.py", line 331, in train
    torch.save(ckpt, last)
  File "torch/serialization.py", line 370, in save
    _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  File "torch/serialization.py", line 443, in _legacy_save
    pickler.dump(obj)
TypeError: can't pickle torch.distributed.ProcessGroupNCCL objects

Single: python train.py --weights yolov5s.pt --epochs 4 --img 320 --device 0
Double: python train.py --weights yolov5s.pt --epochs 4 --img 320 --device 0,3

@glenn-jocher
Copy link
Member

@NanoCode012 Ok. If you find a solution that works please submit a PR.

@glenn-jocher
Copy link
Member

@NanoCode012 why don't you try to move the EMA definition up before the DDP init?

@NanoCode012
Copy link
Contributor

@NanoCode012 why don't you try to move the EMA definition up before the DDP init?

Sure, that’s one of the things I looked to do. One other thing is whether I can move the parameters of model like model.nc before the DDP wrapper? I read on pytorch that it can cause unexpected behaviors when modifying a model’s parameters after DDP wrapped.

https://github.com/pytorch/pytorch/blob/15864d170384f584e9c8a06118781e9817ef8cc5/torch/nn/parallel/distributed.py#L138

Though, I was hesitant on doing it as I would have to modify all calls on model in the training loop?

@glenn-jocher
Copy link
Member

Ok! parameters in that context means model weights that have gradients. The values atta he’s to the model after DDP are class attributes.

@NanoCode012
Copy link
Contributor

I've tried a few things.

  • Moved EMA above DDP wrapper
  • Moved DDP wrapper below EMA (had to re-assign attrib)
  • Changed torch.save(ckpt, last) to torch.save(ema.ema, last)
  • Setting only device cuda:0 to be allowed to save ckpt

It always errors on line torch.save(..), so I looked into ema. I don't see any major change (please correct me), besides refactor from f02481c commit till now.

I did some print for k,v pairs in module.__dict__ and I notice that keys present when running in 1 gpu, differ from 2 or more.
3 keys from single gpu aren't present in multi-gpu. On the other hand, multiple keys from multi-gpu aren't present in single with one being very interesting.

# 1 gpu
md with val {'nc': 80, 'depth_multiple': 0.33, 'width_multiple': 0.5, 'anchors': [[11...
save with val [4, 6, 10, 14, 17, 18, 21, 22]
stride with val tensor([32., 16.,  8.])

# Multi gpu
process_group with val <torch.distributed.ProcessGroupNCCL object at 0x7f20ce4c6930> # Should we save this?
....

This was referenced Jul 4, 2020
@onnx20
Copy link
Author

onnx20 commented Jul 4, 2020

@NanoCode012 @glenn-jocher
Thank you very much for your help. Your work is so good that I can use e02a189 to work normally. I look forward to your solving this bug.

@NanoCode012
Copy link
Contributor

NanoCode012 commented Jul 4, 2020

@OYRQ , If you must use multi gpu immediately, you can simply clone my branch as @yxNONG and I have tested it out already.

Or you can also fix it yourself. It’s just one line.

If you encountered a bug, it’d be great to mention it to the PR. Thanks.

@glenn-jocher glenn-jocher reopened this Jul 4, 2020
@glenn-jocher
Copy link
Member

@NanoCode012 thanks for running the experiments! This process_group should definitely not be added, it must be the problem. I suppose we could insert a check into the EMA attribute update to prevent it from being added. Can you try this?

    def update_attr(self, model):
        # Update EMA attributes
        for k, v in model.__dict__.items():
            if not k.startswith('_') and k != 'module' and k != 'process_group':
                setattr(self.ema, k, v)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants