add fake Parameter and backward interface

This commit is contained in:
Dun Liang 2021-06-23 21:40:57 +08:00
parent 1d0df10f13
commit 8aa478fa5e
2 changed files with 42 additions and 1 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.39'
__version__ = '1.2.3.40'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -1254,6 +1254,7 @@ def get_len(var):
Var.__len__ = get_len
int = int32
Var.int = Var.int32
Var.long = Var.int32
float = float32
Var.float = Var.float32
double = float64

View File

@ -1605,6 +1605,46 @@ class ParameterList(Module):
ParameterDict = ParameterList
def Parameter(data, requires_grad=True):
''' The `Parameter` interface isn't needed in Jittor, this interface
doesn't nothings and it is just used for compatible.
A Jittor Var is a Parameter
when it is a member of Module, if you don't want a Jittor
Var menber is treated as a Parameter, just name it startswith
underscore `_`.
'''
LOG.w(Parameter.__doc__)
data = data.clone()
data.requires_grad = requires_grad
return data
def backward(v, *args, **kw):
''' The `backward` variable interface doesn't exist in Jittor.
please use `optimizer.backward(loss)` or
`optimizer.step(loss)` instead.
For example, if your code looks like this::
optimizer.zero_grad()
loss.backward()
optimizer.step()
It can be changed to this::
optimizer.zero_grad()
optimizer.backward(loss)
optimizer.step()
Or more concise::
optimizer.step(loss)
The step function will automatically zero grad and backward.
'''
LOG.f(backward.__doc__)
jt.Var.backward = backward
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
assert X.ndim == 4
if not isinstance(kernel_size, tuple):