Skip to content

Commit

Permalink
Dispatch trivial variable operators to C++ aten functions. (pytorch#3372
Browse files Browse the repository at this point in the history
)

Implement __comparison_ops__ by calling the VariableBase methods.
  • Loading branch information
gchanan authored and colesbury committed Oct 30, 2017
1 parent 8cd0df0 commit 3e6e81d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 43 deletions.
12 changes: 12 additions & 0 deletions tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ namespace torch { namespace autograd {
${py_methods}

PyMethodDef variable_methods[] = {
{"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
{"__radd__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
{"__iadd__", (PyCFunction)THPVariable_add_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__rmul__", (PyCFunction)THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL},
{"__mul__", (PyCFunction)THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL},
{"__imul__", (PyCFunction)THPVariable_mul_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__sub__", (PyCFunction)THPVariable_sub, METH_VARARGS | METH_KEYWORDS, NULL},
{"__isub__", (PyCFunction)THPVariable_sub_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__div__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL},
{"__truediv__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL},
{"__idiv__", (PyCFunction)THPVariable_div_, METH_VARARGS | METH_KEYWORDS, NULL},
{"__mod__", (PyCFunction)THPVariable_remainder, METH_VARARGS | METH_KEYWORDS, NULL},
${py_method_defs}
{NULL}
};
Expand Down
51 changes: 8 additions & 43 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,38 +426,18 @@ def multinomial(self, num_samples=1, replacement=False):
def bernoulli(self):
return Bernoulli.apply(self)

__radd__ = __add__ = _C._VariableBase.add

def __iadd__(self, other):
return self.add_(other)

__sub__ = _C._VariableBase.sub

def __isub__(self, other):
return self.sub_(other)

def __rsub__(self, other):
return -self + other

__rmul__ = __mul__ = _C._VariableBase.mul

def __imul__(self, other):
return self.mul_(other)

def __matmul__(self, other):
if not isinstance(other, Variable):
return NotImplemented
return self.matmul(other)

__truediv__ = __div__ = _C._VariableBase.div

def __rdiv__(self, other):
return self.reciprocal() * other
__rtruediv__ = __rdiv__

def __idiv__(self, other):
return self.div_(other)

__pow__ = _C._VariableBase.pow

def __ipow__(self, other):
Expand All @@ -466,8 +446,14 @@ def __ipow__(self, other):
def __rpow__(self, other):
return PowConstant.apply(other, self)

def __neg__(self):
return Negate.apply(self)
__neg__ = _C._VariableBase.neg

__eq__ = _C._VariableBase.eq
__ne__ = _C._VariableBase.ne
__lt__ = _C._VariableBase.lt
__le__ = _C._VariableBase.le
__gt__ = _C._VariableBase.gt
__ge__ = _C._VariableBase.ge

def __len__(self):
return len(self.data)
Expand All @@ -481,27 +467,6 @@ def __iter__(self):
# map will interleave them.)
return iter(imap(lambda i: self[i], range(self.size(0))))

def __mod__(self, other):
return self.remainder(other)

def __eq__(self, other):
return self.eq(other)

def __ne__(self, other):
return self.ne(other)

def __lt__(self, other):
return self.lt(other)

def __le__(self, other):
return self.le(other)

def __gt__(self, other):
return self.gt(other)

def __ge__(self, other):
return self.ge(other)

def __hash__(self):
return id(self)

Expand Down

0 comments on commit 3e6e81d

Please sign in to comment.