Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using args-mutation annotation to support inplace op in torch-blade pipline #1209

Merged
merged 9 commits into from
Jul 2, 2023

Conversation

Yancey1989
Copy link
Collaborator

@Yancey1989 Yancey1989 commented Jun 21, 2023

fixed #1218

@Yancey1989 Yancey1989 changed the title support input mutation torch-blade using args-mutation annotation to support inplace op in torch-blade pipline Jun 21, 2023
Copy link
Collaborator

@qiuxiafei qiuxiafei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with minor comments.

@@ -765,7 +765,7 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
// Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
auto selfType = adaptor.getSelf().getType().cast<RankedTensorType>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast will raise error if it's not possible to cast? so the following if won't get reached?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, will polish this in the next PR.

@@ -2725,6 +2726,8 @@ class ShapePropagator : public PropertyPropBase {
int64_t end = endOptional.value() != c10::nullopt
? node->get<int64_t>(attr::end).value()
: INT64_MAX;
start =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: So we could not handle negative index before ... Orz..

func.walk([&](Operation* op) {
if (isa<OverwriteTensorContentsOp>(op)) {
rewriter.setInsertionPoint(op);
auto nonValueTensor = rewriter.create<CopyToNonValueTensorOp>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to convert OverwriteTensorContentsOp to CopyToNonValueTensorOp and then to disc.ArgsMutation? Is it possible to skip CopyToNonValueTensorOp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe CopyToNonValueTensorOp is required, the function reduceTensorConversions(func) at L79 removs all Tensor conversations and convert all !tensor to !vtensor, but OverwriteTensorContentsOp operand requires !tensor , so we need to insert CopyToNonValueTensorOp to make OverwriteTensorContentsOp validation after reduceTensorConversions.

@Yancey1989 Yancey1989 merged commit 56de4c9 into main Jul 2, 2023
24 checks passed
@Yancey1989 Yancey1989 deleted the input_mutation branch July 2, 2023 08:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

support inplace operator in TorchBlade
2 participants