mirror of https://github.com/Jittor/Jittor
init
This commit is contained in:
parent
1c5519acf2
commit
949c6ed676
|
@ -57,55 +57,6 @@ Example::
|
|||
cdim += a.shape[dim]
|
||||
return s
|
||||
|
||||
def block_diag(*tensors):
|
||||
"""Create a block diagonal matrix from provided tensors.
|
||||
|
||||
Args:
|
||||
*tensors: One or more tensors with 0, 1, or 2 dimensions.
|
||||
|
||||
Returns:
|
||||
Tensor: A 2 dimensional tensor with all the input tensors arranged in
|
||||
order such that their upper left and lower right corners are
|
||||
diagonally adjacent. All other elements are set to 0.
|
||||
"""
|
||||
requires_grad = tensors[0].requires_grad
|
||||
|
||||
rows = 0
|
||||
cols = 0
|
||||
for tensor in tensors:
|
||||
shape = tensor.shape
|
||||
if len(shape) == 0: # 0-d tensor
|
||||
rows += 1
|
||||
cols += 1
|
||||
elif len(shape) == 1: # 1-d tensor
|
||||
rows += 1
|
||||
cols += shape[0]
|
||||
elif len(shape) == 2: # 2-d tensor
|
||||
rows += shape[0]
|
||||
cols += shape[1]
|
||||
|
||||
result = jt.zeros((rows, cols))
|
||||
result.requires_grad = requires_grad
|
||||
|
||||
current_row = 0
|
||||
current_col = 0
|
||||
for tensor in tensors:
|
||||
shape = tensor.shape
|
||||
if len(shape) == 0: # 0-d tensor
|
||||
result[current_row, current_col] = tensor
|
||||
current_row += 1
|
||||
current_col += 1
|
||||
elif len(shape) == 1: # 1-d tensor
|
||||
result[current_row, current_col:current_col + shape[0]] = tensor
|
||||
current_row += 1
|
||||
current_col += shape[0]
|
||||
elif len(shape) == 2: # 2-d tensor
|
||||
result[current_row:current_row + shape[0], current_col:current_col + shape[1]] = tensor
|
||||
current_row += shape[0]
|
||||
current_col += shape[1]
|
||||
|
||||
return result
|
||||
|
||||
def check(bc):
|
||||
bc = np.array(bc)
|
||||
if ((bc != 1) * (bc != bc.max(0))).sum() > 0:
|
||||
|
|
Loading…
Reference in New Issue