This commit is contained in:
liylo 2024-08-28 21:27:02 +08:00
parent 1c5519acf2
commit 949c6ed676
1 changed files with 0 additions and 49 deletions

View File

@ -57,55 +57,6 @@ Example::
cdim += a.shape[dim]
return s
def block_diag(*tensors):
"""Create a block diagonal matrix from provided tensors.
Args:
*tensors: One or more tensors with 0, 1, or 2 dimensions.
Returns:
Tensor: A 2 dimensional tensor with all the input tensors arranged in
order such that their upper left and lower right corners are
diagonally adjacent. All other elements are set to 0.
"""
requires_grad = tensors[0].requires_grad
rows = 0
cols = 0
for tensor in tensors:
shape = tensor.shape
if len(shape) == 0: # 0-d tensor
rows += 1
cols += 1
elif len(shape) == 1: # 1-d tensor
rows += 1
cols += shape[0]
elif len(shape) == 2: # 2-d tensor
rows += shape[0]
cols += shape[1]
result = jt.zeros((rows, cols))
result.requires_grad = requires_grad
current_row = 0
current_col = 0
for tensor in tensors:
shape = tensor.shape
if len(shape) == 0: # 0-d tensor
result[current_row, current_col] = tensor
current_row += 1
current_col += 1
elif len(shape) == 1: # 1-d tensor
result[current_row, current_col:current_col + shape[0]] = tensor
current_row += 1
current_col += shape[0]
elif len(shape) == 2: # 2-d tensor
result[current_row:current_row + shape[0], current_col:current_col + shape[1]] = tensor
current_row += shape[0]
current_col += shape[1]
return result
def check(bc):
bc = np.array(bc)
if ((bc != 1) * (bc != bc.max(0))).sum() > 0: