mirror of https://github.com/Jittor/Jittor
add fake Parameter and backward interface
This commit is contained in:
parent
1d0df10f13
commit
8aa478fa5e
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.39'
|
||||
__version__ = '1.2.3.40'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -1254,6 +1254,7 @@ def get_len(var):
|
|||
Var.__len__ = get_len
|
||||
int = int32
|
||||
Var.int = Var.int32
|
||||
Var.long = Var.int32
|
||||
float = float32
|
||||
Var.float = Var.float32
|
||||
double = float64
|
||||
|
|
|
@ -1605,6 +1605,46 @@ class ParameterList(Module):
|
|||
|
||||
ParameterDict = ParameterList
|
||||
|
||||
def Parameter(data, requires_grad=True):
|
||||
''' The `Parameter` interface isn't needed in Jittor, this interface
|
||||
doesn't nothings and it is just used for compatible.
|
||||
|
||||
A Jittor Var is a Parameter
|
||||
when it is a member of Module, if you don't want a Jittor
|
||||
Var menber is treated as a Parameter, just name it startswith
|
||||
underscore `_`.
|
||||
'''
|
||||
LOG.w(Parameter.__doc__)
|
||||
data = data.clone()
|
||||
data.requires_grad = requires_grad
|
||||
return data
|
||||
|
||||
def backward(v, *args, **kw):
|
||||
''' The `backward` variable interface doesn't exist in Jittor.
|
||||
please use `optimizer.backward(loss)` or
|
||||
`optimizer.step(loss)` instead.
|
||||
For example, if your code looks like this::
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
It can be changed to this::
|
||||
|
||||
optimizer.zero_grad()
|
||||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
Or more concise::
|
||||
|
||||
optimizer.step(loss)
|
||||
|
||||
The step function will automatically zero grad and backward.
|
||||
'''
|
||||
LOG.f(backward.__doc__)
|
||||
|
||||
jt.Var.backward = backward
|
||||
|
||||
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
||||
assert X.ndim == 4
|
||||
if not isinstance(kernel_size, tuple):
|
||||
|
|
Loading…
Reference in New Issue