-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using args-mutation annotation to support inplace op in torch-blade pipline #1209
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor comments.
@@ -765,7 +765,7 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite( | |||
OpAdaptor adaptor, | |||
ConversionPatternRewriter& rewriter) const { | |||
// Not a tensor type. | |||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>(); | |||
auto selfType = adaptor.getSelf().getType().cast<RankedTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cast will raise error if it's not possible to cast? so the following if
won't get reached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, will polish this in the next PR.
@@ -2725,6 +2726,8 @@ class ShapePropagator : public PropertyPropBase { | |||
int64_t end = endOptional.value() != c10::nullopt | |||
? node->get<int64_t>(attr::end).value() | |||
: INT64_MAX; | |||
start = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: So we could not handle negative index before ... Orz..
func.walk([&](Operation* op) { | ||
if (isa<OverwriteTensorContentsOp>(op)) { | ||
rewriter.setInsertionPoint(op); | ||
auto nonValueTensor = rewriter.create<CopyToNonValueTensorOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to convert OverwriteTensorContentsOp
to CopyToNonValueTensorOp
and then to disc.ArgsMutation
? Is it possible to skip CopyToNonValueTensorOp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe CopyToNonValueTensorOp
is required, the function reduceTensorConversions(func)
at L79 removs all Tensor conversations and convert all !tensor
to !vtensor
, but OverwriteTensorContentsOp
operand requires !tensor
, so we need to insert CopyToNonValueTensorOp
to make OverwriteTensorContentsOp
validation after reduceTensorConversions
.
fixed #1218