Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed the GPU indexing error when running in colab #1229

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mhwahdan
Copy link

@mhwahdan mhwahdan commented Dec 9, 2022

python train.py --workers 8 --device 0 --batch-size 16 --data data.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights yolov7x.pt --name yolov7 --hyp data/hyp.scratch.p5.yaml

I got this error

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

I modified the loss.py file to automatically get the index of the default GPU selected using torch.device('cuda') function

fixes #1224 #1045 #1101 #1225

when i used the command 

python train.py --workers 8 --device 0 --batch-size 16 --data data.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights yolov7x.pt --name yolov7 --hyp data/hyp.scratch.p5.yaml

I got this error

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

I modified the loss.py file to automatically get the index of the default GPU selected using torch.device('cuda') function

fixes WongKinYiu#1225 WongKinYiu#1224 WongKinYiu#1101 WongKinYiu#1045
fixed the GPU indexing error when running in colab
@mateusz-lichota
Copy link

+1, this change really needs to get merged in order for gpu training to be a painless process

utils/loss.py Outdated
@@ -682,8 +682,7 @@ def build_targets(self, p, targets, imgs):
all_gj.append(gj)
all_gi.append(gi)
all_anch.append(anch[i][idx])
from_which_layer.append(torch.ones(size=(len(b),)) * i)

from_which_layer.append((torch.ones(size=(len(b),)) * i).to('cuda'))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to initialize torch.ones to targets (or some other nearby tensor) to make cpu case work as well

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right

I have modified the code to detect the targets types and set the torch.ones according to them
That makes the CPU and GPU cases work

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved my problem!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. I wonder how it was working before without a need for such modification :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hope it gets merged

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WongKinYiu
Your review would be much appreciated :)

TimoLob added a commit to TimoLob/yolov7 that referenced this pull request Dec 27, 2022
@SkalskiP
Copy link

@WongKinYiu / @AlexeyAB, are there any plans to merge that change? The training script does not work with the latest PyTorch, which makes your installation instructions not work.

@magedhelmy1
Copy link

Fixes train_aux bug with masks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

indices should be either on cpu or on the same device as the indexed tensor (cpu)
7 participants