mirror of https://github.com/Jittor/Jittor
add changelog
This commit is contained in:
parent
060686fafd
commit
bec2e710af
|
@ -60,6 +60,17 @@ class flag_scope(_call_no_record_scope):
|
|||
setattr(flags, k, v)
|
||||
|
||||
class no_grad(flag_scope):
|
||||
''' no_grad scope, all variable created inside this
|
||||
scope will stop grad.
|
||||
|
||||
Example::
|
||||
|
||||
import jittor as jt
|
||||
|
||||
with jt.no_grad():
|
||||
...
|
||||
|
||||
'''
|
||||
def __init__(self, **jt_flags):
|
||||
self.jt_flags = jt_flags
|
||||
jt_flags["no_grad"] = 1
|
||||
|
|
|
@ -31,6 +31,23 @@ def matmul_transpose(a, b):
|
|||
|
||||
|
||||
def bmm(a, b):
|
||||
''' batch matrix multiply,
|
||||
shape of input a is [batch, n, m],
|
||||
shape of input b is [batch, m, k],
|
||||
return shape is [batch, n, k]
|
||||
|
||||
Example::
|
||||
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
|
||||
batch, n, m, k = 100, 5, 6, 7
|
||||
|
||||
a = jt.random((batch, n, m))
|
||||
b = jt.random((batch, m, k))
|
||||
c = nn.bmm(a, b)
|
||||
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
|
|
Loading…
Reference in New Issue