Compare commits

...

328 Commits

Author SHA1 Message Date
lidongyang 845c24c9f8 update to 1.3.10.0 2025-07-28 18:50:28 +08:00
DongYang Li d892a83d1c
Merge pull request #654 from Jittor/hyx
merge hw backend
2025-07-28 18:36:24 +08:00
lidongyang 4017b161d2 fix master 2025-07-28 18:33:35 +08:00
uyzhang c78db2a794 enable cuda and acl 2025-07-19 11:05:30 +08:00
uyzhang f8e44de79d merge by JittorHW 2025-07-19 08:59:51 +08:00
lidongyang b79ac22b05 add updated code 2025-07-15 20:05:00 +08:00
lidongyang 2f37158e3e revert huawei support code 2025-07-15 19:51:30 +08:00
Yuxuan Han daf04e9fb5
Merge pull request #646 from Jittor/fixHW
adjust aclnn.h reference
2025-06-17 13:34:28 +08:00
Yuxuan Han 3cf5d7f2a4 adjust aclnn.h reference 2025-06-16 11:05:05 +08:00
Zikai Xiao 330dec69d2 Merge pull request #623 from 514flowey/master
Fix Get Item Problem. Warning: This change has not passed a completely check.
2025-06-10 21:52:37 +08:00
514flowey 58192fc7ef fix cutlass.zip url 2025-06-10 21:50:39 +08:00
Zikai Xiao 8f5048882f Merge branch 'Jittor:master' into master 2025-06-10 21:48:20 +08:00
Yi Zhang 330ef620f7 Merge pull request #626 from fleurs03/master
fix unqualified call to 'std::move'
2025-05-14 11:27:01 +08:00
Yi Zhang b04e197c22 Merge pull request #638 from Jittor/JittorHW
Update Huawei ACL
2025-05-14 11:26:33 +08:00
Yuxuan Han 91190b949e temporarily unable hccl 2025-05-13 09:09:37 +08:00
Yuxuan Han e8f94f4003 update from HuaWei ACL 2025-05-12 15:28:18 +08:00
DongYang Li 0abdc60b77 Merge pull request #630 from Exusial/rdkit
debug rdkit
2025-04-22 20:20:49 +08:00
Exusial 45465befd3 fix histc. 2025-04-22 20:16:43 +08:00
Yuxuan Han 6407252044 fix bug: Conv2dBackward gradBias shape 2025-04-13 23:14:18 +08:00
Exusial 917d122e96 Merge branch 'master' of https://github.com/Jittor/jittor 2025-03-30 09:18:43 +08:00
CHEN Xinsheng 9b53d2b5a7 Merge pull request #632 from CHEN-Xinsheng/reduce-memory
reduce SFRL large block size to 5242880
2025-03-18 20:09:33 +08:00
Xinsheng Chen 95a17684fa reduce memory alloc (in certain cases) 2025-03-18 16:04:39 +08:00
DongYang Li 4c75b24cc9 Update README.md 2025-03-05 19:46:54 +08:00
Exusial 449874356d debug. 2025-02-26 16:21:30 +08:00
CHEN Xinsheng 8a74e9e78a support `jt.any` with argument `dim` 2025-02-18 23:29:05 +08:00
CHEN Xinsheng 23abcda711 copy `jittor.attention` from jittor official repo 2025-02-18 22:06:54 +08:00
CHEN Xinsheng cc3b402913 allow the input of `concat` to be tuple` 2025-02-18 20:56:57 +08:00
zjp_shadow e4be9b1f78 Update HCCL to support multi npus 2025-02-10 16:18:34 +08:00
DongYang Li 86841e858d fix arch90 2025-02-09 02:30:35 +08:00
DongYang Li b166e4e385 Update attention.py fix parameter_name error 2025-02-08 18:03:34 +08:00
MenghaoGuo 646a0346fb Merge pull request #627 from plutoZZZZ/master
update Cusparse op
2025-01-03 17:23:03 +08:00
Yuxuan Han 7ba878bf49 Merge pull request #17 from CSCG-Lab/splits
fix conv2dbackward
2025-01-02 14:40:25 +08:00
Yuxuan Han 2a193eb836 fix conv2dbackward 2025-01-02 14:32:02 +08:00
lusz ece4e3efaa update cusparse trans 2024-12-29 17:25:18 +08:00
lusz 02c3173def update cusparse trans 2024-12-29 17:23:49 +08:00
lusz 1bf6f73d4c update cusparse trans 2024-12-29 16:48:34 +08:00
Yuxuan Han 4e462a6b85 Merge pull request #16 from CSCG-Lab/splits
delete acl_op.h
2024-12-24 10:09:13 +08:00
Exusial 6483f2710b delete acl_op.h 2024-12-24 10:04:44 +08:00
Yuxuan Han 71b990590d Merge pull request #15 from CSCG-Lab/splits
fix conv,relu, split expand
2024-12-23 22:05:31 +08:00
Exusial 9679b992a5 fix conv,relu, split expand 2024-12-23 17:24:09 +08:00
Yuxuan Han 5ce5b45c58 Merge pull request #14 from CSCG-Lab/splits
split triu,embedding,batchnorm
2024-12-23 15:57:30 +08:00
Yuxuan Han 36cc2b33d6 Merge branch 'main' into splits 2024-12-23 15:57:22 +08:00
Exusial fb89c96cc4 split triu,embedding,batchnorm 2024-12-23 15:49:07 +08:00
Exusial 4c2c9bc8e1 polish code 2024-12-23 13:43:21 +08:00
Yi Zhang 093d562aeb Update acl_op_exec.cc 2024-12-23 12:32:31 +08:00
Yi Zhang 572f4301c2 Update acl_compiler.py 2024-12-20 22:44:49 +08:00
Yuxuan Han 1c752fbd83 Merge pull request #13 from CSCG-Lab/splits
split stack,rope,nantonum
2024-12-19 19:42:16 +08:00
Exusial 130f02814d split stack,rope,nantonum 2024-12-19 19:40:37 +08:00
lidongyang 8419709e31 update version 1.3.9.14 2024-12-19 16:33:10 +08:00
lidongyang 14f000f867 tmp fix zipfile for jittorllama 2024-12-19 16:32:37 +08:00
DongYang Li 30bb3dbf22 Merge pull request #625 from Exusial/master
fix duplicate definition in cudnnops
2024-12-19 16:29:32 +08:00
Yuxuan Han c3a6df6682 Merge pull request #12 from CSCG-Lab/splits
split silu,sigmoid,softmax
2024-12-19 10:18:26 +08:00
Exusial 144b7bc57d split silu,sigmoid,softmax 2024-12-19 10:17:06 +08:00
Yuxuan Han 4219d445a4 Merge pull request #11 from CSCG-Lab/splits
split relu,dropout,transpose,flashattention
2024-12-19 09:37:32 +08:00
Exusial bfe1ceb82b split relu,dropout,transpose,flashattention 2024-12-19 09:36:32 +08:00
hjc21 67f9b7ad61 fix unqualified call to 'std::move' 2024-12-18 17:09:24 +08:00
Exusial 9ce77dfb82 fix duplicate definition in cudnnops 2024-12-18 14:19:04 +08:00
CHEN Xinsheng 2a67644b0d fix lack of import 2024-12-17 19:44:35 +08:00
514flowey 4225804df2 Merge branch 'Jittor:master' into master 2024-12-16 22:25:47 +08:00
514flowey eefd57c0f4 Merge branch 'master' of github.com:514flowey/jittor 2024-12-16 22:20:20 +08:00
514flowey 9e7e479df2 Add Index Check for Get Item. Warning: It may slow down the speed, and has not passed a fully check! 2024-12-16 22:19:55 +08:00
Yuxuan Han fa89429a21 Merge pull request #10 from CSCG-Lab/splits
split where,scatter,floor
2024-12-14 10:37:15 +08:00
Exusial 8762352c64 split where,scatter,floor 2024-12-14 10:36:37 +08:00
MenghaoGuo ac78f57a7e Merge pull request #622 from plutoZZZZ/master
add cuda extern: Cusparse
2024-12-13 17:08:36 +08:00
Yuxuan Han 1554d416b0 Merge pull request #9 from CSCG-Lab/splits
split cumsum,gather,index
2024-12-13 14:32:01 +08:00
Exusial 064af9d543 split cumsum,gather,index 2024-12-13 14:31:29 +08:00
Yuxuan Han 5a30cd334f Merge pull request #8 from CSCG-Lab/splits
split maxpool,flip,concat
2024-12-12 19:52:31 +08:00
Exusial f8c8f7e8d7 split maxpool,flip,concat 2024-12-12 19:49:06 +08:00
Exusial 64e3ceb59e shut off sync expect reduce op 2024-12-12 17:02:51 +08:00
Exusial c3b1f380eb Merge branch 'main' of https://github.com/CSCG-Lab/JittorHW 2024-12-12 16:31:40 +08:00
Exusial 14af6f0980 Merge branch 'ddd' 2024-12-12 16:27:07 +08:00
Exusial 2b63a07aa0 udpate base. 2024-12-12 16:26:38 +08:00
Exusial c9c02508d4 Debug nan. 2024-12-12 14:49:18 +08:00
lusz a5fdfd1408 cusparse 2024-12-11 21:38:17 +08:00
lusz 2d93b36cbb cusparse 2024-12-11 21:33:33 +08:00
Exusial 722cb8e3fc add sync in broadcast_to when shape is [1] 2024-12-10 19:20:18 +08:00
Exusial da6acc6cc3 Add flags for sync. 2024-12-10 10:17:58 +08:00
Exusial d1b313bf1d add random 2024-12-09 17:18:27 +08:00
Exusial ce533cbeb3 fix setitem 2024-12-09 16:22:37 +08:00
Exusial 99413285cb modify getitem & setitem 2024-12-07 18:52:34 +08:00
邓一轩 1db9bc2993 splite matmul and bmm from acl_op 2024-12-07 11:16:49 +08:00
Exusial 420f94f283 update 2024-12-06 14:23:08 +08:00
Exusial 86331a8d8f add setitem & getitem op 2024-12-06 11:21:55 +08:00
Exusial fb00b8a558 add conv_forward op. 2024-12-04 15:45:48 +08:00
Exusial 15a7fba3da add conv_op. 2024-12-04 15:41:38 +08:00
Exusial 5544147573 Merge branch 'main' of https://github.com/CSCG-Lab/JittorHW 2024-12-04 15:40:05 +08:00
Exusial d71e59b262 fixed bug of cpp 2024-12-04 15:36:25 +08:00
CHEN Xinsheng 135446ca59 improve reduce op output 2024-12-03 11:12:56 +08:00
Exusial 3bea663698 fixed the bug of not recompile 2024-12-02 22:24:01 +08:00
Exusial f7edd32327 fix bug 2024-12-02 17:53:42 +08:00
Exusial e24a37f5ce add base op class 2024-12-01 23:36:28 +08:00
DongYang Li 63d9392e49 update version 1.3.9.13 2024-11-28 22:13:06 +08:00
CHEN Xinsheng acf5d1a05e add `jt.Var.isnan` and `jt.Var.isinf` 2024-11-28 22:04:42 +08:00
514flowey 7638ab5ffb Merge pull request #604 from 514flowey/master
Fix RNN code op bug
2024-11-28 19:41:29 +08:00
514flowey 421d5a4fa4 Merge branch 'Jittor:master' into master 2024-11-28 19:40:55 +08:00
514flowey 9ee61d26f1 fix rnn op bug 2024-11-28 19:40:08 +08:00
Exusial 352bb8d6a7 update reduce. 2024-11-28 16:24:07 +08:00
Exusial ca712e241b update 2024-11-28 15:45:15 +08:00
Exusial a1add64d6c fix compile include aclops.h in aclops 2024-11-28 13:24:47 +08:00
Exusial 0dc84ebed8 update get_dtype 2024-11-27 20:21:34 +08:00
Exusial 8e5ee574f5 merge reduce. 2024-11-27 19:31:48 +08:00
Exusial edf2755cb5 Merge branch 'main' into dev 2024-11-27 11:18:47 +08:00
Exusial 8c33770036 update. 2024-11-27 11:18:19 +08:00
张仪 89f5b98741 split binary and unary op by hy 2024-11-26 23:22:41 +08:00
张仪 e8ae65d797 update 2024-11-26 21:27:06 +08:00
CHEN Xinsheng d4793e2146 Merge pull request #603 from CHEN-Xinsheng/master
fix `nn.Dropout` (dtype convert)
2024-11-26 19:08:17 +08:00
Xinsheng Chen 2bab0bb8dd fix `nn.Dropout` (dtype convert) 2024-11-26 19:07:20 +08:00
CHEN Xinsheng b37fae105b fix `nn.Dropout` (dtype) 2024-11-26 18:59:12 +08:00
Yuxuan Han 8ee6a45d5c Merge pull request #7 from CSCG-Lab/unittest
add isnan, isinf
2024-11-26 10:05:03 +08:00
Yuxuan Han 2580a98710 add isnan, isinf 2024-11-26 10:01:58 +08:00
Yuxuan Han 8a55cfe5b3 Merge pull request #5 from CSCG-Lab/unittest
add stack
2024-11-24 09:47:29 +08:00
Yuxuan Han 66e18b85a8 add stack 2024-11-24 09:46:23 +08:00
张仪 2e137f73b0 big op and recompile bug 2024-11-21 23:59:41 +08:00
Yuxuan Han c8b76acece Merge pull request #4 from CSCG-Lab/unittest
Fix relu grad, flip. Add more unittest
2024-11-21 16:01:35 +08:00
Yuxuan Han d6917eda4c fix relu grad, add more unnittest 2024-11-21 15:59:40 +08:00
Yuxuan Han 0ff1deccb7 fix flip, add softmax 2024-11-21 15:12:12 +08:00
CHEN Xinsheng 8159093262 fix `getitem` (list case) 2024-11-18 20:15:39 +08:00
CHEN Xinsheng d3f2dc5606 add: support numpy int as an index for `getitem` 2024-11-18 17:36:22 +08:00
张仪 47f0c8acda update unit test 2024-11-18 11:54:32 +08:00
CHEN Xinsheng 0b637852f1 Merge pull request #601 from CHEN-Xinsheng/master
fix `stack`
2024-11-14 19:31:45 +08:00
CHEN Xinsheng 1d0602ae32 fix `stack` 2024-11-14 19:30:22 +08:00
CHEN Xinsheng 19b7bbbe57 fix `stack` 2024-11-14 18:01:25 +08:00
CHEN Xinsheng a5b16925e8 add unit test for `any` 2024-11-14 17:35:25 +08:00
CHEN Xinsheng 9b6fd17e20 add `jt.Var.cumsum` and `jt.Var.cub_cumsum` 2024-11-14 17:34:48 +08:00
DongYang Li c10acf34bc Update version 1.3.9.12 2024-11-14 16:17:33 +08:00
CHEN Xinsheng 3d06d25077 add ACL op `any` 2024-11-12 10:32:07 +08:00
CHEN Xinsheng a0dfdc5ff0 add `getitem` (`None` case) 2024-11-11 21:33:55 +08:00
邓一轩 e0537e5c1a concat 2024-11-11 19:58:31 +08:00
邓一轩 58fc5a9b35 fix conv 2024-11-11 19:43:56 +08:00
邓一轩 495a26a458 add sync at end of all op 2024-11-11 17:09:22 +08:00
邓一轩 b9986ac53b use switch 2024-11-11 15:52:35 +08:00
CHEN Xinsheng c747053a54 fix `concat` 2024-11-09 20:13:47 +08:00
514flowey 1474ebe608 Merge pull request #600 from 514flowey/master
update optimizer
2024-11-05 14:18:37 +08:00
514flowey 382bd3f0e5 update optimizer 2024-11-05 14:14:10 +08:00
CHEN Xinsheng 4e4e67dfd5 fix `nonzero` 2024-11-05 12:46:00 +08:00
dengyx21 810af5953b Revert "sync only on broadcast_to from [1]"
This reverts commit 1439a03fca.
2024-11-05 11:29:10 +08:00
CHEN Xinsheng 33bd28fdb3 add ACL op `where` (unary case) 2024-11-04 22:41:16 +08:00
CHEN Xinsheng 19d2e2e912 add ACL op `nonzero`, a temporary implementation, a bit slow 2024-11-04 22:38:29 +08:00
dengyx21 1902dab9c5 sync only on broadcast_to from [1] 2024-11-04 19:42:09 +08:00
CHEN Xinsheng f79e2908ed Merge pull request #3 from CSCG-Lab/jtorch
fix some bugs for jtorch
2024-10-30 12:34:08 +08:00
dengyx21 158ec0756c shut off a stream 2024-10-29 20:14:41 +08:00
CHEN Xinsheng b279960344 Merge branch 'main' 2024-10-29 15:42:39 +08:00
DongYang Li 56bc5f65be fix jt.index error 2024-10-24 02:32:02 +08:00
CHEN Xinsheng f34e1beafa fix warp (class case) 2024-10-21 20:39:15 +08:00
CHEN Xinsheng 1776dd4da9 fix warp (class case) 2024-10-21 17:08:01 +08:00
CHEN Xinsheng 8a31c402de fix `jt.index` 2024-10-18 17:18:06 +08:00
CHEN Xinsheng 811dc241d4 fix `jt.Var.triu_` 2024-10-18 11:13:32 +08:00
CHEN Xinsheng 17048da065 fix finfo & iinfo 2024-10-17 22:00:22 +08:00
CHEN Xinsheng cf4ce2c95e fix finfo bug in jittor 2024-10-17 21:36:00 +08:00
CHEN Xinsheng 89010f5475 fix `Var.triu` & `Var.triu_` 2024-10-17 21:24:19 +08:00
CHEN Xinsheng 6edc1f74a3 add cub_cumsum & cumprod 2024-10-17 21:23:32 +08:00
CHEN Xinsheng c886f01b53 fix warp (class case) 2024-10-17 21:22:26 +08:00
lidongyang fc1fff8c0e update version to 1.3.9.11 2024-10-08 23:03:24 +08:00
lidongyang 72be1396d9 fix: jupyter restart error 2024-10-08 23:02:27 +08:00
zjp_shadow 8966ca4320 fix transpose 2024-10-06 21:20:12 +08:00
uyzhang a078268e18 polish 2024-10-01 19:34:51 +08:00
uyzhang 33898421e4 Merge branch 'main' of https://github.com/CSCG-Lab/JittorHW into main 2024-10-01 18:16:47 +08:00
uyzhang 4c6d726a4c Refactor transpose_acl function and fix bug in matmul_acl 2024-10-01 18:14:08 +08:00
张仪 cb75c8dedd format 2024-09-29 13:47:22 +08:00
uyzhang c268a0bfaf Refactor aclnn.h and acl_op.h to add support for FlashAttention and FlashAttentionBackward 2024-09-29 12:29:41 +08:00
uyzhang 146574d7d1 Refactor transpose_acl function and fix bug in matmul_acl 2024-09-27 19:47:13 +08:00
uyzhang 4329f3b287 Refactor transpose_acl function and fix bug in matmul_acl 2024-09-27 19:44:04 +08:00
zjp_shadow b48d8664a1 add transpose 2024-09-27 19:37:19 +08:00
uyzhang c7c7326456 fixed the bug in matmul 2024-09-27 16:54:43 +08:00
Yi Zhang 810530b3cc Merge pull request #1 from CSCG-Lab/concat
Update concat
2024-09-25 12:19:26 +08:00
zjp_shadow d648713ec5 Update concat 2024-09-25 00:37:36 +08:00
uyzhang 934885c96e Merge branch 'main' of https://github.com/CSCG-Lab/JittorHW into main 2024-09-23 23:12:49 +08:00
uyzhang dc29fa69dc FEAT! opt transpose in matmul and bmm 2024-09-23 23:12:44 +08:00
uyzhang c3df41e77b Refactor acl_compiler.py to handle gradient accumulation in bmm_acl and matmul_acl functions 2024-09-23 23:11:04 +08:00
uyzhang d092b83d0b Merge branch 'main' of https://github.com/CSCG-Lab/JittorHW into main 2024-09-23 22:43:47 +08:00
uyzhang 74aa4e68c2 Refactor acl_compiler.py to handle gradient accumulation in bmm_acl and matmul_acl functions 2024-09-23 22:43:44 +08:00
uyzhang 7fa22e2e32 add Ellipsis 2024-09-23 22:27:44 +08:00
uyzhang 9578e30972 Refactor acl_compiler.py to handle gradient accumulation in bmm_acl and matmul_acl functions 2024-09-23 20:40:53 +08:00
uyzhang 37671ccec1 Refactor acl_compiler.py to handle gradient accumulation in bmm_acl and matmul_acl functions 2024-09-23 20:26:46 +08:00
uyzhang 657687e0c0 Refactor acl_compiler.py to handle gradient accumulation in bmm_acl and matmul_acl functions 2024-09-23 16:09:42 +08:00
uyzhang 2a142ae73d fix bug of setitem cpu when use acl 2024-09-23 15:34:45 +08:00
uyzhang 9907aad7de fix getitem&setitem slice bug 2024-09-23 13:58:37 +08:00
uyzhang 2c2e8abe59 fix slice setitem 2024-09-23 13:18:49 +08:00
uyzhang 0d5035443e fix setitem not in graph 2024-09-23 03:26:12 +08:00
uyzhang fa288cb4d9 Refactor acl_op.h to use __fp16 for alphaValue in the case of ACL_FLOAT16 dtype 2024-09-22 18:06:38 +08:00
uyzhang 9ff62acf7d Refactor acl_op.h to use __fp16 for alphaValue in the case of ACL_FLOAT16 dtype
Refactor grad method for improved performance and synchronization
Index indices to int32
Fix getitem bug
Add getitem&setitem mask
2024-09-22 16:41:43 +08:00
lidongyang 8888b25ea7 fix getitem bug 2024-09-22 02:30:16 +08:00
lidongyang 464009af42 add getitem&setitem mask 2024-09-21 22:57:54 +08:00
uyzhang a357a7913d Refactor acl_op.h to use __fp16 for alphaValue in the case of ACL_FLOAT16 dtype 2024-09-21 17:17:47 +08:00
uyzhang 631a9a3aaa Refactor grad method for improved performance and synchronization 2024-09-21 14:20:10 +08:00
lidongyang 0705ed9d8f index indices to int32 2024-09-20 22:10:15 +08:00
uyzhang 015bd10210 Refactor flip and squeeze operations for improved performance and synchronization 2024-09-20 21:54:49 +08:00
lidongyang 898ec600b4 polish getitem&setitem 2024-09-20 21:44:47 +08:00
lidongyang babd92a002 polish getitem&setitem -1 2024-09-20 20:01:45 +08:00
lidongyang cdad66c01d polish output dtype 2024-09-20 19:43:56 +08:00
张仪 18afb843ad Fix synchronization issue in acl_op.h 2024-09-19 19:52:25 +08:00
张仪 4006f242de fixed bugs 2024-09-18 17:33:23 +08:00
张仪 e47a74a497 Fix broadcasting issue in acl_compiler.py and add support for setting item in jt.Var 2024-09-14 16:00:15 +08:00
lidongyang 651b24e634 add sigmoid embedding silu 2024-09-13 03:19:25 +08:00
lidongyang 0641a50a5d change op file to acl_op.h 2024-09-12 22:29:20 +08:00
lidongyang e00e4f099c add getitem&setitem 2024-09-12 20:25:48 +08:00
张仪 c55d49a8de add new aclop 2024-09-12 20:14:22 +08:00
张仪 3beeec78b1 add new aclop & fixed some bugs 2024-09-12 17:11:23 +08:00
张仪 eb89ae19ed add new aclop 2024-09-07 22:11:39 +08:00
张仪 21580ce80e update aclnn 2024-09-07 18:18:00 +08:00
514flowey 593519203b Merge pull request #586 from fansunqi/dim
fix dim=3 error
2024-09-05 20:18:02 +08:00
范孙奇 2c141fa996 fix dim=3 error 2024-09-05 20:14:36 +08:00
DongYang Li 4b907d493c Merge pull request #584 from liylo/module
Make forward hook modifiy the inputs and outputs
2024-09-04 16:33:04 +08:00
DongYang Li 79527c40e9 Merge pull request #583 from liylo/func
Add support for block diag function
2024-09-04 16:32:45 +08:00
DongYang Li a1fcd0f337 Merge pull request #503 from 514flowey/attention_mask
add attention mask
2024-09-04 16:19:46 +08:00
DongYang Li 96b97ccf55 Merge pull request #549 from fansunqi/bilinear
check input1 and input2 shape in jt.nn.Bilinear()
2024-09-04 16:16:31 +08:00
DongYang Li 818edc962e Merge pull request #558 from fansunqi/Upsample
check input shape and scale factor's positiveness in jt.nn.Upsample
2024-09-04 16:16:18 +08:00
DongYang Li 60d4f5a2ef Merge pull request #582 from liylo/master
fix load_parameter for Parameterlist issue Jittor#581
2024-09-04 16:15:49 +08:00
lidongyang 30b8a637de remove compatibility 2024-09-04 16:11:51 +08:00
liylo df442516ab forward hooks now could modifiy inputs and outputs 2024-08-28 21:35:12 +08:00
liylo 949c6ed676 init 2024-08-28 21:27:02 +08:00
liylo 1c5519acf2 simple implementation for block diag with proper grad 2024-08-28 21:18:56 +08:00
liylo c8ca6d30eb simple implementation for block diag 2024-08-28 21:13:00 +08:00
liylo ddaf3520e3 fix load 2024-08-28 20:50:37 +08:00
514flowey dc6e888d19 Merge pull request #579 from 514flowey/complex
Add Complex Operators
2024-08-22 12:55:05 +08:00
514flowey 1fbd56bb6d fix unique bug 2024-08-22 12:53:03 +08:00
张仪 b4244090ae first commit 2024-08-21 22:15:12 +08:00
514flowey 822955ac00 add several ffunctions 2024-08-20 15:08:19 +08:00
Yi Zhang c124023085 Merge pull request #567 from Hanyx2021/master
complement of test_aclop
2024-08-12 19:51:07 +08:00
Yuxuan Han 1c0cf4c2e4 complement of test_aclop: error of scatter()-multiple and where() 2024-08-12 19:50:29 +08:00
Yuxuan Han b46264b9f8 complement of test_aclop 2024-08-12 19:28:01 +08:00
Yuxuan Han f353b18472 complement of test_aclop 2024-08-01 16:00:00 +08:00
Yuxuan Han 4deb69c4e5 Merge pull request #1 from Jittor/master
Fixed the BUG of ACL op memory
2024-07-26 21:20:43 +08:00
Yuxuan Han 550ca96a75 complement of test_aclop 2024-07-26 21:16:09 +08:00
张仪 c25ac3a4e8 Fixed the BUG of ACL op memory 2024-07-25 15:54:57 +08:00
hanyx 69b6dd3b42 Merge remote-tracking branch 'upstream/master' 2024-07-24 21:18:52 +08:00
Yi Zhang 496b771211 Update acl_compiler.py 2024-07-24 16:20:02 +08:00
Yi Zhang 29f2fbd853 Update compile_extern.py 2024-07-24 15:43:19 +08:00
张仪 53327feff2 feat: enable ACL optimization in split function 2024-07-24 15:25:10 +08:00
Yi Zhang f2a471c2ec Merge pull request #575 from dengyx21/dev-dyx
FEAT! add floor_int
2024-07-24 15:20:15 +08:00
邓一轩 a755d64f9e FEAT! add floor_int 2024-07-24 15:13:57 +08:00
Yi Zhang 279e4113f3 Update compile_extern.py 2024-07-24 15:01:30 +08:00
Yi Zhang 140b17b824 Update acl_compiler.py 2024-07-24 14:53:04 +08:00
Yi Zhang 67d79a66d4 Merge pull request #572 from dengyx21/dev-dyx
FEAT! add aclop unittest
2024-07-19 17:05:00 +08:00
Yi Zhang 8a9c10d615 Format test_aclop.py 2024-07-19 17:04:42 +08:00
邓一轩 2b12e55447 FEAT! add aclop unittest 2024-07-19 17:01:17 +08:00
Yi Zhang f71c00c3d5 Merge pull request #571 from CHEN-Xinsheng/dev-cross_entropy_loss
fix dtype mismatch in `nn.cross_entropy_loss`
2024-07-19 16:46:41 +08:00
CHEN Xinsheng 9758b18c7d fix dtype mismatch in `nn.cross_entropy_loss` 2024-07-19 16:42:55 +08:00
Yi Zhang 54bc8484e9 Merge pull request #570 from dengyx21/dev-dyx
FEAT! where,scatter,cumsum,gather,flip
2024-07-18 20:08:37 +08:00
邓一轩 8f6563cba9 FEAT! where,scatter,cumsum,gather,flip 2024-07-18 20:04:40 +08:00
lidongyang 121fee583d add no gpu device error 2024-07-12 15:07:35 +08:00
Jiapeng Zhang f7bc197200 fix load bugs
fix load bugs of state
2024-07-10 19:58:07 +08:00
hanyx fa8b332f32 ComplexNumber:polar,view_as_complex,view_as_real 2024-07-09 22:27:28 +08:00
Yi Zhang 3f0814b482 Update acl_compiler.py 2024-07-09 21:48:35 +08:00
张仪 2ae2f1d453 update acl 2024-07-09 19:50:35 +08:00
lidongyang 3b2ca1c2c0 Merge branch 'master' of https://github.com/Jittor/jittor 2024-07-09 14:28:57 +08:00
lidongyang a58c8c7988 polish nn.Sequential __getattr__ 2024-07-09 14:28:17 +08:00
DongYang Li 914cd170b4 Merge pull request #548 from fansunqi/binary_cross_entropy_with_logits
check target and output shape in jt.nn.binary_cross_entropy_with_logits
2024-07-08 17:16:26 +08:00
DongYang Li 6736ce68e3 Merge pull request #553 from fansunqi/conv_transpose3d
modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose; add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose
2024-07-08 17:15:38 +08:00
DongYang Li dde745407e Merge pull request #554 from fansunqi/ConvTranspose
check stride positiveness and input shape in jt.nn.ConvTranspose
2024-07-08 17:14:05 +08:00
DongYang Li 133307627e Update nn.py 2024-07-08 17:13:32 +08:00
DongYang Li 9983779d7a Merge pull request #551 from fansunqi/Conv1d_sp
check input shape in jt.nn.Conv1d_sp
2024-07-08 17:08:33 +08:00
DongYang Li bdd6bb6de5 Merge pull request #550 from fansunqi/Conv1d
check input shape in jt.nn.Conv1d
2024-07-08 17:07:57 +08:00
DongYang Li 2b57b2d988 Merge pull request #555 from fansunqi/Dropout2d
check input shape in nn.Dropout2d
2024-07-08 17:05:49 +08:00
DongYang Li c669b1219a Merge pull request #556 from fansunqi/zeroPad2d
check input shape in jt.nn.ZeroPad2d
2024-07-08 17:05:22 +08:00
JittorRepos 596368ae7c Merge pull request #557 from fansunqi/ReplicationPad2d
check input shape in jt.nn.ReplicationPad2d
2024-07-08 17:04:21 +08:00
JittorRepos 98d7c2d0fa Merge pull request #562 from fansunqi/unfold
check parameter's positive in jt.nn.Unfold
2024-07-08 17:03:02 +08:00
lidongyang c47549e673 add isin 2024-07-05 18:12:43 +08:00
DongYang Li dcd6c6b2be update version 2024-07-02 20:02:02 +08:00
DongYang Li 7a2b94a91d Merge pull request #561 from fansunqi/fold
check parameters' positive in jt.nn.fold
2024-07-02 20:01:27 +08:00
fansunqi f4d4c9d55c check parameter's positive in jt.nn.unfold 2024-07-01 15:41:24 +08:00
Sunqi Fan c45dac35e6 Merge branch 'Jittor:master' into fold 2024-07-01 12:26:35 +08:00
fansunqi 45ccf3d2ac check parameters' positive in jt.nn.fold 2024-07-01 12:23:03 +08:00
Sunqi Fan dfec39c2b8 Merge branch 'Jittor:master' into master 2024-07-01 11:00:18 +08:00
DongYang Li 4196cb8154 update version 2024-06-25 16:49:44 +08:00
DongYang Li d8ce49cd70 Update setup.py
fix numpy version
2024-06-25 16:47:09 +08:00
范孙奇 f358fb7518 check input shape and scale factor's positiveness in jt.nn.Upsample 2024-06-10 19:27:29 +08:00
范孙奇 969d810f55 resume 2024-06-10 19:26:40 +08:00
范孙奇 78b7cf091b check input shape and scale factor's positiveness in jt.nn.Upsample 2024-06-10 19:25:53 +08:00
范孙奇 c4480b7e3b check input shape in jt.nn.ReplicationPad2d 2024-06-10 19:08:53 +08:00
范孙奇 d31b0a244d check input shape in jt.nn.ZeroPad2d 2024-06-10 19:05:49 +08:00
范孙奇 1fba329474 check input shape in nn.Dropout2d 2024-06-10 17:02:05 +08:00
范孙奇 958708ed60 modify error information 2024-06-10 16:48:27 +08:00
范孙奇 e6e5949765 add stride check in jt.nn.ConvTranspose 2024-06-10 16:45:53 +08:00
范孙奇 2266d21a8b remove 3D(unbatch) description 2024-06-10 16:43:22 +08:00
范孙奇 db8fcb33da modify stride positive check in jt.nn.conv_transpose; add input shape check in jt.nn.conv_transpose 2024-06-10 16:39:05 +08:00
范孙奇 baf6b45cf1 add input shape check in jt.nn.transpose3d 2024-06-10 16:29:58 +08:00
范孙奇 8fd834465c modify stride positive check in jt.nn.transpose3d 2024-06-10 16:08:48 +08:00
范孙奇 ae0e52dca5 check input shape in jt.nn.ConvTranspose 2024-06-06 20:55:46 +08:00
范孙奇 d895cb9d36 jt.nn.Conv1d in_channels and out_channels must be positive 2024-06-06 20:39:10 +08:00
范孙奇 b4155d8021 jt.nn.Conv1d_sp in_channels and out_channels must be positive 2024-06-06 20:35:05 +08:00
范孙奇 a42198705b check input shape in jt.nn.Conv1d_sp 2024-06-06 20:25:29 +08:00
范孙奇 4d11325634 check input shape in jt.nn.Conv1d 2024-06-06 20:18:31 +08:00
范孙奇 7f6beb58b9 check input1 and input2 shape in jt.nn.Bilinear() 2024-06-06 20:04:35 +08:00
范孙奇 2f11e3bbbe check target shape and output shape in jt.nn.binary_cross_entropy_with_logits 2024-06-06 18:02:29 +08:00
lidongyang 393684f196 polish nn.Sequential attribute 2024-06-05 22:31:20 +08:00
DongYang Li c49be7cf79 Merge pull request #546 from Hanyx2021/fix-expand
fix: some function&class input illegal paramters
2024-06-05 21:37:59 +08:00
DongYang Li 1baf90dd1b Update README.md 2024-06-04 17:55:59 +08:00
Hanyuxuan f20ea9dcf1 fix illegal parameters of ConvTranspose and Pool,issue #478,#480,#481,#482,#483 2024-05-31 15:23:42 +08:00
Hanyuxuan 26963bc70f fix Pad2d with illegal padding,issue #464,#465,#466,#467 2024-05-31 14:50:47 +08:00
Hanyuxuan 64c6400070 check x.shape and kernel_size of Pool and Pool3d,issue #461,#463 2024-05-31 14:33:07 +08:00
Hanyuxuan c7e31604c2 fix illegal parameters of PixelShuffle of issue #458,fix validity of concat of issue #459 2024-05-30 19:03:17 +08:00
Hanyuxuan 981f60c381 fix illegal parameters of Conv2d issue #471,#472,#473,#474,#475,#476,#477 2024-05-30 16:10:06 +08:00
Hanyuxuan 102a689fee fix illegal parameters of Pool and Pool3d of issue #451,#453,#456,#457 2024-05-30 15:19:19 +08:00
Hanyuxuan 681174a606 a ValueError fix of issue #450 2024-05-30 14:29:34 +08:00
Hanyuxuan fe8fb30136 a IndexError fix of issue #448 2024-05-30 14:00:27 +08:00
Hanyuxuan 317defa7a1 fix: jt.Var.expand with valid index -1 2024-05-29 11:05:53 +08:00
DongYang Li a83eea318d Merge pull request #545 from zhc7/patch-1
fix: fix for issue #544
2024-05-28 20:12:33 +08:00
zhc7 fce14c8d9d fix: a minimal quick fix for issue #544 2024-05-22 11:08:28 +08:00
Yi Zhang 426c83a8d4 Merge pull request #533 from uyzhang/master
Update ACL library and fix bugs in ACL integration
2024-05-21 12:50:05 +08:00
Yi Zhang 2c1a5e14f1 Merge branch 'master' into master 2024-05-21 12:48:58 +08:00
DongYang Li f8fde94c3f Update version to 1.3.9.8 2024-05-20 21:43:48 +08:00
DongYang Li 4e41e6b070 Merge pull request #539 from fansunqi/issue523_branch
fix issue 523;update jt.nn.Conv1d/Conv3d/conv2d/conv3d
2024-05-20 21:42:16 +08:00
DongYang Li 21eaa919b9 Merge pull request #540 from fansunqi/issue522_branch
fix issue 522,520,519,516; update jt.Pool/Pool3d
2024-05-20 21:41:52 +08:00
DongYang Li f3744aa47d Merge pull request #541 from fansunqi/issue521_branch
fix issue 521;update jt.nn.MaxUnpool2d/MaxUnpool3d
2024-05-20 21:41:34 +08:00
DongYang Li b8df8e0098 Merge pull request #543 from LDYang694/master
polish rocm support
2024-05-20 21:37:53 +08:00
lidongyang 9190180d8d polish rocm support 2024-05-20 21:34:26 +08:00
范孙奇 e4981653e3 fix issue 522;update jt.Pool/Pool3d
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-17 11:22:12 +00:00
范孙奇 dd9ac69eec fix issue 523;update jt.nn.Conv1d/Conv3d/conv2d/conv3d
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-17 09:35:12 +00:00
uyzhang 871ed92fc4 fix: Add conditional import for change_function in __init__.py 2024-05-16 16:03:09 +08:00
DongYang Li 4efbbbf75c Merge pull request #443 from co63oc/patch-1
Update mnist.py
2024-05-16 15:20:07 +08:00
DongYang Li 9943ddf8de Merge pull request #536 from fansunqi/issue528_branch
fix issue 528;update conv_transpose
2024-05-16 15:19:19 +08:00
DongYang Li 5934b20720 Merge pull request #535 from fansunqi/issue529_branch
fix issue 529;update contrib.argmax_pool()
2024-05-16 15:18:14 +08:00
DongYang Li 9370896b35 Merge pull request #537 from fansunqi/issue527_branch
fix issue 527,526;update jt.zeros/ones/full/randn/randint/random
2024-05-16 14:47:53 +08:00
DongYang Li ee3c68ce7a Merge pull request #538 from fansunqi/issue525_branch
fix issue 525;update jt.nn.Reflection2d/Replication2d
2024-05-16 14:43:28 +08:00
DongYang Li 136a710775 polish PixelShuffle in nn.py 2024-05-16 14:35:21 +08:00
DongYang Li 75429f83b7 Merge pull request #534 from fansunqi/master
fix issue 531,530;update jt.nn.PixelShuffle/jt.histc
2024-05-16 14:27:24 +08:00
514flowey 96fda6dee5 Merge pull request #518 from 514flowey/complex
add complex matmul, inv, qr, eig, and svd
2024-05-16 13:38:53 +08:00
范孙奇 82595dc766 fix issue 525;update jt.nn.Reflection2d/Replication2d
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-15 13:49:26 +00:00
范孙奇 69fc229912 fix issue 526;update jt.randn/random/randint
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-15 13:34:06 +00:00
范孙奇 9f4c156e12 fix issue 527;update jt.zeros/ones/full
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-15 13:23:30 +00:00
范孙奇 9baccaed4d fix issue 528;update conv_transpose
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-15 12:21:21 +00:00
范孙奇 fc252af9a2 fix issue 529;update contrib.argmax_pool()-2
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-15 12:04:38 +00:00
范孙奇 72f72900d6 fix issue 529;update contrib.argmax_pool()
Signed-off-by: 范孙奇 <fansq20@mails.tsinghua.edu.cn>
2024-05-15 12:01:08 +00:00
uyzhang d587961209 Update ACL library and fix bugs in ACL integration 2024-05-15 19:03:04 +08:00
514flowey 290f2aec60 add complex matmul, inv, qr, eig, and svd 2024-05-08 21:09:07 +08:00
514flowey e5bfdfb162 add attention mask 2024-04-08 19:26:47 +08:00
co63oc 11e29147db Update mnist.py 2023-05-22 13:18:08 +08:00
216 changed files with 14225 additions and 8000 deletions

View File

@ -96,7 +96,6 @@ Jittor environment requirements:
| OS | CPU | Python | Compiler | (Optional) GPU platform |
|--------------------------------------------------------|-------------------------------------|--------|--------------|---------------------------------------------|
| Linux<br>(Ubuntu, CentOS, Arch, <br>UOS, KylinOS, ...) | x86 <br>x86_64 <br>ARM <br>loongson | >= 3.7 | g++ >=5.4 | Nvidia CUDA >= 10.0, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar) <br> or [AMD ROCm](https://docs.amd.com/) >= 4.0 <br> or [Hygon DCU DTK](https://tycloud.hpccube.com/doc/1.0.6/11277/general-handbook/software-tutorial/jittor.html) >= 22.04 |
| macOS <br>(>= 10.14 Mojave) | intel<br>Apple Silicon | >= 3.7 | clang >= 8.0 | - |
| Windows 10 & 11 | x86_64 | [>= 3.8](https://www.python.org/downloads/windows/) | - | Nvidia CUDA >= 10.2 [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-windows) |
@ -116,25 +115,6 @@ python3.7 -m jittor.test.test_example
### macOS install
Please first install additional dependencies with [homebrew](https://brew.sh).
```bash
brew install libomp
```
Then you can install jittor through pip and run the example.
```bash
python3.7 -m pip install jittor
python3.7 -m jittor.test.test_example
```
Currently jittor only supports CPU on macOS.
### Windows install
@ -382,10 +362,10 @@ Email: jittor@qq.com
File an issue: https://github.com/Jittor/jittor/issues
QQ Group: 761222083
QQ Group: 836860279
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/news/2020-12-8-21-19-1_2_2/fig4.png" width="200"/>
<img src="https://github.com/Jittor/jittor/assets/62846124/8dd830bd-b31c-4e4f-9a78-5fd7a3409145" width="200"/>
## The Team

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.9.6'
__version__ = '1.3.10.0'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -26,7 +26,7 @@ with lock.lock_scope():
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
from .compile_extern import cudnn, curand, cublas, cufft
from .compile_extern import cudnn, curand, cublas, cufft, cusparse
from .init_cupy import numpy2cupy
from typing import List, Tuple
@ -428,7 +428,9 @@ def random(shape, dtype="float32", type="uniform"):
jt.Var([[0.96788853 0.28334728 0.30482838]
[0.46107793 0.62798643 0.03457401]], dtype=float32)
'''
for dim in shape:
if dim < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
ret = ops.random(shape, "float32", type)
## TODO: move those code to core
#if dtype in ["float16", "bfloat16"]:
@ -484,6 +486,9 @@ def ones(*shape, dtype="float32"):
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
for dim in shape:
if dim < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(1, dtype).broadcast(shape)
def new_ones(x, size):
@ -515,6 +520,9 @@ def zeros(*shape, dtype="float32"):
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
for dim in shape:
if dim < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(0, dtype).broadcast(shape)
def new_zeros(x, size):
@ -547,6 +555,9 @@ def full(shape,val,dtype="float32"):
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
for dim in shape:
if dim < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
return unary(val, dtype).broadcast(shape)
def new_full(x, size, val):
@ -641,14 +652,22 @@ def var(x, dim=None, dims=None, unbiased=False, keepdims=False):
return sqr
Var.var = var
def std(x):
matsize=1
for i in x.shape:
matsize *= i
out=(x-x.mean()).sqr().sum()
out=out/(matsize-1)
out=out.maximum(1e-6).sqrt()
return out
def std(x, dim=None, keepdim=False):
if dim is None:
matsize=1
for i in x.shape:
matsize *= i
out=(x-x.mean()).sqr().sum()
out=out/(matsize-1)
out=out.maximum(1e-6).sqrt()
return out
else:
dimsize=x.size(dim)
mean=jt.mean(x, dim, keepdim=True)
out=(x - mean).sqr().sum(dim=dim, keepdim=keepdim)
out=out/(dimsize-1)
out=out.maximum(1e-6).sqrt()
return out
Var.std = std
def norm(x, p=2, dim=-1, keepdims=False, eps=1e-30, keepdim=False):
@ -687,6 +706,8 @@ def flatten(input, start_dim=0, end_dim=-1):
start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function"
if len(in_shape) <= end_dim:
raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})")
out_shape = []
for i in range(0,start_dim,1): out_shape.append(in_shape[i])
dims = 1
@ -917,6 +938,9 @@ def randn(*size, dtype="float32", requires_grad=True) -> Var:
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
for dim in size:
if dim < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {size}")
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
@ -1013,6 +1037,9 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
[1 1 1]], dtype=int32)
'''
if high is None: low, high = 0, low
for dim in shape:
if dim < 0:
raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}")
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
v = jt.floor_int(v)
return v.astype(dtype)
@ -1437,9 +1464,17 @@ class Module:
def __hooked_call__(self, *args, **kw):
if hasattr(self, "__fhook2__"):
if len(kw):
self.__fhook2__(self, args, kw)
args_kw_result = self.__fhook2__(self, args, kw)
else:
self.__fhook2__(self, args)
args_kw_result = self.__fhook2__(self, args)
if args_kw_result is not None:
if isinstance(args_kw_result, tuple) and len(args_kw_result) == 2:
args, kw = args_kw_result
else:
raise RuntimeError(
"forward pre-hook must return None or a tuple "
f"of (new_args, new_kwargs), but got {args_kw_result}."
)
if hasattr(self, "__bihook__"):
if len(kw):
LOG.w("backward hook not support kw")
@ -1458,9 +1493,11 @@ class Module:
ret = grad_hooker(ret, self.__bohook__)
if hasattr(self, "__fhook__"):
if len(kw):
self.__fhook__(self, args, ret, kw)
res = self.__fhook__(self, args, ret, kw)
else:
self.__fhook__(self, args, ret)
res = self.__fhook__(self, args, ret)
if res is not None:
ret = res
return ret
def _place_hooker(self):
@ -1595,6 +1632,8 @@ Arguments of hook are defined as::
else:
if hasattr(v, k):
v = getattr(v, k)
if v is None:
continue
assert isinstance(v, (Module, Var)), \
f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}"
else:
@ -2119,6 +2158,7 @@ from . import sparse
from . import optim
from . import dataset
from . import init
from . import gradfunctional
dtype = NanoString
@ -2152,3 +2192,7 @@ for k,v in list(Var.__dict__.items()):
from . import math_util
from .math_util import *
from . import distributions
if jt.compiler.has_acl:
from jittor.extern.acl.acl_compiler import change_function
change_function()

View File

@ -1,8 +1,8 @@
from jittor_core import *
from jittor_core.ops import *
from .misc import *
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, cusparse as cusparse ,mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
from .contrib import concat as concat
from .nn import bmm as bmm, bmm_transpose as bmm_transpose, matmul as matmul

View File

@ -9,168 +9,575 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import init, Module, nn
import numpy as np
from typing import Optional, Tuple, List
import warnings
import math
import jittor as jt
from jittor import Var
from jittor.nn import Module, Linear, softmax, pad, linear, dropout
from jittor.init import xavier_uniform_, xavier_gauss_, constant_
def _canonical_mask(
mask: Optional[Var],
mask_name: str,
other_type,
other_name: str,
target_type,
check_other: bool = True,
) -> Optional[Var]:
if mask is not None:
_mask_dtype = mask.dtype
_mask_is_float = mask.dtype == jt.float16 or mask.dtype == jt.float32 or mask.dtype == jt.float64
if _mask_dtype != jt.bool and not _mask_is_float:
raise AssertionError(
f"only bool and floating types of {mask_name} are supported")
if check_other and other_type is not None:
if _mask_dtype != other_type:
warnings.warn(
f"Support for mismatched {mask_name} and {other_name} "
"is deprecated. Use same type for both instead."
)
if not _mask_is_float:
# WARNING(514flowey): Check Here
new_mask = jt.zeros_like(mask, dtype=target_type)
new_mask[mask] = float("-inf")
mask = new_mask
return mask
def _none_or_dtype(input: Optional[Var]):
if input is None:
return None
elif isinstance(input, jt.Var):
return input.dtype
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
def baddbmm(input_var:jt.Var, batch1:jt.Var, batch2:jt.Var, beta=1, alpha=1) -> jt.Var:
# WARNING(514flowey): Check here
return beta * input_var + alpha * (batch1 @ batch2)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> jt.Var:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = jt.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = jt.ones(L, S, dtype=jt.bool).tril(diagonal=0)
attn_bias[jt.logical_not(temp_mask)] = float("-inf")
# attn_bias.to(query.dtype)
attn_bias = jt.array(attn_bias, query.dtype)
if attn_mask is not None:
if attn_mask.dtype == jt.bool:
attn_bias[jt.logical_not(temp_mask)] = float("-inf")
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = softmax(attn_weight, dim=-1)
attn_weight = dropout(attn_weight, dropout_p, is_train=True)
return attn_weight @ value
def _mha_shape_check(query: Var, key: Var, value: Var,
key_padding_mask: Optional[Var], attn_mask: Optional[Var], num_heads: int):
if query.dim() == 3:
is_batched = True
assert key.dim() == 3 and value.dim() == 3, \
("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.dim() == 2, \
("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
f" but found {key_padding_mask.dim()}-D tensor instead")
if attn_mask is not None:
assert attn_mask.dim() in (2, 3), \
("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.dim()}-D tensor instead")
elif query.dim() == 2:
is_batched = False
assert key.dim() == 2 and value.dim() == 2, \
("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.dim() == 1, \
("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
f" but found {key_padding_mask.dim()}-D tensor instead")
if attn_mask is not None:
assert attn_mask.dim() in (2, 3), \
("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.dim()}-D tensor instead")
if attn_mask.dim() == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
assert attn_mask.shape == expected_shape, \
(f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
else:
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
return is_batched
def _in_projection_packed(
q: Var,
k: Var,
v: Var,
w: Var,
b: Optional[Var] = None,
) -> List[Var]:
E = q.size(-1)
if k is v:
if q is k:
# self-attention
proj = linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
# proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
nshape = proj.shape[:-1] + (3, E)
proj = proj.reshape(nshape).unsqueeze(0).transpose(0, -2).squeeze(-2)
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
q_proj = linear(q, w_q, b_q)
kv_proj = linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
# kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
nshape = kv_proj.shape[:-1] + (2, E)
kv_proj = kv_proj.reshape(nshape).unsqueeze(0).transpose(0, -2).squeeze(-2)
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def _in_projection(
q: Var,
k: Var,
v: Var,
w_q: Var,
w_k: Var,
w_v: Var,
b_q: Optional[Var] = None,
b_k: Optional[Var] = None,
b_v: Optional[Var] = None,
) -> Tuple[Var, Var, Var]:
Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def multi_head_attention_forward(
query: Var,
key: Var,
value: Var,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Var],
in_proj_bias: Optional[Var],
bias_k: Optional[Var],
bias_v: Optional[Var],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Var,
out_proj_bias: Optional[Var],
training: bool = True,
key_padding_mask: Optional[Var] = None,
need_weights: bool = True,
attn_mask: Optional[Var] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Var] = None,
k_proj_weight: Optional[Var] = None,
v_proj_weight: Optional[Var] = None,
static_k: Optional[Var] = None,
static_v: Optional[Var] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Var, Optional[Var]]:
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
# batch dimension so that the output doesn't carry this temporary batch dimension.
if not is_batched:
# unsqueeze if the input is unbatched
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(0)
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)
if is_causal and key_padding_mask is None and not need_weights:
# when we have a kpm or need weights, we need attn_mask
# Otherwise, we use the is_causal hint go as is_causal
# indicator to SDPA.
attn_mask = None
else:
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, jt.Var):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
# prep attention mask
if attn_mask is not None:
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = jt.concat([k, bias_k.repeat(1, bsz, 1)])
v = jt.concat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
#
# reshape q, k, v for multihead attention and make em batch first
#
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = jt.concat([k, jt.zeros(zero_attn_shape, dtype=k.dtype)], dim=1)
v = jt.concat([v, jt.zeros(zero_attn_shape, dtype=v.dtype)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.size(1)
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask
# adjust dropout probability
if not training:
dropout_p = 0.0
#
# (deep breath) calculate attention and out projection
#
if need_weights:
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = jt.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
attn_output = jt.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
# optionally average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights
else:
# attn_mask can be either (L,S) or (N*num_heads, L, S)
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
# in order to match the input for SDPA of (N, num_heads, L, S)
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
return attn_output, None
class MultiheadAttention(Module):
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
):
__constants__ = ['batch_first']
bias_k: Optional[jt.Var]
bias_v: Optional[jt.Var]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
kdim=None, vdim=None, batch_first=False, dtype=jt.float32) -> None:
if embed_dim <= 0 or num_heads <= 0:
raise ValueError(
f"embed_dim and num_heads must be greater than 0,"
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
)
factory_kwargs = {'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
assert dropout==0, "TODO: dropout>0"
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
if not self._qkv_same_embed_dim:
self.q_proj_weight = jt.empty((embed_dim, embed_dim), **factory_kwargs)
self.k_proj_weight = jt.empty((embed_dim, self.kdim), **factory_kwargs)
self.v_proj_weight = jt.empty((embed_dim, self.vdim), **factory_kwargs)
self.in_proj_weight = None
else:
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None
self.in_proj_weight = jt.empty((3 * embed_dim, embed_dim), **factory_kwargs)
assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size")
if bias:
self.in_proj_bias = jt.empty(3 * embed_dim, **factory_kwargs)
else:
self.in_proj_bias = None
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
#TODO: quant_noise
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
assert not add_bias_kv, "TODO: add_bias_kv=True"
self.bias_k = self.bias_v = None
if add_bias_kv:
self.bias_k = jt.empty((1, 1, embed_dim), **factory_kwargs)
self.bias_v = jt.empty((1, 1, embed_dim), **factory_kwargs)
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self._reset_parameters()
self.onnx_trace = False
self.tpu = False
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
init.xavier_uniform_(self.k_proj.weight)
init.xavier_uniform_(self.v_proj.weight)
init.xavier_uniform_(self.q_proj.weight)
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
# init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
init.constant_(self.out_proj.bias, 0.)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
init.xavier_normal_(self.bias_k)
xavier_gauss_(self.bias_k)
if self.bias_v is not None:
init.xavier_normal_(self.bias_v)
xavier_gauss_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True
super().__setstate__(state)
def execute(
self,
query,
key = None,
value = None,
key_padding_mask = None,
incremental_state = None,
need_weights = True,
static_kv = False,
attn_mask = None,
before_softmax = False,
need_head_weights = False,
):
if need_head_weights:
need_weights = True
self,
query: Var,
key: Var,
value: Var,
key_padding_mask: Optional[Var] = None,
need_weights: bool = True,
attn_mask: Optional[Var] = None,
average_attn_weights: bool = True,
is_causal : bool = False) -> Tuple[Var, Optional[Var]]:
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
#####
# Fast Path is not Supported.
#####
assert incremental_state is None, "TODO: incremental_state is not None"
saved_state = None
is_batched = query.dim() == 3
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = (x.transpose(1, 0) for x in (query, key))
value = key
else:
k = self.k_proj(key)
v = self.v_proj(key)
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.is_training(),
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q*self.scaling
assert self.bias_k is None, "TODO: self.bias_k is not None:"
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if k is not None:
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if v is not None:
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
assert saved_state is None, "TODO: saved_state is not None"
assert k is not None
src_len = k.shape[1]
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"
attn_weights = nn.bmm(q, k.transpose(0, 2, 1))
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
assert attn_mask is None, "TODO: attn_mask is not None"
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
if before_softmax:
return attn_weights, v
attn_weights_float = nn.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
assert v is not None
attn = nn.bmm(attn_weights, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.shape[1] == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.view(tgt_len, bsz, embed_dim)
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.is_training(),
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights = None
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dims=[0])
return attn, attn_weights
return attn_output, attn_output_weights

View File

@ -1,430 +0,0 @@
import os
os.environ["FIX_TORCH_ERROR"] = "0"
import jittor as jt
from jittor import *
from typing import Tuple
org_int = int = type(1)
org_float = float = type(1.0)
org_bool = bool = type(True)
import jtorch.compiler
import jtorch_core
from jtorch_core import *
device.__reduce__ = lambda self: (device, (self.type,))
device.__module__ = "jtorch"
jt.jittor_core.device = device
def handle_dtype(args, kw, dtype):
def convert(x):
if isinstance(x, jt.Var):
return x.cast(dtype)
return x
if dtype is not None:
if args is not None:
if isinstance(args, (tuple,list)):
args = [ convert(a) for a in args ]
else:
args = convert(x)
if kw is not None:
kw = { k:convert(v) for k,v in kw.items() }
return args, kw
def get_args_names(func):
import inspect
spec = inspect.getfullargspec(func)
return spec[0] + spec[4]
def wrapper(func):
has_dtype = False
if hasattr(func, "__code__"):
has_dtype = "dtype" in get_args_names(func)
def inner(*args, **kw):
requires_grad = None
dtype = None
if "requires_grad" in kw:
requires_grad = kw["requires_grad"]
del kw["requires_grad"]
if not has_dtype and "dtype" in kw:
dtype = kw["dtype"]
del kw["dtype"]
if "device" in kw:
del kw["device"]
if 'pin_memory' in kw:
del kw['pin_memory']
args, kw = handle_dtype(args, kw, dtype)
ret = func(*args, **kw)
if isinstance(ret, jt.Var):
if requires_grad is not None:
ret.requires_grad = requires_grad
if dtype is not None:
ret.astype(dtype)
return ret
return inner
import inspect
_wrapper_keys = set(["shape", "start", "size"])
_wrapper_keys.add("x")
for k,v in list(globals().items()):
if callable(v) and not isinstance(v, type):
try:
spec = inspect.getfullargspec(v)
args_name = spec[0]
if len(args_name) and args_name[0] in _wrapper_keys:
globals()[k] = wrapper(v)
elif spec.varargs in _wrapper_keys:
globals()[k] = wrapper(v)
except:
pass
def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
if len(size) == 1 and not isinstance(size[0], org_int):
size = size[0]
return jt.empty(size, dtype)
Tensor = Var
Tensor.backward = lambda x: jtorch_core.backward(x)
Tensor.grad = property(grad_get, grad_set, grad_del)
Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
def retain_grad(x:Tensor, value:bool=True):
x.retains_grad = value
return value
Tensor.retain_grad = retain_grad
Tensor.dim = lambda self: self.ndim
Tensor.ndimension = lambda self: self.ndim
Tensor.nelement = lambda self: self.numel()
Tensor.cuda = lambda self: self
def device_get(x:Tensor):
return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
Tensor.device = property(device_get)
def argmax(x: Var, dim=None, keepdim: bool = False):
return jt.argmax(x, dim, keepdim)[0]
Tensor.argmax = argmax
def tensor_type(x: Var, dtype=None, **kwargs):
if dtype:
return x.astype(dtype)
else:
return x.dtype
Tensor.type = tensor_type
def is_floating_point(x: Var):
return "float" in str(x.dtype)
Tensor.is_floating_point = is_floating_point
from . import autograd
from .autograd import *
def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
if isinstance(data,list):
data_list = []
check = True
for p in data:
if isinstance(p, Tensor) and p.numel()==1:
data_list.append(p.item())
elif isinstance(p, (org_int,org_float)):
data_list.append(p)
else:
check = False
break
if check:
data = data_list
return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
# tensor = wrapper(array)
from_numpy = wrapper(array)
strided = None
def mod_zero_grad(self):
for p in self.parameters():
p.grad = None
Module.zero_grad = mod_zero_grad
class ModuleMisc:
def parameters(self):
return iter(super().parameters())
def load_state_dict(self, state_dict, strict=False):
return super().load_state_dict(state_dict)
def to(self, device=None,dtype=None):
''' do nothing but return its self'''
return self
def register_parameter(self,name,data):
self.name = data
def buffers(self):
for _, buf in self.named_buffers():
yield buf
def make_module(cls):
class TMod(ModuleMisc, cls):
def __init__(self, *args, **kw):
dtype = None
if "dtype" in kw:
dtype = kw["dtype"]
del kw["dtype"]
self._dtype = dtype
with jt.flag_scope(th_mode=0):
if "device" in kw:
del kw["device"]
super().__init__(*args, **kw)
for k,v in self.__dict__.items():
if not k.startswith("_") and isinstance(v, Var) \
and v.requires_grad:
v.retain_grad()
if dtype is not None and isinstance(v, Var):
v.assign(v.cast(dtype))
def __call__(self, *args, **kw):
args, kw = handle_dtype(args, kw, self._dtype)
# if forward is override by user, call forward
if self.__class__.forward is not TMod.forward:
return self.forward(*args, **kw)
return self.execute(*args, **kw)
def forward(self, *args, **kw):
args, kw = handle_dtype(args, kw, self._dtype)
return self.execute(*args, **kw)
@property
def training(self):
if not hasattr(self, "is_train"):
self.is_train = True
return self.is_train
@training.setter
def training(self, value):
self.is_train = value
TMod.__name__ = cls.__name__
return TMod
import jtorch.cuda
import jtorch.nn
from jtorch.nn import Module, Parameter
import jtorch.optim
from jtorch.utils.dtype import Dtype, get_string_dtype
def frombuffer(buffer: bytearray,
*,
dtype: Dtype,
count: int = -1,
offset: int = 0,
requires_grad: bool = True) -> Tensor:
dtype = get_string_dtype(dtype)
tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
if requires_grad and tensor.dtype.is_float():
tensor.requires_grad = True
return tensor
def conflict_wrapper(origin_func, new_func):
def wrapper(*args, **kw):
if jt.flags.th_mode:
return new_func(*args, **kw)
else:
return origin_func(*args, **kw)
return wrapper
def min(*args, **kw):
dim = None
if len(args) >= 2 and isinstance(args[1], org_int):
dim = args[1]
elif "dim" in kw and isinstance(kw["dim"], org_int):
dim = kw["dim"]
if dim is not None:
k, v = jt.argmin(*args, **kw)
return v, k
elif len(args) == 2 and isinstance(args[1], jt.Var):
return jt.minimum(args[0], args[1])
else:
return jt.min(*args, **kw)
Tensor.min = conflict_wrapper(jt.min, min)
def max(*args, **kw):
dim = None
if "dim" in kw:
x = kw["dim"]
if len(args) >= 2 and isinstance(args[1], org_int):
dim = args[1]
elif "dim" in kw and isinstance(kw["dim"], org_int):
dim = kw["dim"]
if dim is not None:
k, v = jt.argmax(*args, **kw)
return v, k
elif len(args) == 2 and isinstance(args[1], jt.Var):
return jt.maximum(args[0], args[1])
else:
return jt.max(*args, **kw)
Tensor.max = conflict_wrapper(jt.max, max)
def argsort(*args, **kw):
k, v = jt.argsort(*args, **kw)
return k
Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
LongTensor = jt.int64
FloatTensor = jt.float
HalfTensor = jt.float16
BoolTensor = jt.bool
IntTensor = jt.int32
class JDType:
def __init__(self, func, str):
self.func = func
self.str = str
self.__name__ = str.split(".")[-1]
def __call__(self, *args, **kw):
return self.func(*args, **kw)
def __str__(self):
return self.str
def is_floating_point(self):
return "float" in str(self.str)
int8 = JDType(jt.int8, "torch.int8")
int16 = JDType(jt.int16, "torch.int16")
int = int32 = JDType(jt.int32, "torch.int32")
long = int64 = JDType(jt.int64, "torch.int64")
half = float16 = JDType(jt.float16, "torch.float16")
float = float32 = JDType(jt.float32, "torch.float32")
double = float64 = JDType(jt.float64, "torch.float64")
bfloat16 = "bfloat16" # TODO
complex64 = "complex64" # TODO
complex128 = "complex128" # TODO
def get_JDtype(dtype):
if dtype=='float32' or dtype == jt.float32:
return float32
elif dtype=='float64' or dtype == jt.float64:
return float64
elif dtype=='float16' or dtype == jt.float16:
return float16
elif dtype=='int32' or dtype == jt.int32:
return int32
elif dtype=='int64' or dtype == jt.int64:
return int64
elif dtype=='int16' or dtype == jt.int16:
return int16
elif dtype=='int8' or dtype == jt.int8:
return int8
else:
raise Exception("dtype {} not supported".format(dtype))
def load(path,**kwargs):
def _to_jittor(data):
if isinstance(data,dict):
return {k:_to_jittor(d) for k,d in data.items()}
if isinstance(data,list):
return [_to_jittor(d) for d in data]
if isinstance(data,np.ndarray):
return jt.array(data)
return data
data = jt.load(path)
return _to_jittor(data)
def is_tensor(x):
return isinstance(x, Tensor)
manual_seed = jt.set_global_seed
jt.flags.amp_level = 3
Size = jt.NanoVector
class Generator:
def __init__(self,*args,**kw) -> None:
self.seed = None
def manual_seed(self,seed):
self.seed = seed
from . import fx
_default_type = "float32"
def get_default_dtype():
return _default_type
def set_default_dtype(dtype):
global _default_type
_default_type = dtype
dtype = JDType
def div(x,y,rounding_mode="floor"):
assert rounding_mode == "floor"
z = (x / y)
if rounding_mode == "floor":
z = z.floor()
if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
z = z.int32()
return z
def randn(*args,**kw):
wrap_randn = wrapper(jt.randn)
generator = kw.get('generator',None)
kw.pop('generator',None)
if 'layout' in kw:
del kw['layout']
if generator is not None and generator.seed is not None:
jt.set_global_seed(generator.seed)
return wrap_randn(*args,**kw)
def rand(*args,**kw):
print("rand")
wrap_rand = wrapper(jt.rand)
generator = kw.get('generator',None)
kw.pop('generator',None)
if 'layout' in kw:
del kw['layout']
if generator is not None and generator.seed is not None:
jt.set_global_seed(generator.seed)
return wrap_rand(*args,**kw)
def set_default_tensor_type(t: type or str):
if isinstance(t, str):
info = t.split(".")
if len(info) == 3 and info[1] == 'cuda':
jt.flags.use_cuda = 1
#TODO: type
def clamp(x, min=None, max=None):
return jt.clamp(x, min, max)
def to(x,*args,**kw):
device = None
if len(args) == 1:
device = args[0]
if isinstance(device, jt.NanoString) or callable(device):
return jt.to(x,*args,**kw)
if 'cpu' in str(device):
args = []
device = kw.get("device",None)
if 'cpu' in str(device):
kw.pop('device',None)
print("to cpu")
# print(kw)
return jt.to(x,*args,**kw)
Tensor.to = conflict_wrapper(jt.to, to)
mm = wrapper(jt.matmul)
def _data_get(x):
return x
def _data_set(x, value):
x.assign(value)
Tensor.data = property(_data_get, _data_set)
Tensor.layout = None

View File

@ -1,134 +0,0 @@
import jittor as jt
from jittor import Var
from collections.abc import Sequence, Mapping
Variable = Var
class FunctionContext:
def save_for_backward(self, *args):
self.saved_tensors = args
class Function:
''' Function Module for customized backward operations
Example 1 (Function can have multiple input and multiple output, and user
can store value for backward computation)::
import jtorch
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jtorch.array(3.0)
a.requires_grad = True
b = jtorch.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
(c+d*3).backward()
assert a.grad.data == 4
assert b.grad.data == 9
Example 2(Function can return None for no gradiant, and gradiant
can also be None)::
import jtorch
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
assert grad1 is None
return grad0 * self.y, None
a = jt.array(3.0)
a.requires_grad = True
b = jt.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
d.stop_grad()
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
'''
def __call__(self, *args):
backup = args
args = list(args)
taped_inputs = []
taped_outputs = []
input_mask = [-1] * len(args)
for i,v in enumerate(args):
if isinstance(v, Var):
if v.is_stop_grad():
# -2 in input_mask represents it is stop_grad
input_mask[i] = -2
continue
v = v.tape()
input_mask[i] = len(taped_inputs)
args[i] = v
taped_inputs.append(v)
ctx = FunctionContext()
ori_res = self.forward(ctx, *args)
# ori_res = self.execute(*args)
if not isinstance(ori_res, Sequence):
res = [ori_res]
else:
res = list(ori_res)
output_mask = [-1] * len(res)
for i,v in enumerate(res):
if isinstance(v, Var):
v = v.tape()
output_mask[i] = len(taped_outputs)
res[i] = v
taped_outputs.append(v)
ctx.input_mask = input_mask
ctx.output_mask = output_mask
# tape output and input together so
# backward treat them as one operator
jt.tape_together(taped_inputs, taped_outputs,
lambda *args: self._grad(ctx, self, *args))
if isinstance(ori_res, Sequence):
return res
else:
return res[0]
@staticmethod
def _grad(ctx, func, *args):
new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask )
ret = func.backward(ctx, *new_args)
if not isinstance(ret, Sequence):
ret = (ret,)
new_ret = []
for i, r in enumerate(ret):
j = ctx.input_mask[i]
if j<0:
# -2 in input_mask represents it is stop_grad
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
"because the input value is not jittor variable."
else:
new_ret.append(r)
return new_ret
def dfs(self, parents, k, callback, callback_leave=None):
pass
@classmethod
def apply(cls, *args, **kw):
func = cls()
return func(*args, **kw)

View File

@ -1,39 +0,0 @@
import jittor as jt
import jittor_utils
import glob
import os
from jittor import pyjt_compiler
import sys
from jittor_utils import lock
jtorch_path = os.path.dirname(__file__)
cache_path = os.path.join(jt.compiler.cache_path, "jtorch")
# os.makedirs(cache_path, exist_ok=True)
os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True)
with lock.lock_scope():
pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path)
ext_args = 'c[cu]' if jt.has_cuda else 'cc'
files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True)
files += pyjt_gen_src
cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" "
if os.environ.get("use_data_o", "1") == "1":
files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True)
files = [f for f in files if "__data__" not in f]
with lock.lock_scope():
jt.compiler.compile(
jt.compiler.cc_path,
jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags,
files,
"jtorch_core"+jt.compiler.extension_suffix,
obj_dirname="jtorch_objs")
with jittor_utils.import_scope(jt.compiler.import_flags):
import jtorch_core as core
jt.flags.th_mode = 1

View File

@ -1,64 +0,0 @@
import jittor as jt
import jtorch
def is_available():
return jt.has_cuda
def device_count():
return int(jt.has_cuda)
def set_device(device=None):
pass
def get_rng_state(device=None):
pass
def current_device():
return jtorch.device("cuda")
def mem_get_info(i):
return ("75GB",)
class Generator:
def __init__(self):
pass
def set_state(self, state):
self.state = state
default_generators = [Generator()]
_lazy_call = lambda func: func()
device = None
LongTensor = jt.int64
FloatTensor = jt.float
HalfTensor = jt.float16
BoolTensor = jt.bool
manual_seed = jt.set_global_seed
manual_seed_all = jt.set_global_seed
def synchronize():
jt.sync_all(True)
class Event:
pass
class Stream:
pass
from typing import Any
from .gradscaler import GradScaler
class autocast:
def __init__(self,**kwargs):
pass
def __enter__(self,):
pass
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
pass

View File

@ -1,53 +0,0 @@
import datetime
from enum import Enum
import jittor as jt
class DistributedDataParallel:
def __new__(cls, model):
return model
def is_initialized():
return True
def get_rank(group=None):
return 0
def get_world_size(group=None):
return 1
def get_backend(group=None):
return "nccl"
def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None):
return 1
def barrier():
pass
def is_available():
return True
def is_built():
return True
class ReduceOp:
SUM = 0
class GroupMember:
WORLD = 0
class ProcessGroup:
pass
class Join:
pass
dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL"))
_backend = dist_backend.NCCL
def is_mpi_available():
return jt.in_mpi
def DistributedDataParallel(model, *args, **kw):
return model

View File

@ -1,15 +0,0 @@
import jittor as jt
class RelaxedBernoulli:
def __init__(self, temperature, probs=None, logits=None):
self.temperature = temperature
self.probs = probs
self.logits = logits
def rsample(self):
noise = jt.rand_like(self.logits)
eps = 1e-20
noise = jt.clamp(noise, eps, 1.0 - eps)
logit_noise = jt.log(noise) - jt.log(1 - noise)
sample = (self.logits + logit_noise) / self.temperature
return jt.sigmoid(sample)

View File

@ -1,5 +0,0 @@
#TODO: Implement FFT and IFFT
fftn = None
fftshift = None
ifftn = None
ifftshift = None

View File

@ -1,2 +0,0 @@
class Proxy:
pass

View File

@ -1,519 +0,0 @@
from collections import defaultdict, abc
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, cast
import inspect
import warnings
import jittor as jt
# import torch
def _refresh_per_optimizer_state():
return {}
class GradScaler:
_scale: Optional[jt.Var]
_grows_tracker: Optional[jt.Var]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = self._init_scale
self._growth_tracker = self._init_growth_tracker
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return outputs * self._scale
def apply_scale(val):
if isinstance(val, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return val * self._scale
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
with jt.no_grad():
optimizer.pre_step()
for group in optimizer.param_groups:
for to_unscale in group["grads"]:
if to_unscale is None or isinstance(to_unscale,(int,float)):
continue
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
raise ValueError("Attempting to unscale FP16 gradients.")
if not (to_unscale.isinf().any()):
if inv_scale != 1.0:
to_unscale.update(to_unscale*inv_scale)
else:
found_inf = 1.0
return found_inf
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if hasattr(optimizer,"get_find_inf"):
return
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = 1.0 / self._scale
found_inf = 0.0
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if (not self._enabled):
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
found_inf = cast(
jt.Var,
sum([
t for t in optimizer_state["found_inf_per_device"].values()
])
)
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
optimizer.found_inf = found_inf
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
return retval
if hasattr(optimizer,"get_find_inf"):
optimizer.set_grad_scale(self._scale)
optimizer.step()
optimizer_state["found_inf_per_device"] = optimizer.get_find_inf()
return
retval = None
if not optimizer_state["found_inf_per_device"]:
retval = optimizer.step(*args, **kwargs)
else:
optimizer.post_step()
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [state["found_inf_per_device"]
for state in self._per_optimizer_states.values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
current_scale = _scale
if found_inf_combined:
current_scale *=self._backoff_factor
_growth_tracker = 0
else:
successful = _growth_tracker+1
if successful == self._growth_interval:
new_scale = current_scale*self._growth_factor
if new_scale < 1e9:
current_scale = new_scale
_growth_tracker = 0
else:
_growth_tracker = successful
self._scale, self._growth_tracker = current_scale,_growth_tracker
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return self._init_scale if self._scale is None else self._get_scale_async()
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
"of an iteration, or at the end after scaler.update()."
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker()
state['_scale'] = None
state['_growth_tracker'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = 1.0
found_inf = 0.0
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -1,556 +0,0 @@
from collections import defaultdict, abc
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, cast
import inspect
import warnings
import jittor as jt
# import torch
__all__ = ["OptState", "GradScaler"]
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
# as well as associated "enum" values. Prefers defining these at top level because
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
# causes a circular reference, which we'd rather avoid.
class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2
def _refresh_per_optimizer_state():
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler:
_scale: Optional[jt.Var]
_grows_tracker: Optional[jt.Var]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = self._init_scale
self._growth_tracker = self._init_growth_tracker
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
print("scale")
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return outputs * self._scale
def apply_scale(val):
if isinstance(val, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return val * self._scale
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
with jt.no_grad():
optimizer.pre_step()
for group in optimizer.param_groups:
for to_unscale in group["grads"]:
if to_unscale is None or isinstance(to_unscale,(int,float)):
continue
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
raise ValueError("Attempting to unscale FP16 gradients.")
if not (to_unscale.isinf().any()):
if inv_scale != 1.0:
to_unscale.update(to_unscale*inv_scale)
else:
found_inf = 1.0
return found_inf
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = 1.0 / self._scale
found_inf = 0.0
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
if not optimizer_state["found_inf_per_device"]:
retval = optimizer.step(*args, **kwargs)
else:
optimizer.post_step()
return retval
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if (not self._enabled):
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("step() has already been called since the last update().")
retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
found_inf = cast(
jt.Var,
sum([
t for t in optimizer_state["found_inf_per_device"].values()
])
)
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
optimizer.found_inf = found_inf
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
return retval
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [state["found_inf_per_device"]
for state in self._per_optimizer_states.values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
current_scale = _scale
if found_inf_combined:
current_scale *=self._backoff_factor
_growth_tracker = 0
else:
successful = _growth_tracker+1
if successful == self._growth_interval:
new_scale = current_scale*self._growth_factor
if new_scale < 1e9:
current_scale = new_scale
_growth_tracker = 0
else:
_growth_tracker = successful
self._scale, self._growth_tracker = current_scale,_growth_tracker
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return self._init_scale if self._scale is None else self._get_scale_async()
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
"of an iteration, or at the end after scaler.update()."
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker()
state['_scale'] = None
state['_growth_tracker'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = 1.0
found_inf = 0.0
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -1,12 +0,0 @@
import math
def _jit_set_profiling_mode(x): pass
def _jit_set_profiling_executor(x): pass
def _jit_override_can_fuse_on_cpu(x): pass
def _jit_override_can_fuse_on_gpu(x): pass
def script(func):
return func
inf = math.inf
nan = math.nan

View File

@ -1,281 +0,0 @@
import jtorch
from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict
from typing_extensions import Self
import jittor as jt
from jtorch import make_module, Tensor, ModuleMisc, wrapper
#from . import init
from jittor import Function
import operator
import warnings
for k,v in jt.nn.__dict__.items():
if callable(v):
globals()[k] = wrapper(v)
for k,v in jt.nn.__dict__.items():
if isinstance(v, type) and issubclass(v, jt.Module):
globals()[k] = make_module(v)
from collections import OrderedDict
from collections import abc as container_abcs
class Module(ModuleMisc, jt.Module):
def __call__(self, *args, **kw):
return self.execute(*args, **kw)
def execute(self, *args, **kw):
return self.forward(*args, **kw)
def get_submodule(self, target: str):
if target == "":
return self
atoms: List[str] = target.split(".")
mod: jt.nn.Module = self
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(mod._get_name() + " has no "
"attribute `" + item + "`")
mod = getattr(mod, item)
if not isinstance(mod, jt.nn.Module):
raise AttributeError("`" + item + "` is not "
"an nn.Module")
return mod
def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor:
x = x.clone()
x.requires_grad = requires_grad
x.retains_grad = requires_grad
return x
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False):
return jt.nn.embedding(input, weight)
def dropout(x, p=0.5, training=False):
return jt.nn.dropout(x, p, training)
class Flatten(Module):
''' Flattens the contiguous range of dimensions in a Var.
:param start_dim: the first dimension to be flattened. Defaults: 1.
:type start_dim: int
:param end_dim: the last dimension to be flattened. Defaults: -1.
:type end_dim: int
'''
def __init__(self, start_dim=1, end_dim=-1):
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x) -> jt.Var:
return x.flatten(self.start_dim, self.end_dim)
class _IncompatibleKeys:
def __init__(self, missing_keys, unexpected_keys):
self.missing_keys = missing_keys
self.unexpected_keys = unexpected_keys
_BatchNorm = None
#from . import utils
normalize = wrapper(jt.normalize)
T = TypeVar('T', bound=Module)
class ModuleDict(Module):
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key: str) -> Module:
return self._modules[key]
def __setitem__(self, key: str, module: Module) -> None:
self.add_module(key, module)
def __delitem__(self, key: str) -> None:
del self._modules[key]
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[str]:
return iter(self._modules)
def __contains__(self, key: str) -> bool:
return key in self._modules
def clear(self) -> None:
"""Remove all items from the ModuleDict."""
self._modules.clear()
def pop(self, key: str) -> Module:
r"""Remove key from the ModuleDict and return its module.
Args:
key (str): key to pop from the ModuleDict
"""
v = self[key]
del self[key]
return v
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ModuleDict keys."""
return self._modules.keys()
def items(self) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()
def values(self) -> Iterable[Module]:
r"""Return an iterable of the ModuleDict values."""
return self._modules.values()
def update(self, modules: Mapping[str, Module]) -> None:
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(modules).__name__)
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
for key, module in modules.items():
self[key] = module
else:
# modules here can be a list with two items
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError("ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(m).__name__)
if not len(m) == 2:
raise ValueError("ModuleDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) +
"; 2 is required")
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented
class ParameterList(Module):
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
super().__init__()
self._size = 0
if values is not None:
self += values
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
return str(idx)
@overload
def __getitem__(self, idx: int) -> Any:
...
@overload
def __getitem__(self: T, idx: slice) -> T:
...
def __getitem__(self, idx):
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
out = self.__class__()
for i in range(start, stop, step):
out.append(self[i])
return out
else:
idx = self._get_abs_string_index(idx)
return getattr(self, str(idx))
def __setitem__(self, idx: int, param: Any) -> None:
# Note that all other function that add an entry to the list part of
# the ParameterList end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
# Objects added via setattr() are not in the list part and thus won't
# call into this function.
idx = self._get_abs_string_index(idx)
if isinstance(param, jt.Var) and not isinstance(param, Parameter):
param = Parameter(param)
return setattr(self, str(idx), param)
def __len__(self) -> int:
return self._size
def __iter__(self) -> Iterator[Any]:
return iter(self[i] for i in range(len(self)))
def __iadd__(self, parameters: Iterable[Any]) -> Self:
return self.extend(parameters)
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def append(self, value: Any) -> 'ParameterList':
"""Append a given value at the end of the list.
Args:
value (Any): value to append
"""
new_idx = len(self)
self._size += 1
self[new_idx] = value
return self
def extend(self, values: Iterable[Any]) -> Self:
"""Append values from a Python iterable to the end of the list.
Args:
values (iterable): iterable of values to append
"""
# Tensor is an iterable but we never want to unpack it here
if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(values).__name__)
for value in values:
self.append(value)
return self
def extra_repr(self) -> str:
child_lines = []
for k, p in enumerate(self):
if isinstance(p, jt.Var):
size_str = 'x'.join(str(size) for size in p.size())
parastr = '{} containing: [{} of size {}{}]'.format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu")
child_lines.append(' (' + str(k) + '): ' + parastr)
else:
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
tmpstr = '\n'.join(child_lines)
return tmpstr
def __call__(self, *args, **kwargs):
raise RuntimeError('ParameterList should not be called.')

View File

@ -1,16 +0,0 @@
import jittor as jt
for k,v in jt.nn.init.__dict__.items():
if callable(v):
globals()[k] = v
normal = gauss
normal_ = gauss_
xavier_normal = xavier_gauss
xavier_normal_ = xavier_gauss_
zeros_ = zero_
jt.Var.normal_ = normal_

View File

@ -1 +0,0 @@
from . import rnn

View File

@ -1,20 +0,0 @@
import jittor as jt
PackedSequence = None
def pad_sequence(sequences,batch_first=False,padding_value=0.0):
max_f = max([len(s) for s in sequences])
# max_f = 512
b = len(sequences)
if batch_first:
ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value)
for i,s in enumerate(sequences):
ret[i,:len(s)] = s
else:
ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value)
for i,s in enumerate(sequences):
ret[:len(s),i] = s
# print(ret.shape)
# ret = ret[:,:406]
return ret

File diff suppressed because it is too large Load Diff

View File

@ -1,102 +0,0 @@
#include "pyjt/py_obj_holder.h"
#include "utils/str_utils.h"
#include "jtorch_core.h"
#include "graph.h"
#include "grad.h"
#include "ops/op_register.h"
namespace jittor {
void pyjt_def_all(PyObject* m);
EXTERN_LIB void setter_use_cuda(int value);
Device::Device(const string& name, int ordinal) : name(name) {
if (startswith(name, "cpu"))
setter_use_cuda(0);
else
setter_use_cuda(1);
}
unordered_map<int64, VarPtr> grad_backup;
EXTERN_LIB void (*_var_free_hook)(Var*);
EXTERN_LIB unordered_map<int64, VarPtr>* _grad_backup_ptr;
void jtorch_var_free_hook(Var* v) {
auto iter = grad_backup.find(v->id);
if (iter != grad_backup.end()) {
grad_backup.erase(iter);
}
}
void jtorch_init() {
_var_free_hook = &jtorch_var_free_hook;
_grad_backup_ptr = &grad_backup;
}
inline static VarPtr& get_grad(Var* v) {
return grad_backup[v->id];
}
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
inline static void add_grad(VarPtr& a, VarPtr&& b) {
if (!a) a = move(b);
else {
a = make_binary(a, b, ns_add);
}
}
void grad_set(VarHolder* x, Maybe<VarHolder> v) {
if (!v) {
grad_del(x);
return;
}
grad_backup[x->var->id] = v.ptr->var;
}
Maybe<VarHolder> grad_get(VarHolder* x) {
auto iter = grad_backup.find(x->var->id);
if (iter != grad_backup.end()) {
if (!iter->second.ptr) return nullptr;
return new VarHolder(iter->second.ptr);
}
return nullptr;
}
void grad_del(VarHolder* x) {
auto iter = grad_backup.find(x->var->id);
if (iter != grad_backup.end())
grad_backup.erase(iter);
}
void backward(VarHolder* x) {
vector<Node*> gnodes({x->var});
bfs_backward(gnodes, [&](Node* node) {
if (node->is_stop_grad())
return false;
return true;
});
vector<Var*> targets;
for (auto* node : gnodes) {
if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad))
targets.push_back(node->var());
}
auto grads = grad(x->var, targets);
for (int i=0; i<targets.size(); i++) {
auto& gptr = get_grad(targets[i]);
add_grad(gptr, move(grads[i]));
}
}
}
static void init_module(PyModuleDef* mdef, PyObject* m) {
jittor::jtorch_init();
mdef->m_doc = "Inner c++ core of jtorch";
jittor::pyjt_def_all(m);
}
PYJT_MODULE_INIT(jtorch_core);

View File

@ -1,40 +0,0 @@
#pragma once
#include "common.h"
#include "var_holder.h"
#include "misc/fast_shared_ptr.h"
namespace jittor {
// @pyjt(device)
// @attrs(heaptype)
struct Device {
string name;
// @pyjt(__init__)
Device(const string& name, int ordinal=0);
// @pyjt(__get__type, __str__)
inline string get_type() {return name;}
// @pyjt(__get__index)
inline int index() {return 0;}
};
// @pyjt(backward)
void backward(VarHolder* x);
// @pyjt(grad_set)
void grad_set(VarHolder* x, Maybe<VarHolder> v);
// @pyjt(grad_get)
Maybe<VarHolder> grad_get(VarHolder* x);
// @pyjt(grad_del)
void grad_del(VarHolder* x);
// @pyjt(retain_grad_set)
inline void retain_grad_set(VarHolder* x, bool v) {
x->var->flags.set(NodeFlags::_th_require_grad, v);
}
// @pyjt(retain_grad_get)
inline bool retain_grad_get(VarHolder* x) {
return x->var->flags.get(NodeFlags::_th_require_grad);
}
}

View File

@ -1,25 +0,0 @@
import unittest
import numpy as np
import torch
import jittor as jt
class TestConflictFunc(unittest.TestCase):
def test_max(self):
a = torch.Tensor([1,4,2])
assert a.max() == 4
v, k = a.max(dim=0)
assert v==4 and k==1
def test_argsort(self):
a = torch.Tensor([1,4,2])
k = a.argsort()
assert jt.all_equal(k, [0,2,1])
with jt.flag_scope(th_mode=0):
k, v = a.argsort()
assert jt.all_equal(k, [0,2,1])
if __name__ == "__main__":
unittest.main()

View File

@ -1,58 +0,0 @@
import unittest
import numpy as np
import torch
class TestFunction(unittest.TestCase):
def test_example1(self):
import jtorch
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jtorch.array(3.0)
a.requires_grad = True
b = jtorch.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
(c+d*3).backward()
assert a.grad.data == 4
assert b.grad.data == 9
def test_example2(self):
import jtorch as jt
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
assert grad1 is None
return grad0 * self.y, None
a = jt.array(3.0)
a.requires_grad = True
b = jt.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
d.stop_grad()
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
if __name__ == "__main__":
unittest.main()

View File

@ -1,24 +0,0 @@
import unittest
import numpy as np
import torch
class TestMisc(unittest.TestCase):
def test_update_grad(self):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0]))
net = Net()
assert(net.a.requires_grad)
net.load_state_dict({"a": torch.Tensor([3.0, 4.0])})
assert(net.a.requires_grad)
def test_reshape(self):
a = torch.ones(3,3)
a.requires_grad = True
b = torch.reshape(a, [9])
assert b.requires_grad == True
if __name__ == "__main__":
unittest.main()

View File

@ -1,56 +0,0 @@
import unittest
import numpy as np
import os
import subprocess as sp
import sys
def check_two(cmd, parser=None, checker=None):
jtorch_out = sp.getoutput(cmd)
print("=========JTORCH OUT==========")
print(jtorch_out)
torch_out = sp.getoutput("PYTHONPATH= "+cmd)
print("=========TORCH OUT==========")
print(torch_out)
if parser:
torch_out = parser(torch_out)
jtorch_out = parser(jtorch_out)
if checker:
checker(torch_out, jtorch_out)
else:
assert torch_out == jtorch_out
return jtorch_out, torch_out
jtorch_path = os.path.join(os.path.dirname(__file__), "..")
# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
class TestTutorial(unittest.TestCase):
def test_auto_grad1(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad2(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad3(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py",
parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad4(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad5(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
def test_auto_grad6(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad7(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py",
parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
if __name__ == "__main__":
unittest.main()

View File

@ -1,44 +0,0 @@
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(20000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 1000 == 999:
print(t, loss)
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
# print(t, torch.liveness_info())
# torch.sync_all()
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

View File

@ -1,60 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(20000):
# Forward pass: compute predicted y using operations on Tensors.
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# print(y_pred.requires_grad)
# y_pred.requires_grad = False
# Compute and print loss using operations on Tensors.
# Now loss is a Tensor of shape (1,)
# loss.item() gets the scalar value held in the loss.
loss = (y_pred - y).pow(2).sum()
if t % 1000 == 990:
print(t, loss.item())
# Use autograd to compute the backward pass. This call will compute the
# gradient of loss with respect to all Tensors with requires_grad=True.
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
# the gradient of the loss with respect to a, b, c, d respectively.
# torch.backward(loss)
loss.backward()
# Manually update weights using gradient descent. Wrap in torch.no_grad()
# because weights have requires_grad=True, but we don't need to track this
# in autograd.
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

View File

@ -1,85 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
class LegendrePolynomial3(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
"""
@staticmethod
def forward(ctx, input):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
ctx.save_for_backward(input)
return 0.5 * (5 * input ** 3 - 3 * input)
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
input, = ctx.saved_tensors
return grad_output * 1.5 * (5 * input ** 2 - 1)
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For this example, we need
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
# not too far from the correct result to ensure convergence.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
learning_rate = 5e-6
for t in range(2000):
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
P3 = LegendrePolynomial3.apply
# Forward pass: compute predicted y using operations; we compute
# P3 using our custom autograd operation.
y_pred = a + b * P3(c + d * x)
# Compute and print loss
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# Use autograd to compute the backward pass.
loss.backward()
# Update weights using gradient descent
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)')

View File

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# For this example, the output y is a linear function of (x, x^2, x^3), so
# we can consider it as a linear layer neural network. Let's prepare the
# tensor (x, x^2, x^3).
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)
# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape
# (3,), for this case, broadcasting semantics will apply to obtain a tensor
# of shape (2000, 3)
# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. The Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
# The Flatten layer flatens the output of the linear layer to a 1D tensor,
# to match the shape of `y`.
model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Flatten(0, 1)
)
# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')
# print(model[0].weight.requires_grad)
learning_rate = 1e-6
for t in range(8000):
# Forward pass: compute predicted y by passing x to the model. Module objects
# override the __call__ operator so you can call them like functions. When
# doing so you pass a Tensor of input data to the Module and it produces
# a Tensor of output data.
y_pred = model(xx)
# Compute and print loss. We pass Tensors containing the predicted and true
# values of y, and the loss function returns a Tensor containing the
# loss.
loss = loss_fn(y_pred, y)
if t % 1000 == 999:
print(t, loss.item())
# Zero the gradients before running the backward pass.
model.zero_grad()
# Backward pass: compute gradient of the loss with respect to all the learnable
# parameters of the model. Internally, the parameters of each Module are stored
# in Tensors with requires_grad=True, so this call will compute gradients for
# all learnable parameters in the model.
loss.backward()
# Update the weights using gradient descent. Each parameter is a Tensor, so
# we can access its gradients like we did before.
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
# You can access the first layer of `model` like accessing the first item of a list
linear_layer = model[0]
# For linear layer, its parameters are stored as `weight` and `bias`.
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')

View File

@ -1,53 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Prepare the input tensor (x, x^2, x^3).
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Flatten(0, 1)
)
loss_fn = torch.nn.MSELoss(reduction='sum')
# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use RMSprop; the optim package contains many other
# optimization algorithms. The first argument to the RMSprop constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
for t in range(8000):
# Forward pass: compute predicted y by passing x to the model.
y_pred = model(xx)
# Compute and print loss.
loss = loss_fn(y_pred, y)
if t % 1000 == 999:
print(t, loss.item())
# Before the backward pass, use the optimizer object to zero all of the
# gradients for the variables it will update (which are the learnable
# weights of the model). This is because by default, gradients are
# accumulated in buffers( i.e, not overwritten) whenever .backward()
# is called. Checkout docs of torch.autograd.backward for more details.
optimizer.zero_grad()
# Backward pass: compute gradient of the loss with respect to model
# parameters
loss.backward()
# Calling the step function on an Optimizer makes an update to its
# parameters
optimizer.step()
linear_layer = model[0]
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')

View File

@ -1,59 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
class Polynomial3(torch.nn.Module):
def __init__(self):
"""
In the constructor we instantiate four parameters and assign them as
member parameters.
"""
super().__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Construct our model by instantiating the class defined above
model = Polynomial3()
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters (defined
# with torch.nn.Parameter) which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
for t in range(8000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 1000 == 999:
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Result: {model.string()}')

View File

@ -1,69 +0,0 @@
# -*- coding: utf-8 -*-
import random
import torch
import math
class DynamicNet(torch.nn.Module):
def __init__(self):
"""
In the constructor we instantiate five parameters and assign them as members.
"""
super().__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))
self.e = torch.nn.Parameter(torch.randn(()))
def forward(self, x):
"""
For the forward pass of the model, we randomly choose either 4, 5
and reuse the e parameter to compute the contribution of these orders.
Since each forward pass builds a dynamic computation graph, we can use normal
Python control-flow operators like loops or conditional statements when
defining the forward pass of the model.
Here we also see that it is perfectly safe to reuse the same parameter many
times when defining a computational graph.
"""
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
for exp in range(4, random.randint(4, 6)):
y = y + self.e * x ** exp
return y
def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Construct our model by instantiating the class defined above
model = DynamicNet()
# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
for t in range(60000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 2000 == 1999:
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(torch.liveness_info())
print(f'Result: {model.string()}')

View File

@ -1,106 +0,0 @@
import torch
from torch import nn
# from jtorch.utils import DataLoader
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
print(len(train_dataloader))
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 5
test(test_dataloader, model, loss_fn)
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")

View File

@ -1,5 +0,0 @@
cpp_extension = None
_flatten_dense_tensors = None
_unflatten_dense_tensors = None
tensorboard = None

View File

@ -1,3 +0,0 @@
#TODO: Implement this
_register_pytree_node = None
_dict_flatten = None

View File

@ -1,8 +0,0 @@
detach_variable = None
def checkpoint(
*args,
**kwargs
):
pass

View File

@ -1,137 +0,0 @@
import jittor as jt
import jittor.dataset
from jittor.dataset import Dataset as JDataset
from collections import namedtuple
from typing import Any, Callable, Iterable, Optional, Sequence, Union
class Dataset:
def __getitem__(self, index):
raise NotImplementedError
class IterableDataset:
def __iter__(self):
raise NotImplementedError
class DataLoader(JDataset):
def __init__(self, dataset,
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = False,
sampler = None,
batch_sampler = None,
num_workers: int = 0,
collate_fn = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn = None,
multiprocessing_context=None,
generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False,
pin_memory_device: str = "") -> None:
super().__init__(batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last)
unsupported_kwargs = {
"batch_sampler": batch_sampler,
"pin_memory": pin_memory,
"timeout": timeout,
"worker_init_fn": worker_init_fn,
"multiprocessing_context": multiprocessing_context,
"generator": generator,
"persistent_workers": persistent_workers,
"pin_memory_device": pin_memory_device
}
for kwarg, value in unsupported_kwargs.items():
if value:
jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}")
self.dataset = dataset
self.collate_fn = collate_fn
self.sampler = sampler
if not isinstance(dataset, IterableDataset):
self.total_len = len(dataset)
else:
# TODO: support multiple worker for iterable dataset
assert(num_workers == 0)
def collate_batch(self, batch):
if self.collate_fn is not None:
return self.collate_fn(batch)
else:
return super().collate_batch(batch)
def __getitem__(self, i):
return self.dataset[i]
def __iter__(self):
if isinstance(self.dataset, IterableDataset):
return self.inner_iter()
else:
return super().__iter__()
def inner_iter(self):
current_batch = []
if jt.world_size > 1:
assert self.batch_size % jt.world_size == 0, \
f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}"
real_batch_size = int(self.batch_size / jt.world_size)
else:
real_batch_size = self.batch_size
for element in self.dataset:
current_batch.append(element)
if len(current_batch) == real_batch_size:
current_batch = self.collate_batch(current_batch)
current_batch = self.to_jittor(current_batch)
yield current_batch
current_batch = []
if not self.drop_last and len(current_batch) > 0:
current_batch = self.collate_batch(current_batch)
yield self.to_jittor(current_batch)
def get_worker_info():
# always return the fake worker info
return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
class RandomSampler(jt.dataset.RandomSampler):
def __init__(self, dataset, generator=None, **kwargs):
super().__init__(dataset, **kwargs)
def __iter__(self):
if getattr(self.dataset, "support_random_access", True):
return super().__iter__()
else:
self.dataset.shuffle()
return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
class DistributedSampler(jt.dataset.Sampler):
def __init__(self, sampler: RandomSampler):
assert(isinstance(sampler, RandomSampler))
self.sampler = sampler
def set_epoch(self, epoch: int):
### do nothing, let jittor's inner dataset handle
pass
def __iter__(self):
return self.sampler.__iter__()
def __len__(self):
return self.sampler.__len__()
BatchSampler = jt.dataset.BatchSampler
Sampler = jt.dataset.Sampler
SequentialSampler = jt.dataset.SequentialSampler
SubsetRandomSampler = jt.dataset.SubsetRandomSampler
TensorDataset = Dataset

View File

@ -1,9 +0,0 @@
from typing import Callable, Union
Dtype = Union[Callable, str]
def get_string_dtype(dtype):
if callable(dtype):
dtype = dtype.__name__
if not isinstance(dtype, str):
raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.")
return dtype

View File

@ -1,34 +0,0 @@
import os
import glob
import shutil
import sys
home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..")
home_path = os.path.abspath(home_path)
def callback(func, path, exc_info):
print(f"remove \"{path}\" failed.")
def rmtree(path):
if os.path.isdir(path):
print(f"remove \"{path}\" recursive.")
shutil.rmtree(path, onerror=callback)
def remove_tmpfile():
dist_file = home_path+"/dist"
egg_file = glob.glob(home_path+"/**/*egg-info")
rmtree(dist_file)
for e in egg_file:
rmtree(e)
def run_cmd(cmd):
print("[CMD]", cmd)
assert os.system(cmd)==0
os.chdir(home_path)
remove_tmpfile()
run_cmd(f"{sys.executable} ./setup.py sdist")
run_cmd(f"{sys.executable} -m twine upload dist/*")
remove_tmpfile()

View File

@ -1,46 +0,0 @@
import importlib.machinery
import os
def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass
def _is_remote_location_available() -> bool:
return False
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise err
kernel32.SetErrorMode(prev_error_mode)
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
return ext_specs.origin

View File

@ -1,9 +0,0 @@
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
__all__ = (
"EMNIST",
"FashionMNIST",
"QMNIST",
"MNIST",
"KMNIST",
)

View File

@ -1,558 +0,0 @@
import codecs
import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
import numpy as np
import torch
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self.data, self.targets = self._load_data()
def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")
@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "processed")
@property
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self) -> bool:
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)
def download(self) -> None:
"""Download the MNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
# download files
for filename, md5 in self.resources:
for mirror in self.mirrors:
url = f"{mirror}{filename}"
try:
print(f"Downloading {url}")
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error:
print(f"Failed to download (trying next):\n{error}")
continue
finally:
print()
break
else:
raise RuntimeError(f"Error downloading {filename}")
def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
resources = [
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
]
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
class KMNIST(MNIST):
"""`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
resources = [
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
]
classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
class EMNIST(MNIST):
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
# Merged Classes assumes Same structure for both uppercase and lowercase version
_merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
_all_classes = set(string.digits + string.ascii_letters)
classes_split_dict = {
"byclass": sorted(list(_all_classes)),
"bymerge": sorted(list(_all_classes - _merged_classes)),
"balanced": sorted(list(_all_classes - _merged_classes)),
"letters": ["N/A"] + list(string.ascii_lowercase),
"digits": list(string.digits),
"mnist": list(string.digits),
}
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super().__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split]
@staticmethod
def _training_file(split) -> str:
return f"training_{split}.pt"
@staticmethod
def _test_file(split) -> str:
return f"test_{split}.pt"
@property
def _file_prefix(self) -> str:
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
@property
def images_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
@property
def labels_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
def _load_data(self):
return read_image_file(self.images_file), read_label_file(self.labels_file)
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def download(self) -> None:
"""Download the EMNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, "gzip")
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith(".gz"):
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder)
class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
Args:
root (string): Root directory of dataset whose ``raw``
subdir contains binary files of the datasets.
what (string,optional): Can be 'train', 'test', 'test10k',
'test50k', or 'nist' for respectively the mnist compatible
training set, the 60k qmnist testing set, the 10k qmnist
examples that match the mnist testing set, the 50k
remaining qmnist testing examples, or all the nist
digits. The default is to select 'train' or 'test'
according to the compatibility argument 'train'.
compat (bool,optional): A boolean that says whether the target
for each example is class number (for compatibility with
the MNIST dataloader) or a torch vector containing the
full qmnist information. Default=True.
download (bool, optional): If True, downloads the dataset from
the internet and puts it in root directory. If dataset is
already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that
takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set ot the testing set. Default: True.
"""
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
"train": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
"ed72d4157d28c017586c42bc6afe6370",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
"0058f8dd561b90ffdd0f734c6a30e5e4",
),
],
"test": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
"1394631089c404de565df7b7aeaf9412",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
"5b5b05890a5e13444e108efe57b788aa",
),
],
"nist": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
"7f124b3b8ab81486c9d8c2749c17f834",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
"5ed0e788978e45d4a8bd4b7caec3d79d",
),
],
}
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = "train" if train else "test"
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + ".pt"
self.training_file = self.data_file
self.test_file = self.data_file
super().__init__(root, train, **kwargs)
@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
if data.dtype != torch.uint8:
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
if data.ndimension() != 3:
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
if targets.ndimension() != 2:
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
if self.what == "test10k":
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == "test50k":
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
return data, targets
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]]
for url, md5 in split:
download_and_extract_archive(url, self.raw_folder, md5=md5)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.compat:
target = int(target[0])
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def extra_repr(self) -> str:
return f"Split: {self.what}"
def get_int(b: bytes) -> int:
return int(codecs.encode(b, "hex"), 16)
SN3_PASCALVINCENT_BITSMAP = {
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}
TORCH_TYPE_BITS = {
torch.uint8: 8,
torch.int8: 8,
torch.int16: 16,
torch.int32: 32,
torch.float32: 32,
torch.float64: 64,
}
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open(path, "rb") as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert 1 <= nd <= 3
assert 8 <= ty <= 14
torch_type = SN3_PASCALVINCENT_BITSMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = parsed.flip(0)
assert parsed.shape[0] == np.prod(s) or not strict
return parsed.view(*s)
def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long()
def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x

View File

@ -1,522 +0,0 @@
import bz2
import contextlib
import gzip
import hashlib
import itertools
import lzma
import os
import os.path
import pathlib
import re
import sys
import tarfile
import urllib
import urllib.error
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
from urllib.parse import urlparse
import numpy as np
import requests
import torch
from tqdm import tqdm
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
USER_AGENT = "pytorch/vision"
def _save_response_content(
content: Iterator[bytes],
destination: str,
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
for chunk in content:
# filter out keep-alive new chunks
if not chunk:
continue
fh.write(chunk)
pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
if sys.version_info >= (3, 9):
md5 = hashlib.md5(usedforsecurity=False)
else:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
initial_url = url
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
for _ in range(max_hops + 1):
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
if response.url == url or response.url is None:
return url
url = response.url
else:
raise RecursionError(
f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
)
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
# check if file is already present locally
if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
return
if _is_remote_location_available():
_download_file_from_remote_location(fpath, url)
else:
# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
if url[:5] == "https":
url = url.replace("https:", "http:")
print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)
try:
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
api_response = None
return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if check_integrity(fpath, md5):
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
return
url = "https://drive.google.com/uc"
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None
if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)
if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
_save_response_content(content, fpath)
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
if os.stat(fpath).st_size < 10 * 1024:
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
text = fh.read()
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
warnings.warn(
f"We detected some HTML elements in the downloaded file. "
f"This most likely means that the download triggered an unhandled API response by GDrive. "
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
f"the response:\n\n{text}"
)
if md5 and not check_md5(fpath, md5):
raise RuntimeError(
f"The MD5 checksum of the download file {fpath} does not match the one on record."
f"Please delete the file and try again. "
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
)
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
".bz2": zipfile.ZIP_BZIP2,
".xz": zipfile.ZIP_LZMA,
}
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
".bz2": bz2.open,
".gz": gzip.open,
".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
".tbz": (".tar", ".bz2"),
".tbz2": (".tar", ".bz2"),
".tgz": (".tar", ".gz"),
}
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.
Args:
file (str): the filename
Returns:
(tuple): tuple of suffix, archive type, and compression
Raises:
RuntimeError: if file has no suffix or suffix is not supported
"""
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
suffix = suffixes[-1]
# check if the suffix is a known alias
if suffix in _FILE_TYPE_ALIASES:
return (suffix, *_FILE_TYPE_ALIASES[suffix])
# check if the suffix is an archive type
if suffix in _ARCHIVE_EXTRACTORS:
return suffix, suffix, None
# check if the suffix is a compression
if suffix in _COMPRESSED_FILE_OPENERS:
# check for suffix hierarchy
if len(suffixes) > 1:
suffix2 = suffixes[-2]
# check if the suffix2 is an archive type
if suffix2 in _ARCHIVE_EXTRACTORS:
return suffix2 + suffix, suffix2, suffix
return suffix, None, suffix
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
r"""Decompress a file.
The compression is automatically detected from the file name.
Args:
from_path (str): Path to the file to be decompressed.
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
remove_finished (bool): If ``True``, remove the file after the extraction.
Returns:
(str): Path to the decompressed file.
"""
suffix, archive_type, compression = _detect_file_type(from_path)
if not compression:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
wfh.write(rfh.read())
if remove_finished:
os.remove(from_path)
return to_path
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
"""Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
but not an archive the call is dispatched to :func:`decompress`.
Args:
from_path (str): Path to the file to be extracted.
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
used.
remove_finished (bool): If ``True``, remove the file after the extraction.
Returns:
(str): Path to the directory the file was extracted to.
"""
if to_path is None:
to_path = os.path.dirname(from_path)
suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
extractor(from_path, to_path, compression)
if remove_finished:
os.remove(from_path)
return to_path
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print(f"Extracting {archive} to {extract_root}")
extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"
T = TypeVar("T", str, bytes)
def verify_str_arg(
value: T,
arg: Optional[str] = None,
valid_values: Optional[Iterable[T]] = None,
custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)
if valid_values is None:
return value
if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
raise ValueError(msg)
return value
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
file_name (str): Path to the file.
slice_channels (int): Number of channels to slice out of the file.
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header not in [b"PF", b"Pf"]:
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
pfm_channels = 3 if header == b"PF" else 1
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:slice_channels, :, :]
return data.astype(np.float32)

View File

@ -1,104 +0,0 @@
import os
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.utils.data as data
from ..utils import _log_api_usage_once
class VisionDataset(data.Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
Args:
root (string): Root directory of dataset.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
.. note::
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
_repr_indent = 4
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can be passed as argument")
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index: int) -> Any:
"""
Args:
index (int): Index
Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None:
body.append(f"Root location: {self.root}")
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def extra_repr(self) -> str:
return ""
class StandardTransform:
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
self.transform = transform
self.target_transform = target_transform
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform, "Target transform: ")
return "\n".join(body)

View File

@ -1 +0,0 @@
from jittor.transform import *

View File

@ -1,582 +0,0 @@
import collections
import math
import pathlib
import warnings
from itertools import repeat
from types import FunctionType
from typing import Any, BinaryIO, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageColor, ImageDraw, ImageFont
__all__ = [
"make_grid",
"save_image",
"draw_bounding_boxes",
"draw_segmentation_masks",
"draw_keypoints",
"flow_to_image",
]
@torch.no_grad()
def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: float = 0.0,
**kwargs,
) -> torch.Tensor:
"""
Make a grid of images.
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
nrow (int, optional): Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by ``value_range``. Default: ``False``.
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
range (tuple. optional):
.. warning::
This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
instead.
scale_each (bool, optional): If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
Returns:
grid (Tensor): the tensor containing grid of images.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(make_grid)
if not torch.is_tensor(tensor):
if isinstance(tensor, list):
for t in tensor:
if not torch.is_tensor(t):
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
else:
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
if "range" in kwargs.keys():
warnings.warn(
"The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
"Please use 'value_range' instead."
)
value_range = kwargs["range"]
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)
if tensor.dim() == 2: # single image H x W
tensor = tensor.unsqueeze(0)
if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if value_range is not None and not isinstance(value_range, tuple):
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, value_range)
else:
norm_range(tensor, value_range)
if not isinstance(tensor, torch.Tensor):
raise TypeError("tensor should be of type torch.Tensor")
if tensor.size(0) == 1:
return tensor.squeeze(0)
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
# Tensor.copy_() is a valid method but seems to be missing from the stubs
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
2, x * width + padding, width - padding
).copy_(tensor[k])
k = k + 1
return grid
@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[str, pathlib.Path, BinaryIO],
format: Optional[str] = None,
**kwargs,
) -> None:
"""
Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
fp (string or file object): A filename or a file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(save_image)
grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
@torch.no_grad()
def draw_bounding_boxes(
image: torch.Tensor,
boxes: torch.Tensor,
labels: Optional[List[str]] = None,
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
fill: Optional[bool] = False,
width: int = 1,
font: Optional[str] = None,
font_size: Optional[int] = None,
) -> torch.Tensor:
"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
If fill is True, Resulting Tensor should be saved as PNG image.
Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes.
colors (color or list of colors, optional): List containing the colors
of the boxes or single color for all boxes. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for boxes.
fill (bool): If `True` fills the bounding box with specified color.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}:
raise ValueError("Only grayscale and RGB images are supported")
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
raise ValueError(
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
)
num_boxes = boxes.shape[0]
if num_boxes == 0:
warnings.warn("boxes doesn't contain any box. No box was drawn")
return image
if labels is None:
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
elif len(labels) != num_boxes:
raise ValueError(
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
)
if colors is None:
colors = _generate_color_palette(num_boxes)
elif isinstance(colors, list):
if len(colors) < num_boxes:
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
else: # colors specifies a single color for all boxes
colors = [colors] * num_boxes
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
if font is None:
if font_size is not None:
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
txt_font = ImageFont.load_default()
else:
txt_font = ImageFont.truetype(font=font, size=font_size or 10)
# Handle Grayscale images
if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1))
ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist()
if fill:
draw = ImageDraw.Draw(img_to_draw, "RGBA")
else:
draw = ImageDraw.Draw(img_to_draw)
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
if fill:
fill_color = color + (100,)
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
else:
draw.rectangle(bbox, width=width, outline=color)
if label is not None:
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
@torch.no_grad()
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
alpha: float = 0.8,
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
) -> torch.Tensor:
"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
colors (color or list of colors, optional): List containing the colors
of the masks or single color for all masks. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for each mask.
Returns:
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_segmentation_masks)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if masks.ndim == 2:
masks = masks[None, :, :]
if masks.ndim != 3:
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
if masks.dtype != torch.bool:
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
if masks.shape[-2:] != image.shape[-2:]:
raise ValueError("The image and the masks must have the same height and width")
num_masks = masks.size()[0]
if colors is not None and num_masks > len(colors):
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image
if colors is None:
colors = _generate_color_palette(num_masks)
if not isinstance(colors, list):
colors = [colors]
if not isinstance(colors[0], (tuple, str)):
raise ValueError("colors must be a tuple or a string, or a list thereof")
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
out_dtype = torch.uint8
colors_ = []
for color in colors:
if isinstance(color, str):
color = ImageColor.getrgb(color)
colors_.append(torch.tensor(color, dtype=out_dtype))
img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this
for mask, color in zip(masks, colors_):
img_to_draw[:, mask] = color[:, None]
out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype)
@torch.no_grad()
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
connectivity: Optional[List[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
) -> torch.Tensor:
"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
radius (int): Integer denoting radius of keypoint.
width (int): Integer denoting width of line connecting keypoints.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if keypoints.ndim != 3:
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
img_kpts = keypoints.to(torch.int64).tolist()
for kpt_id, kpt_inst in enumerate(img_kpts):
for inst_id, kpt in enumerate(kpt_inst):
x1 = kpt[0] - radius
x2 = kpt[0] + radius
y1 = kpt[1] - radius
y2 = kpt[1] + radius
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
if connectivity:
for connection in connectivity:
start_pt_x = kpt_inst[connection[0]][0]
start_pt_y = kpt_inst[connection[0]][1]
end_pt_x = kpt_inst[connection[1]][0]
end_pt_y = kpt_inst[connection[1]][1]
draw.line(
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
width=width,
)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
@torch.no_grad()
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
"""
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
Returns:
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""
if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
orig_shape = flow.shape
if flow.ndim == 3:
flow = flow[None] # Add batch dim
if flow.ndim != 4 or flow.shape[1] != 2:
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon)
img = _normalized_flow_to_image(normalized_flow)
if len(orig_shape) == 3:
img = img[0] # Remove batch dim
return img
@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""
Converts a batch of normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
Returns:
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
"""
N, _, H, W = normalized_flow.shape
device = normalized_flow.device
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
colorwheel = _make_colorwheel().to(device) # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1
k1[k1 == num_cols] = 0
f = fk - k0
for c in range(colorwheel.shape[1]):
tmp = colorwheel[:, c]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col)
flow_image[:, c, :, :] = torch.floor(255 * col)
return flow_image
def _make_colorwheel() -> torch.Tensor:
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
Returns:
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = torch.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
col = col + RY
# YG
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
colorwheel[col : col + YG, 1] = 255
col = col + YG
# GC
colorwheel[col : col + GC, 1] = 255
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
col = col + GC
# CB
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
colorwheel[col : col + CB, 2] = 255
col = col + CB
# BM
colorwheel[col : col + BM, 2] = 255
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
col = col + BM
# MR
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
colorwheel[col : col + MR, 0] = 255
return colorwheel
def _generate_color_palette(num_objects: int):
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
return [tuple((i * palette) % 255) for i in range(num_objects)]
def _log_api_usage_once(obj: Any) -> None:
"""
Logs API usage(module and name) within an organization.
In a large ecosystem, it's often useful to track the PyTorch and
TorchVision APIs usage. This API provides the similar functionality to the
logging module in the Python stdlib. It can be used for debugging purpose
to log which methods are used and by default it is inactive, unless the user
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
Please note it is triggered only once for the same API call within a process.
It does not collect any data from open-source users since it is no-op by default.
For more information, please refer to
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
* Logging policy: https://github.com/pytorch/vision/issues/5052;
Args:
obj (class instance or method): an object to extract info from.
"""
pass
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))

View File

@ -224,7 +224,7 @@ def setup_cuda_extern():
line = traceback.format_exc()
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
libs = ["cublas", "cudnn", "curand", "cufft"]
libs = ["cublas", "cudnn", "curand", "cufft", "cusparse"]
# in cuda 11.4, module memory comsumptions:
# default context: 259 MB
# cublas: 340 MB
@ -240,6 +240,9 @@ def setup_cuda_extern():
msg += """Develop version of CUDNN not found,
please refer to CUDA offical tar file installation:
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar"""
if lib_name == "cusparse":
msg += """CUSPARSE library is not loaded,
please ensure it is installed along with the CUDA toolkit."""
if platform.machine() in ["x86_64", "AMD64"]:
msg += f"""
or you can let jittor install cuda and cudnn for you:
@ -300,6 +303,13 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
link_flags = f"-l{lib_name} -L\"{os.path.dirname(culib_path)}\""
# print("link_flags", link_flags, culib_path)
if lib_name == "cusparse" :
try:
cusparse_spmv_path = search_file([cuda_lib, extra_lib_path], "libcusparse.so")
ctypes.CDLL(cusparse_spmv_path, dlopen_flags)
except:
LOG.w("Failed to load cusparse-specific shared libraries.")
# find all source files
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
culib_src_files = []
@ -446,7 +456,8 @@ def setup_cutt():
def install_cutlass(root_folder):
# Modified from: https://github.com/ap-hynninen/cutlass
url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
# url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/cutlass.zip"
filename = "cutlass.zip"
fullname = os.path.join(root_folder, filename)
@ -600,6 +611,26 @@ def setup_nccl():
nccl_ops = nccl.ops
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
def setup_hccl():
global hccl_ops
hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl")
hccl_src_files = []
for r, _, f in os.walk(hccl_src_dir):
for fname in f:
hccl_src_files.append(os.path.join(r, fname))
hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl")
hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so")
ctypes.CDLL(hccl_lib_name, dlopen_flags)
hccl = compile_custom_ops(hccl_src_files,
extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ",
return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
gen_name_="jittor_hccl_core")
hccl_ops = hccl.ops
LOG.vv("Get hccl_ops: "+str(dir(hccl_ops)))
def manual_link(flags):
lib_dirs = []
libs = []
@ -629,6 +660,7 @@ def setup_mpi():
mpi_ops = None
mpi = None
has_mpi = False
if not use_mpi: return
mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
if mpicc_path == "":
# LOG.i("mpicc not found, distribution disabled.")
@ -692,12 +724,18 @@ if FIX_TORCH_ERROR:
except:
pass
cudnn = cublas = curand = cufft = None
cudnn = cublas = curand = cufft = cusparse = None
setup_mpi()
rank = mpi.world_rank() if in_mpi else 0
world_size = mpi.world_size() if in_mpi else 1
# if has_acl:
# setup_hccl()
# elif has_cuda:
# setup_nccl()
# setup_cutt()
# setup_cutlass()
setup_nccl()
setup_cutt()
setup_cutlass()
@ -711,4 +749,4 @@ setup_cuda_extern()
# install backend extern library
for mod in jit_utils.backends:
if mod.install_extern():
break
break

View File

@ -1002,6 +1002,8 @@ if nvcc_path:
r, s = sp.getstatusoutput(f"log_v=0 {sys.executable} -m jittor_utils.query_cuda_cc")
if r==0:
s = sorted(list(set(s.strip().split())))
if len(s)==0:
LOG.e("No GPU Device Found!")
cu += "_sm_" + "_".join(s)
if "cuda_arch" not in os.environ:
os.environ["cuda_arch"] = " ".join(cu)
@ -1186,7 +1188,22 @@ make_cache_dir(ck_path)
# build cache_compile
cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" "
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
if ascend_toolkit_home:
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" "
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" "
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" "
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" "
cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" "
cc_flags += " -llibascendcl "
cc_flags += " -llibnnopbase "
cc_flags += " -llibopapi "
cc_flags += py_include
check_cache_compile()
LOG.v(f"Get cache_compile: {jit_utils.cc}")
@ -1369,7 +1386,7 @@ if has_cuda and is_cuda:
nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " "
nvcc_flags += convert_nvcc_flags(cc_flags)
nvcc_version = list(jit_utils.get_int_version(nvcc_path))
max_arch = 89
max_arch = 90
if nvcc_version < [11,]:
max_arch = 75
elif nvcc_version < [11,1]:

View File

@ -15,6 +15,8 @@ from collections.abc import Sequence
def argmax_pool(x, size, stride, padding=0):
if stride<=0:
raise RuntimeError(f"stride must be > 0, but got {stride}")
return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim):
@ -241,9 +243,17 @@ Example::
if len(arr) == 0:
raise ValueError("need at least one array to concat")
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
base_dim = len(arr[0].shape)
if dim < 0: dim += base_dim
if dim < 0 or dim >= base_dim:
raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})")
dtypes = []
for a in arr:
if len(a.shape) != base_dim:
raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}")
for i in range(base_dim):
if i != dim and a.shape[i] != arr[0].shape[i]:
raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.")
total_dim += a.shape[dim]
dtypes.append(str(a.dtype))
cdim = 0

View File

@ -26,7 +26,7 @@ class MNIST(Dataset):
[in] data_root(str): your data root.
[in] train(bool): choose model train or val.
[in] download(bool): Download data automatically if download is Ture.
[in] download(bool): Download data automatically if download is True.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.
@ -106,7 +106,7 @@ class EMNIST(Dataset):
[in] data_root(str): your data root.
[in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'.
[in] train(bool): choose model train or val.
[in] download(bool): Download data automatically if download is Ture.
[in] download(bool): Download data automatically if download is True.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.

View File

@ -1,6 +1,6 @@
# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -10,6 +10,25 @@ import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable):
return x
return tuple([x] * n)
return parse
_pair = _ntuple(2)
has_acl = 0
cc_flags = ""
@ -34,38 +53,60 @@ compiler.has_acl = has_acl
# export DUMP_GRAPH_LEVEL=1
# build pytorch-npu
# bash ./ci/build.sh
# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
# bash ./ci/build.sh
# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
# pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark
# python3 ./mm_bench_pt_npu.py
def install():
import jittor.compiler as compiler
global has_acl, cc_flags
acl_compiler_home = os.path.dirname(__file__)
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc",
recursive=True))
cc_files2 = []
for name in cc_files:
if "acl_op_exec" in name:
# Skip files in hccl directory
if "hccl" in name:
continue
# if "acl_op_exec" in name or "_op_acl.cc" in name:
if "acl_op_exec" in name or "_op_acl.cc" in name or "utils.cc" in name:
compiler.extra_core_files.append(name)
else:
cc_files2.append(name)
cc_files = cc_files2
cc_flags += f" -DHAS_CUDA -DIS_ACL \
-I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \
-L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \
-I{acl_compiler_home} -lascendcl -lacl_op_compiler "
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
#print(ascend_toolkit_home)
#print(acl_compiler_home)
cc_flags += f" -MD -DHAS_CUDA -DIS_ACL \
-I{ascend_toolkit_home}/include/ \
-I{ascend_toolkit_home}/include/acl/ \
-I{ascend_toolkit_home}/include/aclnn/ \
-I{ascend_toolkit_home}/include/aclnnop/ \
-I{acl_compiler_home} -lascendcl -lacl_op_compiler \
-I{acl_compiler_home}/aclnn \
-I{acl_compiler_home}/aclops \
-L{ascend_toolkit_home}/lib64/"
cc_flags += " -llibascendcl "
cc_flags += " -llibnnopbase "
cc_flags += " -llibopapi "
#pdb.set_trace()
ctypes.CDLL("libascendcl.so", dlopen_flags)
'''
f'''
-ltikc_runtime
-I/usr/local/Ascend/driver/include \
-L/usr/local/Ascend/compiler/lib64 \
-L/usr/local/Ascend/runtime/lib64 \
-I/usr/local/Ascend/driver/include/ \
-L{ascend_toolkit_home}/compiler/lib64/ \
-L{ascend_toolkit_home}/runtime/lib64/ \
'''
jittor_utils.LOG.i("ACL detected")
global mod
mod = jittor_utils.compile_module('''
mod = jittor_utils.compile_module(
'''
#include "common.h"
namespace jittor {
// @pyjt(process)
@ -98,9 +139,10 @@ def check():
if not has_acl: return False
compiler.cc_flags += cc_flags
compiler.nvcc_path = tikcc_path
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "")
return True
def post_process():
if has_acl:
from jittor import pool
@ -108,5 +150,547 @@ def post_process():
import jittor as jt
jt.flags.use_cuda_host_allocator = 1
jt.flags.use_parallel_op_compiler = 0
jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
mod.init_acl_ops()
jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
mod.init_acl_ops()
def change_function():
import jittor as jt
from jittor import Function
from .aclops.flashattention_op import FlashAttentionACL
from .aclops.conv_op import ConvACL
from .aclops.pool_op import PoolACL
from .aclops.nantonum_op import NanToNumACL
from .aclops.stack_op import StackACL
from .aclops.rope_op import RopeACL
from .aclops.softmax_op import SoftmaxACL
from .aclops.sigmoid_op import SigmoidACL
from .aclops.silu_op import SiLUACL
from .aclops.dropout_op import DropoutACL
from .aclops.relu_op import LeakyReLUACL
from .aclops.flip_op import FlipACL
from .aclops.concat_op import ConcatACL
from .aclops.gather_scatter_op import GatherACL
from .aclops.cumsum_op import CumsumACL
from .aclops.index_op import IndexACL
from .aclops.gather_scatter_op import ScatterACL
from .aclops.where_op import WhereACL
from .aclops.where_op import NonzeroACL
from .aclops.floor_op import FloorIntACL
from .aclops.getitem_op import GetItemACL
from .aclops.setitem_op import SetItemACL
from .aclops.bmm_op import BmmACL
from .aclops.matmul_op import MatmulACL
from .aclops.transpose_op import TransPoseACL
from .aclops.triu_op import TriuACL
def triu_acl(x, diagonal=0):
return TriuACL()(x, diagonal)
from .aclops.conv_op import ConvACL
def conv_acl(x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
return ConvACL()(x, weight, bias, stride, padding, dilation, groups)
class Conv2D(jt.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
if in_channels <= 0:
raise ValueError(
f"in_channels must be greater than zero, got {in_channels}"
)
if out_channels <= 0:
raise ValueError(
f"out_channels must be greater than zero, got {out_channels}"
)
if groups <= 0:
raise ValueError(
f"groups must must be greater than zero, got {groups}")
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
if isinstance(kernel_size, tuple):
for size in kernel_size:
if size <= 0:
raise ValueError(
f"kernel_size must be greater than zero, got {kernel_size}"
)
else:
if kernel_size <= 0:
raise ValueError(
f"kernel_size must be greater than zero, got {kernel_size}"
)
if isinstance(stride, tuple):
for size in stride:
if size <= 0:
raise ValueError(
f"stride must be greater than zero, got {stride}")
else:
if stride <= 0:
raise ValueError(
f"stride must be greater than zero, got {stride}")
if isinstance(padding, tuple):
for size in padding:
if size < 0:
raise ValueError(
f"padding must be nonnegative, got {padding}")
else:
if padding < 0:
raise ValueError(
f"padding must be nonnegative, got {padding}")
if isinstance(dilation, tuple):
for size in dilation:
if size <= 0:
raise ValueError(
f"dilation must be greater than zero, got {dilation}"
)
else:
if dilation <= 0:
raise ValueError(
f"dilation must be greater than zero, got {dilation}")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(
kernel_size, tuple) else (kernel_size, kernel_size)
self.stride = stride if isinstance(stride, tuple) else (stride,
stride)
self.padding = padding if isinstance(padding, tuple) else (padding,
padding)
self.dilation = dilation if isinstance(
dilation, tuple) else (dilation, dilation)
self.groups = groups
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda:
self.depthwise_conv = jt.nn.DepthwiseConv(
stride, padding, dilation)
Kh, Kw = self.kernel_size
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
self.weight = jt.init.invariant_uniform(
[out_channels, in_channels // groups, Kh, Kw], dtype="float")
if bias:
fan = 1
for i in self.weight.shape[1:]:
fan *= i
bound = 1 / math.sqrt(fan)
self.bias = jt.init.uniform([out_channels],
dtype="float",
low=-bound,
high=bound)
else:
self.bias = None
def execute(self, x):
ret = jt.nn.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return ret
from .aclops.flip_op import FlipACL
def flip_acl(x, dim):
return FlipACL()(x, dim)
from .aclops.concat_op import ConcatACL
def concat(x, dim=0):
return ConcatACL()(x, dim)
from .aclops.gather_scatter_op import GatherACL
def gather_acl(input, dim, index):
return GatherACL()(input, dim, index)
def any_acl(input, dim=None):
if dim is None:
if jt.sum(input != 0).item() > 0:
return jt.array([True])
else:
return jt.array([False])
else:
return jt.sum(input != 0, dim=dim) > 0
from .aclops.cumsum_op import CumsumACL
def cumsum_acl(input, dim=-1):
return CumsumACL()(input, dim)
def cumprod_acl(x, dim=None):
x = jt.log(x)
x = cumsum_acl(x, dim=dim)
return jt.exp(x)
from .aclops.index_op import IndexACL
def index_acl(inshape: Union[jt.Var, list], dim=None, dtype="int32"):
if isinstance(inshape, jt.Var):
inshape = inshape.shape
return IndexACL()(inshape, dim, dtype)
from .aclops.gather_scatter_op import ScatterACL
def scatter_acl(input, dim, index, src, reduce='void'):
return ScatterACL()(input, dim, index, src, reduce)
from .aclops.where_op import WhereACL
def where_acl(condition, x=None, y=None):
return WhereACL()(condition, x, y)
from .aclops.where_op import NonzeroACL
def nonzero_acl(x):
return NonzeroACL()(x)
from .aclops.floor_op import FloorIntACL
def floor_int_acl(x):
return FloorIntACL()(x)
from .aclops.getitem_op import GetItemACL
def getitem_acl(x, slices, return_x=None):
# Transform numpy int to int
if isinstance(slices, (np.int8, np.int16, np.int32, np.int64)):
slices = int(slices)
if hasattr(np, 'int128') and isinstance(slices, np.int128):
slices = int(slices)
if hasattr(np, 'int256') and isinstance(slices, np.int256):
slices = int(slices)
## If not related to `None`, directly use `GetItemACL`
if slices is not None and (not isinstance(slices, Iterable)
or all([s is not None for s in slices])):
return GetItemACL()(x, slices, return_x)
## If related to `None`, filter out `None` first, then use `GetItemACL`, and finally insert `None` (new dimensions) back
# Transform to tuple
if isinstance(slices, int) or isinstance(slices, slice):
slices = (slices, )
assert isinstance(slices, tuple)
def get_insert_positions(slices):
result = []
pos = 0
not_none_cnt = len(slices) - slices.count(None)
for s in slices:
if isinstance(s, int):
continue
elif s is None:
result.append(pos)
pos += 1
elif s == Ellipsis:
pos += 1 + x.ndim - not_none_cnt
else:
pos += 1
return result
insert_positions = get_insert_positions(slices)
slices_without_none = tuple(s for s in slices if s is not None)
result = GetItemACL()(x, slices_without_none, return_x)
for i in insert_positions:
result = result.unsqueeze(i)
return result
from .aclops.setitem_op import SetItemACL
def setitem_acl(x, slices, value):
res = SetItemACL()(x, slices, value)
return x.assign(res)
from .aclops.bmm_op import BmmACL
def bmm_acl(x1, x2):
return BmmACL()(x1, x2)
def bmm_transpose_acl(x1, x2):
return BmmACL(True)(x1, x2)
from .aclops.matmul_op import MatmulACL
def matmul_acl(x1, x2):
return MatmulACL()(x1, x2)
def matmul_transpose_acl(x1, x2):
return MatmulACL(True)(x1, x2)
from .aclops.transpose_op import TransPoseACL
def transpose_acl(x, *dim):
return TransPoseACL()(x, *dim)
from .aclops.relu_op import ReLUACL
class ReLU(jt.nn.Module):
def __init__(self):
super(ReLU, self).__init__()
def execute(self, x):
return ReLUACL()(x)
def relu(x):
return ReLUACL()(x)
from .aclops.relu_op import LeakyReLUACL
class LeakyReLU(jt.nn.Module):
def __init__(self, negative_slope=0.01):
super(LeakyReLU, self).__init__()
self.negative_slope = negative_slope
def execute(self, x):
return LeakyReLUACL()(x, self.negative_slope)
def leaky_relu(x, scale=0.01):
return LeakyReLUACL()(x, scale)
from .aclops.dropout_op import DropoutACL
class Dropout(jt.nn.Module):
def __init__(self, p=0.5, is_train=False):
super(Dropout, self).__init__()
self.p = p
self.is_train = is_train
def execute(self, x):
return DropoutACL()(x, self.p, self.is_train)
def dropout_acl(x, p=0.5, is_train=False):
return DropoutACL()(x, p, is_train)
from .aclops.silu_op import SiLUACL
def silu_acl(x):
return SiLUACL()(x)
class SiLU(jt.nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def execute(self, x):
return SiLUACL()(x)
from .aclops.sigmoid_op import SigmoidACL
def sigmoid_acl(x):
return SigmoidACL()(x)
class Sigmoid(jt.nn.Module):
def __init__(self):
super(Sigmoid, self).__init__()
def execute(self, x):
return SigmoidACL()(x)
# class Embedding(jt.nn.Module):
# def __init__(self,
# num_embeddings,
# embedding_dim,
# padding_idx=None,
# dtype="float32"):
# self.num_embeddings = num_embeddings
# self.embedding_dim = embedding_dim
# self.padding_idx = padding_idx
# self.weight = jt.init.gauss(
# [self.num_embeddings, self.embedding_dim], dtype)
# if padding_idx is not None:
# self.weight[padding_idx] = 0
# def execute(self, x):
# res = embedding_acl(x, self.weight)
# return res
class Softmax(jt.nn.Module):
def __init__(self):
super(Softmax, self).__init__()
def execute(self, x, dim):
return SoftmaxACL()(x, dim)
def softmax_acl(x, dim):
return SoftmaxACL()(x, dim)
from .aclops.rope_op import RopeACL
def rope_acl(xq, xk, freqs_cis=None, freq_sin=None, freq_cos=None):
return RopeACL()(xq, xk, freqs_cis, freq_sin, freq_cos)
from .aclops.stack_op import StackACL
def stack_acl(x, dim=0):
return StackACL()(x, dim)
from .aclops.nantonum_op import NanToNumACL
def isnan_acl(x):
tonum = NanToNumACL()(x, -1.0)
return jt.not_equal(x, tonum).logical_and(
jt.not_equal(tonum, jt.ones_like(x)))
def isinf_acl(x):
tonum = NanToNumACL()(x, 1.0)
return jt.not_equal(x, tonum).logical_and(
jt.not_equal(tonum, jt.ones_like(x)))
def warp(origin_func, new_func, name=None):
if isinstance(origin_func, type):
class WrappedClass(origin_func, new_func):
def __init__(self, *args, **kwargs):
if jt.flags.use_acl:
new_func.__init__(self, *args, **kwargs)
else:
origin_func.__init__(self, *args, **kwargs)
def execute(self, *args, **kwargs):
if jt.flags.use_acl:
return new_func.execute(self, *args, **kwargs)
elif name == 'setitem':
return args[0].assign(origin_func(*args, **kwargs))
else:
return origin_func.execute(self, *args, **kwargs)
return WrappedClass
else:
def warpper(*args, **kwargs):
if jt.flags.use_acl:
return new_func(*args, **kwargs)
elif name == 'setitem':
return args[0].assign(origin_func(*args, **kwargs))
else:
return origin_func(*args, **kwargs)
return warpper
jt.triu = warp(jt.triu, triu_acl)
jt.triu_ = warp(jt.triu, triu_acl)
jt.Var.triu = jt.triu
jt.Var.triu_ = lambda x, diagonal=0: x.assign(x.triu(diagonal))
jt.nn.conv2d = warp(jt.nn.conv2d, conv_acl)
jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D)
jt.nn.Conv = warp(jt.nn.Conv, Conv2D)
jt.nn.Pool = warp(jt.nn.Pool, PoolACL)
jt.flip = warp(jt.flip, flip_acl)
jt.Var.flip = lambda x, dim_vector=0: jt.flip(x, dim_vector)
jt.concat = warp(jt.concat, concat)
jt.stack = warp(jt.stack, stack_acl)
jt.gather = warp(jt.gather, gather_acl)
jt.any = warp(jt.any, any_acl)
jt.Var.any = jt.any
jt.cumsum = warp(jt.cumsum, cumsum_acl)
jt.cub_cumsum = jt.cumsum
jt.Var.cumsum = jt.cumsum
jt.Var.cub_cumsum = jt.cumsum
jt.cumprod = warp(jt.cumprod, cumprod_acl)
jt.index = warp(jt.index, index_acl)
jt.Var.index = jt.index
jt.scatter = warp(jt.scatter, scatter_acl)
jt.Var.scatter = lambda x, dim, index, src, reduce="void": jt.scatter(
x, dim, index, src, reduce)
jt.where = warp(jt.where, where_acl)
jt.nonzero = warp(jt.nonzero, nonzero_acl)
jt.misc.nonzero = warp(jt.misc.nonzero, nonzero_acl)
jt.Var.nonzero = jt.misc.nonzero
jt.floor_int = warp(jt.floor_int, floor_int_acl)
jt.Var.floor_int = lambda x: jt.floor_int(x)
jt.getitem = warp(jt.contrib.getitem, getitem_acl)
fake_getitem = jt.Var.getitem
jt.Var.getitem = lambda x, slices, return_x=None: warp(
fake_getitem, getitem_acl)(x, slices)
jt.Var.slice_var = lambda x, slices, return_x=None: warp(
fake_getitem, getitem_acl)(x, slices)
jt.Var.__getitem__ = lambda x, slices, return_x=None: warp(
fake_getitem, getitem_acl)(x, slices)
jt.setitem = warp(jt.contrib.setitem, setitem_acl)
fake_setitem = jt.Var.setitem
jt.Var.setitem = lambda x, slices, value: warp(
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
jt.Var.__setitem__ = lambda x, slices, value: warp(
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
fake_matmul = jt.Var.matmul
jt.nn.bmm = warp(jt.nn.bmm, bmm_acl)
jt.bmm = warp(jt.bmm, bmm_acl)
jt.nn.matmul = warp(jt.matmul, matmul_acl)
jt.matmul = warp(jt.matmul, matmul_acl)
jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl)
jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl)
jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl)
jt.Var.__matmul__ = lambda x, y: warp(fake_matmul, matmul_acl)(x, y)
jt.transpose = warp(jt.transpose, transpose_acl)
fake_transpose = jt.transpose
jt.Var.transpose = lambda x, *dim: warp(fake_transpose, transpose_acl)(x, *
dim)
# jt.Var.permute = lambda x: warp(fake_transpose, transpose_acl)(x)
# jt.Var.t = lambda x: warp(fake_transpose, transpose_acl)(x)
jt.nn.relu = warp(jt.nn.relu, relu)
jt.nn.ReLU = warp(jt.nn.ReLU, ReLU)
jt.nn.leaky_relu = warp(jt.nn.leaky_relu, leaky_relu)
jt.nn.LeakyReLU = warp(jt.nn.LeakyReLU, LeakyReLU)
# jt.nn.silu = warp(jt.nn.silu, silu_acl)
# jt.nn.SiLU = warp(jt.nn.SiLU, SiLU)
jt.sigmoid = warp(jt.sigmoid, sigmoid_acl)
jt.nn.Sigmoid = warp(jt.nn.Sigmoid, Sigmoid)
# from .aclops.embedding_op import EmbeddingACL
# def embedding_acl(indices, weight):
# return EmbeddingACL()(indices, weight)
# jt.nn.embedding = warp(jt.nn.embedding, embedding_acl)
# jt.nn.Embedding = warp(jt.nn.Embedding, Embedding)
jt.nn.dropout = warp(jt.nn.dropout, dropout_acl)
jt.nn.Dropout = warp(jt.nn.Dropout, Dropout)
jt.nn.softmax = warp(jt.nn.softmax, softmax_acl)
# from .aclops.norms_op import BatchNormACL,LayerNormACL
# jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL)
# jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL)
jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL)
jt.isnan = warp(jt.isnan, isnan_acl)
jt.isinf = warp(jt.isinf, isinf_acl)
jt.Var.isnan = jt.isnan
jt.Var.isinf = jt.isinf
jt.nn.rotary_emb = rope_acl

View File

@ -1,6 +1,6 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
@ -11,23 +11,27 @@ using std::unordered_map;
typedef int aclError;
static inline unordered_map<aclError,string> gen_map(string s) {
unordered_map<aclError,string> smap;
for (int i=0; i<s.size(); i++) {
if (s[i] == ';') {
int j=s.rfind(" ", i);
int code = std::stoi(s.substr(j+1, i-j-1));
int k = s.rfind(" ", j-1);
int l = s.rfind(" ACL_", k-1);
smap[code] = s.substr(l+1, k-l-1);
static inline unordered_map<aclError, string> gen_map(string s)
{
unordered_map<aclError, string> smap;
for (int i = 0; i < s.size(); i++)
{
if (s[i] == ';')
{
int j = s.rfind(" ", i);
int code = std::stoi(s.substr(j + 1, i - j - 1));
int k = s.rfind(" ", j - 1);
int l = s.rfind(" ACL_", k - 1);
smap[code] = s.substr(l + 1, k - l - 1);
}
}
return smap;
}
string acl_error_to_string(aclError error) {
string acl_error_to_string(aclError error)
{
static unordered_map<aclError,string> acl_error_map = gen_map(R"(
static unordered_map<aclError, string> acl_error_map = gen_map(R"(
// from acl_base.h
static const int ACL_ERROR_INVALID_PARAM = 100000;
static const int ACL_ERROR_UNINITIALIZE = 100001;

View File

@ -1,6 +1,6 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
@ -10,266 +10,311 @@
#include "utils/str_utils.h"
#include <chrono>
#include <thread>
#include "aclnn/aclnn.h"
namespace jittor {
namespace jittor
{
uint64_t acl_jittor_tid;
int acl_jittor_thread_running=0;
aclrtContext acl_jittor_context;
uint64_t acl_jittor_tid;
int acl_jittor_thread_running = 0;
aclrtStream aclstream;
void *workspaceAddr = nullptr;
uint64_t nowWorkSpaceSize = 0;
#define CHECK_ACL(x) ASSERTop(x,==,0)
#define CHECK_ACL(x) ASSERTop(x, ==, 0)
static void* acl_jittor_process_callback(void*) {
acl_jittor_thread_running = 1;
int deviceId = 0;
CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context));
while (acl_jittor_thread_running) {
// LOGir << "acl_jittor_process_callback";
auto ret = aclrtProcessReport(1000);
if (ret) {
if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT)
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
break;
void mallocWorkSpace(uint64_t size)
{
uint64_t alloc_size = size + 32;
alloc_size = ((alloc_size - 1) / 32 + 1) * 32;
if (alloc_size > nowWorkSpaceSize)
{
aclrtFree(workspaceAddr);
nowWorkSpaceSize = alloc_size;
auto ret = aclrtMalloc(&workspaceAddr, nowWorkSpaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return);
}
}
acl_jittor_thread_running = 0;
return (void*)0;
}
static void *acl_jittor_process_callback(void *)
{
acl_jittor_thread_running = 1;
// void aaa(void*) {
// LOGir << "haha";
// }
struct acl_jittor_initer {
acl_jittor_initer() {
CHECK_ACL(aclInit(nullptr));
uint device_count = 0;
// 获取可用的Device数量
CHECK_ACL(aclrtGetDeviceCount(&device_count));
LOGi << "Found ACL device number:" << device_count;
CHECK_ACL(aclrtSetDevice(0));
CHECK_ACL(aclrtCreateContext(&acl_jittor_context, 0));
CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context));
pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0);
// subscribe for default stream
CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0));
// simple callback test
// aclrtStream stream;
// CHECK_ACL(aclrtCreateStream(&stream));
// CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,stream));
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, stream));
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0));
}
~acl_jittor_initer() {
acl_jittor_thread_running = 0;
CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid,0));
CHECK_ACL(aclrtDestroyContext(acl_jittor_context));
CHECK_ACL(aclFinalize());
}
} _acl_jittor_initer;
string process_acl(const string& src, const string& name, const map<string,string>& kargs) {
if (endswith(name, "_jittor.cc"))
return src;
// static vector<string> dont_compile = {"fp16_emu.cc"};
// for (auto& s : dont_compile)
// if (endswith(name, s))
// return " ";
static unordered_set<string> cuda_headers = {
"cuda_runtime", "cudnn", "driver_types",
"cuda_fp16", "cuda_runtime_api", "fp16_emu",
"cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper",
"curand", "curand_wrapper", "cufft", "cufftXt",
"CudaUtils", "cutt"
};
static unordered_set<string> fake_class = {
"cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t",
"cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t",
"cufftHandle"
};
try {
auto tokens = token_split(src);
int edit = 0;
for (int i=0; i<tokens.size(); i++) {
auto& token = tokens[i];
if (cuda_headers.count(token)) token = "acl_jittor", edit ++; else
if (fake_class.count(token)) token = "int", edit ++; else
if (token == "CUDA") token = "ACL", edit ++; else
if (startswith(token, "cuda")) {
if (token.size()>=5 && token[4] >= 'A' && token[4] <= 'Z') {
if (token == "cudaGetDeviceCount") {
token_replace(tokens, i, "($1);", "((uint*)$1);");
} else if (token == "cudaLaunchHostFunc") {
// ACL_CALLBACK_BLOCK for 310
token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)",
"LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)");
} else if (token == "cudaMemcpy")
token_replace(tokens, i, "cudaMemcpy($1,$2,$3,",
"aclrtMemcpy($1,$3,$2,$3,");
else if (token == "cudaMemcpyAsync")
token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,",
"aclrtMemcpyAsync($1,$3,$2,$3,");
else if (token == "cudaMemcpyDeviceToHost") token = "ACL_MEMCPY_DEVICE_TO_HOST";
else if (token == "cudaMemcpyHostToDevice") token = "ACL_MEMCPY_HOST_TO_DEVICE";
else if (token == "cudaMemcpyDeviceToDevice") token = "ACL_MEMCPY_DEVICE_TO_DEVICE";
else if (token == "cudaMallocManaged" || token == "cudaMalloc") {
// unified address not supported
token = "aclrtMalloc";
token_replace(tokens, i, "($1,$2)",
"($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)");
} else if (token == "cudaMemGetInfo")
token_replace(tokens, i, "cudaMemGetInfo($1,$2)",
"aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)");
else if (token == "cudaGetLastError")
token_replace(tokens, i, "cudaGetLastError()", "0");
else if (token == "cudaStreamCreateWithFlags")
token_replace(tokens, i-1,
"(cudaStreamCreateWithFlags($1,$2));",
"(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));");
else if (token == "cudaEventCreate")
token_replace(tokens, i,
"cudaEventCreate($1,$2)",
"aclrtCreateEvent($1)");
else if (token == "cudaDeviceSynchronize")
token = "aclrtSynchronizeDevice";
else if (token == "cudaStreamDestroy")
token_replace(tokens, i, "cudaStreamDestroy($1)",
"(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))");
else if (token == "cudaEventDestroy")
token = "aclrtDestroyEvent";
else if (token == "cudaEventRecord")
token = "aclrtRecordEvent";
else if (token == "cudaStreamWaitEvent")
token_replace(tokens, i,
"cudaStreamWaitEvent($1,$2,$3)",
"aclrtStreamWaitEvent($1,$2)");
if (token.size() && token[0] == 'c')
token = "aclrt" + token.substr(4);
if (endswith(token, "_t"))
token = token.substr(0, token.size()-2);
edit ++;
}
} else
if (token == "_cudaGetErrorEnum") {
token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))");
edit ++;
} else
if (token == "checkCudaErrors")
token = "checkAclErrors";
else if (token == "JPU") {
edit ++;
string new_code;
if (tokens[i+2] == "op_compiler")
token_replace(tokens, i,
"JPU(op_compiler($1,$2,$3))",
"acl_jittor_op_compiler($1,$2,$3)");
else if (tokens[i+2] == "header")
new_code = "#include \"acl_jittor.h\"";
if (new_code.size())
token_replace(tokens, i, "JPU($1)", new_code);
} else if (token == "use_cuda_managed_allocator" && tokens[i+1][0]==',') {
tokens[i+2] = "0"; // disable unified address
}
}
if (!edit) return src;
string new_src = join(tokens, "");
// if (name == "executor.cc") {
// new_src = string("#include <Python.h>\n#include <pystate.h>\n#include <common.h>\n")+
// "namespace jittor { void acl_op_exec(Op*); }\n" +
// replace(new_src, "op->do_run_after_prepare(jkl);",
// R"({
// acl_op_exec(op);
// })");
// }
if (name == "profiler.cc") {
new_src = token_replace_all(new_src, ".cc", ".tikcc");
}
// LOGir << name << (name == "pass_manager.cc");
if (name == "pass_manager.cc") {
LOGir << "replace" << name;
new_src = token_replace_all(new_src, "run_pass<FloatAtomicFixPass>();", "WTF");
}
// ????????
return new_src;
} catch (const std::exception& e) {
LOGe << "process acl error:" << e.what();
LOGe << "name:" << name;
throw;
}
}
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) {
if (!is_acl) return;
// extra_flags += " --tik-soc-version=Ascend910 ";
// filename = replace(filename, ".cc", ".tikcc");
// LOGir << filename;
string new_src = process_acl(src, "", {});
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
new_src = replace(new_src, "__global__", "__ai_device_entry__");
new_src = token_replace_all(new_src, "__launch_bounds__($1)", "");
new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;");
new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", "");
new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>");
new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;");
// for inc error
new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)");
// bit op error
new_src = token_replace_all(new_src, "int tnum$1;", "");
new_src = token_replace_all(new_src, "int p1$1;", "");
new_src = token_replace_all(new_src, "int p2$1;", "");
new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;");
new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;");
src = new_src;
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
new_src = token_replace_all(new_src, "bool", "int8");
new_src = token_replace_all(new_src, "::numeric_min<float32>()", "-1e30");
new_src = token_replace_all(new_src, "::numeric_max<float32>()", "1e30");
// TODO: support max
unordered_map<string,string> opmap = {
// {"::max","tikcc::scalar_max"},
{"::sqrtf", "tikcc::scalar_sqrt"}
};
auto ss = split(new_src, ";");
for (auto &s : ss) {
if (s.find("?") != string::npos) {
s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
}
if (s.find("::max") != string::npos) {
if (s.find("auto") == string::npos) {
s = token_replace_all(s+";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;");
} else {
s = token_replace_all(s+";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;");
while (acl_jittor_thread_running)
{
// LOGir << "acl_jittor_process_callback";
auto ret = aclrtProcessReport(1000);
if (ret)
{
if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE)
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
break;
}
}
for (auto& kv : opmap) {
if (s.find(kv.first) != string::npos) {
if (s.find("auto") == string::npos) {
// $1 = op($2) --> op($1, $2)
s = token_replace_all(s+";", " $1= "+kv.first+"($2);", kv.second+"($1, $2);");
} else {
// auto $1 = op($2) --> float32 $1; op($1, $2);
s = token_replace_all(s+";", "auto $1= "+kv.first+"($2);", "float32 $1; " + kv.second+"($1, $2);");
acl_jittor_thread_running = 0;
return (void *)0;
}
struct acl_jittor_initer
{
int32_t deviceId;
acl_jittor_initer()
{
CHECK_ACL(aclInit(nullptr));
uint device_count = 0;
deviceId = 0;
// 获取可用的Device数量
CHECK_ACL(aclrtGetDeviceCount(&device_count));
LOGi << "Found ACL device number:" << device_count;
CHECK_ACL(aclrtSetDevice(deviceId));
CHECK_ACL(aclrtCreateStream(&aclstream));
// pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0);
}
~acl_jittor_initer()
{
acl_jittor_thread_running = 0;
// CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid, 0));
aclrtDestroyStream(aclstream);
aclrtResetDevice(deviceId);
CHECK_ACL(aclFinalize());
if (nowWorkSpaceSize > 0)
{
aclrtFree(workspaceAddr);
}
}
} _acl_jittor_initer;
string process_acl(const string &src, const string &name, const map<string, string> &kargs)
{
if (endswith(name, "_jittor.cc"))
return src;
// static vector<string> dont_compile = {"fp16_emu.cc"};
// for (auto& s : dont_compile)
// if (endswith(name, s))
// return " ";
static unordered_set<string> cuda_headers = {
"cuda_runtime", "cudnn", "driver_types",
"cuda_fp16", "cuda_runtime_api", "fp16_emu",
"cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper",
"curand", "curand_wrapper", "cufft", "cufftXt",
"CudaUtils", "cutt", "cudnn_wrapper", "cuda_bf16"};
static unordered_set<string> fake_class = {
"cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t",
"cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t",
"cufftHandle"};
try
{
auto tokens = token_split(src);
int edit = 0;
for (int i = 0; i < tokens.size(); i++)
{
auto &token = tokens[i];
if (cuda_headers.count(token))
token = "acl_jittor", edit++;
else if (fake_class.count(token))
token = "int", edit++;
else if (token == "CUDA")
token = "ACL", edit++;
else if (startswith(token, "cuda"))
{
if (token.size() >= 5 && token[4] >= 'A' && token[4] <= 'Z')
{
if (token == "cudaGetDeviceCount")
{
token_replace(tokens, i, "($1);", "((uint*)$1);");
}
else if (token == "cudaLaunchHostFunc")
{
// ACL_CALLBACK_BLOCK for 310
token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)",
"LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)");
}
else if (token == "cudaMemcpy")
token_replace(tokens, i, "cudaMemcpy($1,$2,$3,",
"aclrtMemcpy($1,$3,$2,$3,");
else if (token == "cudaMemcpyAsync")
token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,",
"aclrtMemcpyAsync($1,$3,$2,$3,");
else if (token == "cudaMemcpyDeviceToHost")
token = "ACL_MEMCPY_DEVICE_TO_HOST";
else if (token == "cudaMemcpyDefault")
token = "ACL_MEMCPY_HOST_TO_DEVICE";
else if (token == "cudaMemcpyHostToDevice")
token = "ACL_MEMCPY_HOST_TO_DEVICE";
else if (token == "cudaMemcpyDeviceToDevice")
token = "ACL_MEMCPY_DEVICE_TO_DEVICE";
else if (token == "cudaMallocManaged" || token == "cudaMalloc")
{
// unified address not supported
token = "aclrtMalloc";
token_replace(tokens, i, "($1,$2)",
"($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)");
}
else if (token == "cudaMemGetInfo")
token_replace(tokens, i, "cudaMemGetInfo($1,$2)",
"aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)");
else if (token == "cudaGetLastError")
token_replace(tokens, i, "cudaGetLastError()", "0");
else if (token == "cudaStreamCreateWithFlags")
token_replace(tokens, i - 1,
"(cudaStreamCreateWithFlags($1,$2));",
"(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));");
else if (token == "cudaEventCreate")
token_replace(tokens, i,
"cudaEventCreate($1,$2)",
"aclrtCreateEvent($1)");
else if (token == "cudaDeviceSynchronize")
token = "aclrtSynchronizeDevice";
else if (token == "cudaStreamDestroy")
token_replace(tokens, i, "cudaStreamDestroy($1)",
"(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))");
else if (token == "cudaEventDestroy")
token = "aclrtDestroyEvent";
else if (token == "cudaEventRecord")
token = "aclrtRecordEvent";
else if (token == "cudaStreamWaitEvent")
token_replace(tokens, i,
"cudaStreamWaitEvent($1,$2,$3)",
"aclrtStreamWaitEvent($1,$2)");
if (token.size() && token[0] == 'c')
token = "aclrt" + token.substr(4);
if (endswith(token, "_t"))
token = token.substr(0, token.size() - 2);
edit++;
}
}
else if (token == "_cudaGetErrorEnum")
{
token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))");
edit++;
}
else if (token == "checkCudaErrors")
token = "checkAclErrors";
else if (token == "JPU")
{
edit++;
string new_code;
if (tokens[i + 2] == "op_compiler")
token_replace(tokens, i,
"JPU(op_compiler($1,$2,$3))",
"acl_jittor_op_compiler($1,$2,$3)");
else if (tokens[i + 2] == "header")
new_code = "#include \"acl_jittor.h\"";
if (new_code.size())
token_replace(tokens, i, "JPU($1)", new_code);
}
else if (token == "use_cuda_managed_allocator" && tokens[i + 1][0] == ',')
{
tokens[i + 2] = "0"; // disable unified address
}
}
if (!edit)
return src;
string new_src = join(tokens, "");
// if (name == "executor.cc") {
// new_src = string("#include <Python.h>\n#include <pystate.h>\n#include <common.h>\n")+
// "namespace jittor { void acl_op_exec(Op*); }\n" +
// replace(new_src, "op->do_run_after_prepare(jkl);",
// R"({
// acl_op_exec(op);
// })");
// }
if (name == "profiler.cc")
{
new_src = token_replace_all(new_src, ".cc", ".tikcc");
}
// LOGir << name << (name == "pass_manager.cc");
if (name == "pass_manager.cc")
{
LOGir << "replace" << name;
new_src = token_replace_all(new_src, "run_pass<FloatAtomicFixPass>();", "WTF");
}
// ????????
return new_src;
}
catch (const std::exception &e)
{
LOGe << "process acl error:" << e.what();
LOGe << "name:" << name;
throw;
}
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
// if (s.find("::max") != string::npos) {
// s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);");
// }
}
new_src = join(ss, ";");
src = new_src;
}
void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags)
{
if (!is_acl)
return;
string new_src = process_acl(src, "", {});
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
new_src = replace(new_src, "__global__", "__ai_device_entry__");
new_src = token_replace_all(new_src, "__launch_bounds__($1)", "");
new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;");
new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", "");
new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>");
new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;");
// for inc error
new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)");
// bit op error
new_src = token_replace_all(new_src, "int tnum$1;", "");
new_src = token_replace_all(new_src, "int p1$1;", "");
new_src = token_replace_all(new_src, "int p2$1;", "");
new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;");
new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;");
src = new_src;
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
// new_src = token_replace_all(new_src, "bool", "int8");
new_src = token_replace_all(new_src, "::numeric_min<float32>()", "-1e30");
new_src = token_replace_all(new_src, "::numeric_max<float32>()", "1e30");
// TODO: support max
unordered_map<string, string> opmap = {
// {"::max","tikcc::scalar_max"},
{"::sqrtf", "tikcc::scalar_sqrt"}};
auto ss = split(new_src, ";");
for (auto &s : ss)
{
if (s.find("?") != string::npos)
{
s = token_replace_all(s + ";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
}
if (s.find("::max") != string::npos)
{
if (s.find("auto") == string::npos)
{
s = token_replace_all(s + ";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;");
}
else
{
s = token_replace_all(s + ";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;");
}
}
for (auto &kv : opmap)
{
if (s.find(kv.first) != string::npos)
{
if (s.find("auto") == string::npos)
{
// $1 = op($2) --> op($1, $2)
s = token_replace_all(s + ";", " $1= " + kv.first + "($2);", kv.second + "($1, $2);");
}
else
{
// auto $1 = op($2) --> float32 $1; op($1, $2);
s = token_replace_all(s + ";", "auto $1= " + kv.first + "($2);", "float32 $1; " + kv.second + "($1, $2);");
}
}
}
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
// if (s.find("::max") != string::npos) {
// s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);");
// }
}
new_src = join(ss, ";");
src = new_src;
}
}

View File

@ -1,19 +1,700 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
#include "aclnn/aclnn.h"
#include <acl/acl.h>
std::string acl_error_to_string(aclError error);
namespace jittor {
namespace jittor
{
EXTERN_LIB uint64_t acl_jittor_tid;
EXTERN_LIB uint64_t acl_jittor_tid;
EXTERN_LIB aclrtStream aclstream;
EXTERN_LIB void *workspaceAddr;
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags);
void mallocWorkSpace(uint64_t size);
}
void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags);
struct AclOpFunctions
{
// for Unary and Nonzero
std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncUnaryNonzero;
// for Cast
std::function<aclnnStatus(aclTensor *, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncCast;
// for Bianry
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBinary;
// for Add and Sub
std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAdd;
// for Expand, permute, flip
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncExpand;
// for bmm and matmul
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMatmul;
// for conv
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int64_t, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConv;
// for reducesum, mean
std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncReduceSum;
// for amax and amin
std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAmax;
// for conv backward
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int, aclBoolArray *, int8_t, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConvBackward;
// for proddim
std::function<aclnnStatus(aclTensor *, float, float, float, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncProdDim;
// for select, where
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSelect;
// for random_uniform and random_normal
std::function<aclnnStatus(aclTensor *, int64_t, int64_t, int64_t, int64_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncRandom;
// for maxpool
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMaxPool;
// for maxpool backward
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMaxPoolBackward;
// for avgpool
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAvgPool;
// for avgpool backward
std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAvgPoolBackward;
// for concat
std::function<aclnnStatus(aclTensorList *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConcat;
// for gather
std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncGather;
// for cumsum
std::function<aclnnStatus(aclTensor *, uint64_t, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncCumsum;
// for scatter
std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncScatter;
// for index
std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncIndex;
// for stridesliceassignv2
std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncStridedSliceAssignV2;
// for slicev2
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSliceV2;
// for indexputimpl
std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, bool, bool, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncIndexPutImpl;
// for range
std::function<aclnnStatus(aclScalar *, aclScalar *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncRange;
// for leaky_relu
std::function<aclnnStatus(aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncLeakyRelu;
// for leaky_relu backward
std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncLeakyReluBackward;
// for dropout
std::function<aclnnStatus(aclTensor *, double, bool, int64_t, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncDropout;
// for dropout backward
std::function<aclnnStatus(aclTensor *, aclTensor *, double, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncDropoutBackward;
// for split with size
std::function<aclnnStatus(aclTensor *, aclIntArray *, int64_t, aclTensorList *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSplitWithSize;
// for silu
// std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSilu;
// for silu backward
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSiluBackward;
// for sigmoid
// std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSigmoid;
// for sigmoid backward
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSigmoidBackward;
// for embedding
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncEmbedding;
// for embedding backward
std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t, uint64_t, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncEmbeddingBackward;
// for InplaceMaskedScatter MaskedSelect
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncInplaceMaskedScatter;
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> executeFunc;
// for flashattention
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
getWorkspaceSizeFuncFalshAttention;
// for flashattention backward
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
getWorkspaceSizeFuncFalshAttentionBackward;
// for batchnorm
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBatchNorm;
// for batchnorm backward
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, aclBoolArray *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBatchNormBackward;
// for layernorm
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, aclTensor *, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncLayerNorm;
// for ROPE
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, int64_t, uint64_t *, aclOpExecutor **)>
getWorkspaceSizeFuncRotaryPosEmb;
// 添加一个默认构造函数
AclOpFunctions() = default;
// for Unary and Nonzero
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> execf)
: getWorkspaceSizeFuncUnaryNonzero(gwsf), executeFunc(execf) {}
// for Cast
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> execf)
: getWorkspaceSizeFuncCast(gwsf), executeFunc(execf) {}
// for Binary
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncBinary(gwsf), executeFunc(execf) {}
// for Add and Sub
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncAdd(gwsf), executeFunc(execf) {}
// for Expand, flip
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncExpand(gwsf), executeFunc(execf) {}
// for Matmul
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncMatmul(gwsf), executeFunc(execf) {}
// for conv
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int64_t, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncConv(gwsf), executeFunc(execf) {}
// for reducesum, mean
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncReduceSum(gwsf), executeFunc(execf) {}
// for amax amin
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncAmax(gwsf), executeFunc(execf) {}
// for conv backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int, aclBoolArray *, int8_t, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncConvBackward(gwsf), executeFunc(execf) {}
// for proddim
AclOpFunctions(std::function<aclnnStatus(const aclTensor *, float, float, float, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncProdDim(gwsf), executeFunc(execf) {}
// for select, where
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncSelect(gwsf), executeFunc(execf) {}
// for random_normal
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, int64_t, int64_t, int64_t, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncRandom(gwsf), executeFunc(execf) {}
// for maxpool
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncMaxPool(gwsf), executeFunc(execf) {}
// for maxpool backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncMaxPoolBackward(gwsf), executeFunc(execf) {}
// for avgpool
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncAvgPool(gwsf), executeFunc(execf) {}
// for avgpool backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncAvgPoolBackward(gwsf), executeFunc(execf) {}
// for concat
AclOpFunctions(std::function<aclnnStatus(aclTensorList *, int64_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncConcat(gwsf), executeFunc(execf) {}
// for gather
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncGather(gwsf), executeFunc(execf) {}
// for cumsum
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncCumsum(gwsf), executeFunc(execf) {}
// for scatter
AclOpFunctions(std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncScatter(gwsf), executeFunc(execf) {}
// for index
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncIndex(gwsf), executeFunc(execf) {}
// for stridesliceassignv2
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncStridedSliceAssignV2(gwsf), executeFunc(execf) {}
// for slicev2
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncSliceV2(gwsf), executeFunc(execf) {}
// for indexputimpl
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, bool, bool, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncIndexPutImpl(gwsf), executeFunc(execf) {}
// for range
AclOpFunctions(std::function<aclnnStatus(aclScalar *, aclScalar *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncRange(gwsf), executeFunc(execf) {}
// for leaky_relu
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncLeakyRelu(gwsf), executeFunc(execf) {}
// for leaky_relu backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncLeakyReluBackward(gwsf), executeFunc(execf) {}
// for dropout
AclOpFunctions(std::function<aclnnStatus(aclTensor *, double, bool, int64_t, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncDropout(gwsf), executeFunc(execf) {}
// for dropout backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, double, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncDropoutBackward(gwsf), executeFunc(execf) {}
// for embedding backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t, uint64_t, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncEmbeddingBackward(gwsf), executeFunc(execf) {}
// for split with size
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, int64_t, aclTensorList *, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncSplitWithSize(gwsf), executeFunc(execf) {}
// for flash attention
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncFalshAttention(gwsf), executeFunc(execf) {}
// for flash attention backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncFalshAttentionBackward(gwsf), executeFunc(execf) {}
// for batchnorm
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncBatchNorm(gwsf), executeFunc(execf) {}
// for batchnorm backward
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, aclBoolArray *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncBatchNormBackward(gwsf), executeFunc(execf) {}
// for layernorm
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, aclTensor *, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncLayerNorm(gwsf), executeFunc(execf) {}
// for ROPE
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, const aclTensor *, const aclTensor *, int64_t, uint64_t *, aclOpExecutor **)> gwsf,
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
: getWorkspaceSizeFuncRotaryPosEmb(gwsf), executeFunc(execf) {}
};
static std::unordered_map<std::string, AclOpFunctions> aclOpFuncMap = {
{"Abs", AclOpFunctions(aclnnAbsGetWorkspaceSize, aclnnAbs)},
{"Exp", AclOpFunctions(aclnnExpGetWorkspaceSize, aclnnExp)},
{"Log", AclOpFunctions(aclnnLogGetWorkspaceSize, aclnnLog)},
{"Sqrt", AclOpFunctions(aclnnSqrtGetWorkspaceSize, aclnnSqrt)},
{"Ceil", AclOpFunctions(aclnnCeilGetWorkspaceSize, aclnnCeil)},
{"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)},
{"Round", AclOpFunctions(aclnnRoundGetWorkspaceSize, aclnnRound)},
{"Sin", AclOpFunctions(aclnnSinGetWorkspaceSize, aclnnSin)},
{"Cos", AclOpFunctions(aclnnCosGetWorkspaceSize, aclnnCos)},
{"Tan", AclOpFunctions(aclnnTanGetWorkspaceSize, aclnnTan)},
{"Asin", AclOpFunctions(aclnnAsinGetWorkspaceSize, aclnnAsin)},
{"Acos", AclOpFunctions(aclnnAcosGetWorkspaceSize, aclnnAcos)},
{"Atan", AclOpFunctions(aclnnAtanGetWorkspaceSize, aclnnAtan)},
{"Sinh", AclOpFunctions(aclnnSinhGetWorkspaceSize, aclnnSinh)},
{"Cosh", AclOpFunctions(aclnnCoshGetWorkspaceSize, aclnnCosh)},
{"Tanh", AclOpFunctions(aclnnTanhGetWorkspaceSize, aclnnTanh)},
{"Asinh", AclOpFunctions(aclnnAsinhGetWorkspaceSize, aclnnAsinh)},
{"Acosh", AclOpFunctions(aclnnAcoshGetWorkspaceSize, aclnnAcosh)},
{"Atanh", AclOpFunctions(aclnnAtanhGetWorkspaceSize, aclnnAtanh)},
{"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)},
{"Erf", AclOpFunctions(aclnnErfGetWorkspaceSize, aclnnErf)},
{"Erfinv", AclOpFunctions(aclnnErfinvGetWorkspaceSize, aclnnErfinv)},
{"LogicalNot", AclOpFunctions(aclnnLogicalNotGetWorkspaceSize, aclnnLogicalNot)},
{"BitwiseNot", AclOpFunctions(aclnnBitwiseNotGetWorkspaceSize, aclnnBitwiseNot)},
{"Neg", AclOpFunctions(aclnnNegGetWorkspaceSize, aclnnNeg)},
{"Cast", AclOpFunctions(aclnnCastGetWorkspaceSize, aclnnCast)},
{"Maximum", AclOpFunctions(aclnnMaximumGetWorkspaceSize, aclnnMaximum)},
{"Minimum", AclOpFunctions(aclnnMinimumGetWorkspaceSize, aclnnMinimum)},
{"Add", AclOpFunctions(aclnnAddGetWorkspaceSize, aclnnAdd)},
{"Sub", AclOpFunctions(aclnnSubGetWorkspaceSize, aclnnSub)},
{"Mul", AclOpFunctions(aclnnMulGetWorkspaceSize, aclnnMul)},
{"RealDiv", AclOpFunctions(aclnnDivGetWorkspaceSize, aclnnDiv)},
{"FloorDiv", AclOpFunctions(aclnnFloorDivideGetWorkspaceSize, aclnnFloorDivide)},
{"LessEqual", AclOpFunctions(aclnnLeTensorGetWorkspaceSize, aclnnLeTensor)},
{"Less", AclOpFunctions(aclnnLtTensorGetWorkspaceSize, aclnnLtTensor)},
{"GreaterEqual", AclOpFunctions(aclnnGeTensorGetWorkspaceSize, aclnnGeTensor)},
{"Greater", AclOpFunctions(aclnnGtTensorGetWorkspaceSize, aclnnGtTensor)},
{"Equal", AclOpFunctions(aclnnEqTensorGetWorkspaceSize, aclnnEqTensor)},
{"NotEqual", AclOpFunctions(aclnnNeTensorGetWorkspaceSize, aclnnNeTensor)},
{"LogicalAnd", AclOpFunctions(aclnnLogicalAndGetWorkspaceSize, aclnnLogicalAnd)},
{"LogicalOr", AclOpFunctions(aclnnLogicalOrGetWorkspaceSize, aclnnLogicalOr)},
{"LogicalXor", AclOpFunctions(aclnnLogicalXorGetWorkspaceSize, aclnnLogicalXor)},
{"BitwiseAnd", AclOpFunctions(aclnnBitwiseAndTensorGetWorkspaceSize, aclnnBitwiseAndTensor)},
{"BitwiseOr", AclOpFunctions(aclnnBitwiseOrTensorGetWorkspaceSize, aclnnBitwiseOrTensor)},
{"BitwiseXor", AclOpFunctions(aclnnBitwiseXorTensorGetWorkspaceSize, aclnnBitwiseXorTensor)},
{"Pow", AclOpFunctions(aclnnPowTensorTensorGetWorkspaceSize, aclnnPowTensorTensor)},
{"Expand", AclOpFunctions(aclnnExpandGetWorkspaceSize, aclnnExpand)},
{"MatMul", AclOpFunctions(aclnnMatmulGetWorkspaceSize, aclnnMatmul)},
{"BatchMatMul", AclOpFunctions(aclnnBatchMatMulGetWorkspaceSize, aclnnBatchMatMul)},
{"ReduceMax", AclOpFunctions(aclnnAmaxGetWorkspaceSize, aclnnAmax)},
{"ReduceMin", AclOpFunctions(aclnnAminGetWorkspaceSize, aclnnAmin)},
{"ReduceSum", AclOpFunctions(aclnnReduceSumGetWorkspaceSize, aclnnReduceSum)},
{"Triu", AclOpFunctions(aclnnTriuGetWorkspaceSize, aclnnTriu)},
{"Conv2d", AclOpFunctions(aclnnConvolutionGetWorkspaceSize, aclnnConvolution)},
{"Conv2dBackward", AclOpFunctions(aclnnConvolutionBackwardGetWorkspaceSize, aclnnConvolutionBackward)},
{"ReduceMean", AclOpFunctions(aclnnMeanGetWorkspaceSize, aclnnMean)},
// {"ReduceProd", AclOpFunctions(aclnnProdDimGetWorkspaceSize, aclnnProdDim)},
{"Select", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)},
{"RandomUniform", AclOpFunctions(aclnnInplaceUniformGetWorkspaceSize, aclnnInplaceUniform)},
{"RandomNormal", AclOpFunctions(aclnnInplaceNormalGetWorkspaceSize, aclnnInplaceNormal)},
{"Transpose", AclOpFunctions(aclnnPermuteGetWorkspaceSize, aclnnPermute)},
{"Maxpool", AclOpFunctions(aclnnMaxPool2dWithIndicesGetWorkspaceSize, aclnnMaxPool2dWithIndices)},
{"MaxpoolBackward", AclOpFunctions(aclnnMaxPool2dWithIndicesBackwardGetWorkspaceSize, aclnnMaxPool2dWithIndicesBackward)},
{"Avgpool", AclOpFunctions(aclnnAvgPool2dGetWorkspaceSize, aclnnAvgPool2d)},
{"AvgpoolBackward", AclOpFunctions(aclnnAvgPool2dBackwardGetWorkspaceSize, aclnnAvgPool2dBackward)},
{"Flip", AclOpFunctions(aclnnFlipGetWorkspaceSize, aclnnFlip)},
{"Concat", AclOpFunctions(aclnnCatGetWorkspaceSize, aclnnCat)},
{"Gather", AclOpFunctions(aclnnGatherGetWorkspaceSize, aclnnGather)},
{"Cumsum", AclOpFunctions(aclnnCumsumGetWorkspaceSize, aclnnCumsum)},
{"Index", AclOpFunctions(aclnnIndexGetWorkspaceSize, aclnnIndex)},
{"Scatter", AclOpFunctions(aclnnScatterGetWorkspaceSize, aclnnScatter)},
{"Nonzero", AclOpFunctions(aclnnNonzeroGetWorkspaceSize, aclnnNonzero)},
{"Where", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)},
{"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)},
{"StridedSliceAssignV2", AclOpFunctions(aclnnStridedSliceAssignV2GetWorkspaceSize, aclnnStridedSliceAssignV2)},
{"SliceV2", AclOpFunctions(aclnnSliceV2GetWorkspaceSize, aclnnSliceV2)},
{"IndexPutImpl", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)},
{"IndexPutImplAccumulate", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)},
{"Range", AclOpFunctions(aclnnRangeGetWorkspaceSize, aclnnRange)},
{"ReLU", AclOpFunctions(aclnnReluGetWorkspaceSize, aclnnRelu)},
{"LeakyReLU", AclOpFunctions(aclnnLeakyReluGetWorkspaceSize, aclnnLeakyRelu)},
{"LeakyReLUBackward", AclOpFunctions(aclnnLeakyReluBackwardGetWorkspaceSize, aclnnLeakyReluBackward)},
{"Dropout", AclOpFunctions(aclnnDropoutGetWorkspaceSize, aclnnDropout)},
{"DropoutBackward", AclOpFunctions(aclnnDropoutBackwardGetWorkspaceSize, aclnnDropoutBackward)},
{"SiLU", AclOpFunctions(aclnnSiluGetWorkspaceSize, aclnnSilu)},
{"SiLUBackward", AclOpFunctions(aclnnSiluBackwardGetWorkspaceSize, aclnnSiluBackward)},
{"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)},
{"SigmoidBackward", AclOpFunctions(aclnnSigmoidBackwardGetWorkspaceSize, aclnnSigmoidBackward)},
{"Embedding", AclOpFunctions(aclnnEmbeddingGetWorkspaceSize, aclnnEmbedding)},
{"EmbeddingBackward", AclOpFunctions(aclnnEmbeddingDenseBackwardGetWorkspaceSize, aclnnEmbeddingDenseBackward)},
{"InplaceMaskedScatter", AclOpFunctions(aclnnInplaceMaskedScatterGetWorkspaceSize, aclnnInplaceMaskedScatter)},
{"MaskedSelect", AclOpFunctions(aclnnMaskedSelectGetWorkspaceSize, aclnnMaskedSelect)},
{"SplitWithSize", AclOpFunctions(aclnnSplitWithSizeGetWorkspaceSize, aclnnSplitWithSize)},
{"Softmax", AclOpFunctions(aclnnSoftmaxGetWorkspaceSize, aclnnSoftmax)},
{"SoftmaxBackward", AclOpFunctions(aclnnSoftmaxBackwardGetWorkspaceSize, aclnnSoftmaxBackward)},
{"FlashAttention", AclOpFunctions(aclnnFlashAttentionScoreV2GetWorkspaceSize, aclnnFlashAttentionScoreV2)},
{"FlashAttentionBackward", AclOpFunctions(aclnnFlashAttentionScoreGradV2GetWorkspaceSize, aclnnFlashAttentionScoreGradV2)},
{"BatchNorm", AclOpFunctions(aclnnBatchNormGetWorkspaceSize, aclnnBatchNorm)},
{"BatchNormBackward", AclOpFunctions(aclnnBatchNormBackwardGetWorkspaceSize, aclnnBatchNormBackward)},
{"LayerNorm", AclOpFunctions(aclnnLayerNormGetWorkspaceSize, aclnnLayerNorm)},
{"RotaryPosEmb", AclOpFunctions(aclnnApplyRotaryPosEmbGetWorkspaceSize, aclnnApplyRotaryPosEmb)},
{"Stack", AclOpFunctions(aclnnStackGetWorkspaceSize, aclnnStack)},
{"NanToNum", AclOpFunctions(aclnnNanToNumGetWorkspaceSize, aclnnNanToNum)},
};
struct AclOpAttr
{
virtual ~AclOpAttr() {}
};
struct ConvAttr : AclOpAttr
{
vector<int64_t> convStrides;
vector<int64_t> convPads;
vector<int64_t> convOutPads;
vector<int64_t> convDilations;
bool convWithBias;
bool is_transposed;
int64_t group;
// 析构函数
~ConvAttr()
{
convStrides.clear();
convPads.clear();
convOutPads.clear();
convDilations.clear();
}
};
struct ReduceAttr : AclOpAttr
{
vector<int64_t> axes;
// for proddim
int64_t prod_dim;
bool keepdims;
~ReduceAttr()
{
axes.clear();
}
};
struct RandomAttr : AclOpAttr
{
int64_t seed, offset;
~RandomAttr()
{
}
};
struct TriuAttr : AclOpAttr
{
int64_t diagonal;
~TriuAttr()
{
}
};
struct PoolAttr : AclOpAttr
{
vector<int64_t> kernel_size;
vector<int64_t> poolStrides;
vector<int64_t> poolPads;
vector<int64_t> poolDilations;
bool poolCeil;
bool countIncludePad;
// divisorOverride(const int64_t计算输入): 表示取平均的除数。数据类型支持INT64。divisorOverride配置为默认值0时表示功能不使能。
// https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/appdevgapi/context/aclnnAvgPool2d.md
int64_t divisorOverride = 0;
// cubeMathType(int8_t计算输入): host侧的整型判断Cube单元应该使用哪种计算逻辑进行运算数据类型支持INT8。对于无特殊说明的数据类型均保持原始输入数据类型计算。支持的枚举值如下
// 0:KEEP_DTYPE保持输入的数据类型进行计算。当输入是FLOATAtlas 训练系列产品和Atlas 推理系列产品Ascend 310P处理器暂不支持取0时会报错。
// 1:ALLOW_FP32_DOWN_PRECISION允许将输入数据降精度计算。当输入是FLOATAtlas 训练系列产品和Atlas 推理系列产品Ascend 310P处理器允许转换为FLOAT16计算。
// 2:USE_FP16允许转换为数据类型FLOAT16进行计算。当输入数据类型是FLOAT转换为FLOAT16计算。
// 3:USE_HF32允许转换为数据类型HFLOAT32计算。当输入是FLOATAtlas 训练系列产品、Atlas 推理系列产品Ascend 310P处理器和Atlas A2训练系列产品/Atlas 800I A2推理产品暂不支持取3时会报错。
// https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/appdevgapi/context/aclnnAvgPool2d.md
int8_t cubeMathType = 0;
// 析构函数
~PoolAttr()
{
kernel_size.clear();
poolStrides.clear();
poolPads.clear();
poolDilations.clear();
}
};
struct ConcatAttr : AclOpAttr
{
int64_t tensorNum;
int64_t dim;
~ConcatAttr()
{
}
};
struct GatherAttr : AclOpAttr
{
int64_t dim;
~GatherAttr()
{
}
};
struct ScatterAttr : AclOpAttr
{
int64_t axis;
int64_t reduction;
~ScatterAttr()
{
}
};
struct StrideAttr : AclOpAttr
{
vector<int64_t> begins;
vector<int64_t> ends;
vector<int64_t> steps;
vector<int64_t> axes;
~StrideAttr()
{
begins.clear();
ends.clear();
steps.clear();
axes.clear();
}
};
struct RangeAttr : AclOpAttr
{
int64_t start;
int64_t end;
int64_t step;
~RangeAttr()
{
}
};
struct LeakyReluAttr : AclOpAttr
{
float negativeSlope;
bool selfIsResult;
~LeakyReluAttr()
{
}
};
struct DropoutAttr : AclOpAttr
{
float p;
bool train;
int64_t seed;
int64_t offset;
float scale;
~DropoutAttr()
{
}
};
struct EmbeddingAttr : AclOpAttr
{
int64_t numEmbeddings;
// int64_t embeddingDim;
int64_t paddingIdx;
bool scaleGradByFreq;
// bool sparse;
// bool isSparse;
// bool isDense;
~EmbeddingAttr()
{
}
};
struct SplitWithSizeAttr : AclOpAttr
{
vector<int64_t> splitSize;
int64_t dim;
~SplitWithSizeAttr()
{
splitSize.clear();
}
};
struct SoftmaxAttr : AclOpAttr
{
int64_t dim;
~SoftmaxAttr()
{
}
};
struct BatchNormAttr : AclOpAttr
{
bool is_train;
float momentum;
float eps;
~BatchNormAttr()
{
}
};
struct LayerNormAttr : AclOpAttr
{
float eps;
vector<int64_t> normalizedShape;
int64_t size;
~LayerNormAttr()
{
normalizedShape.clear();
}
};
struct FlashAttentionAttr : AclOpAttr
{
vector<int64_t> prefix;
vector<int64_t> qStartIdx;
vector<int64_t> kvStartIdx;
float scale;
float keepProb;
int64_t preToken;
int64_t nextToken;
int64_t headNum;
string inputLayout;
int64_t innerPrecise;
int64_t sparseMode;
int64_t psetype;
bool hasRealshift;
bool hasDropmask;
bool hasPaddingmask;
bool hasAttentmask;
~FlashAttentionAttr()
{
prefix.clear();
qStartIdx.clear();
kvStartIdx.clear();
}
};
struct NanToNumAttr : AclOpAttr
{
float nan;
float posinf;
float neginf;
~NanToNumAttr()
{
}
};
}

File diff suppressed because it is too large Load Diff

58
python/jittor/extern/acl/aclnn/aclnn.cc vendored Normal file
View File

@ -0,0 +1,58 @@
#include <iostream>
#include <vector>
#include "aclnn.h"
int64_t GetShapeSize(const std::vector<int64_t>& shape) {
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
auto size = GetShapeSize(shape);
std::vector<int> resultData(size, 0);
auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]),
*deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
for (int64_t i = 0; i < size; i++) {
LOG_PRINT("mean result[%ld] is: %d\n", i, resultData[i]);
}
}
/*int Init(int32_t deviceId) {
// 固定写法AscendCL初始化
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
//ret = aclrtCreateStream(stream);
//CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0;
}*/
/*
template <typename T>
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
aclDataType dataType, aclTensor** tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
return 0;
}*/

134
python/jittor/extern/acl/aclnn/aclnn.h vendored Normal file
View File

@ -0,0 +1,134 @@
#include <iostream>
#include <vector>
#include "acl.h"
// unary
#include "aclnnop/aclnn_abs.h"
#include "aclnnop/aclnn_neg.h"
#include "aclnnop/aclnn_exp.h"
#include "aclnnop/aclnn_log.h"
#include "aclnnop/aclnn_sqrt.h"
#include "aclnnop/aclnn_ceil.h"
#include "aclnnop/aclnn_floor.h"
#include "aclnnop/aclnn_round.h"
#include "aclnnop/aclnn_sin.h"
#include "aclnnop/aclnn_cos.h"
#include "aclnnop/aclnn_tan.h"
#include "aclnnop/aclnn_asin.h"
#include "aclnnop/aclnn_acos.h"
#include "aclnnop/aclnn_atan.h"
#include "aclnnop/aclnn_sinh.h"
#include "aclnnop/aclnn_cosh.h"
#include "aclnnop/aclnn_tanh.h"
#include "aclnnop/aclnn_asinh.h"
#include "aclnnop/aclnn_acosh.h"
#include "aclnnop/aclnn_atanh.h"
#include "aclnnop/aclnn_sigmoid.h"
#include "aclnnop/aclnn_erf.h"
#include "aclnnop/aclnn_erfinv.h"
#include "aclnnop/aclnn_logical_not.h"
#include "aclnnop/aclnn_bitwise_not.h"
#include "aclnnop/aclnn_cast.h"
#include "aclnnop/aclnn_nonzero.h"
// binary
#include "aclnnop/aclnn_maximum.h"
#include "aclnnop/aclnn_minimum.h"
#include "aclnnop/aclnn_add.h"
#include "aclnnop/aclnn_sub.h"
#include "aclnnop/aclnn_mul.h"
#include "aclnnop/aclnn_div.h"
#include "aclnnop/aclnn_floor_divide.h"
#include "aclnnop/aclnn_le_tensor.h"
#include "aclnnop/aclnn_lt_tensor.h"
#include "aclnnop/aclnn_ge_tensor.h"
#include "aclnnop/aclnn_gt_tensor.h"
#include "aclnnop/aclnn_eq_tensor.h"
#include "aclnnop/aclnn_ne_tensor.h"
#include "aclnnop/aclnn_logical_and.h"
#include "aclnnop/aclnn_logical_or.h"
#include "aclnnop/aclnn_logical_xor.h"
#include "aclnnop/aclnn_bitwise_and_tensor.h"
#include "aclnnop/aclnn_bitwise_or_tensor.h"
#include "aclnnop/aclnn_bitwise_xor_tensor.h"
#include "aclnnop/aclnn_pow_tensor_tensor.h"
#include "aclnnop/aclnn_expand.h"
#include "aclnnop/aclnn_matmul.h"
#include "aclnnop/aclnn_batch_matmul.h"
#include "aclnnop/aclnn_convolution.h"
#include "aclnnop/aclnn_convolution_backward.h"
#include "aclnnop/aclnn_reduce_sum.h"
#include "aclnnop/aclnn_amax.h"
#include "aclnnop/aclnn_amin.h"
#include "aclnnop/aclnn_mean.h"
#include "aclnnop/aclnn_prod.h"
#include "aclnnop/aclnn_triu.h"
#include "aclnnop/aclnn_s_where.h"
#include "aclnnop/aclnn_random.h"
#include "aclnnop/aclnn_normal.h"
#include "aclnnop/aclnn_permute.h"
#include "aclnnop/aclnn_max_pool2d_with_indices.h"
#include "aclnnop/aclnn_max_pool2d_with_indices_backward.h"
#include "aclnnop/aclnn_avgpool2d.h"
#include "aclnnop/aclnn_avgpool2d_backward.h"
#include "aclnnop/aclnn_flip.h"
#include "aclnnop/aclnn_cat.h"
#include "aclnnop/aclnn_gather.h"
#include "aclnnop/aclnn_cumsum.h"
#include "aclnnop/aclnn_index.h"
#include "aclnnop/aclnn_scatter.h"
#include "aclnnop/aclnn_index.h"
#include "aclnnop/aclnn_strided_slice_assign_v2.h"
#include "aclnnop/aclnn_slice_v2.h"
#include "aclnnop/aclnn_index_put_impl.h"
#include "aclnnop/aclnn_range.h"
#include "aclnnop/aclnn_relu.h"
#include "aclnnop/aclnn_dropout.h"
#include "aclnnop/aclnn_dropout_backward.h"
#include "aclnnop/aclnn_leaky_relu.h"
#include "aclnnop/aclnn_leaky_relu_backward.h"
#include "aclnnop/aclnn_uniform.h"
#include "aclnnop/aclnn_silu.h"
#include "aclnnop/aclnn_silu_backward.h"
#include "aclnnop/aclnn_sigmoid.h"
#include "aclnnop/aclnn_sigmoid_backward.h"
#include "aclnnop/aclnn_embedding.h"
#include "aclnnop/aclnn_embedding_dense_backward.h"
#include "aclnnop/aclnn_masked_scatter.h"
#include "aclnnop/aclnn_masked_select.h"
#include "aclnnop/aclnn_split_with_size.h"
#include "aclnnop/aclnn_flash_attention_score.h"
#include "aclnnop/aclnn_flash_attention_score_grad.h"
#include "aclnnop/aclnn_softmax.h"
#include "aclnnop/aclnn_softmax_backward.h"
#include "aclnnop/aclnn_batch_norm.h"
#include "aclnnop/aclnn_batch_norm_backward.h"
#include "aclnnop/aclnn_layer_norm.h"
#include "aclnnop/aclnn_apply_rotary_pos_emb.h"
#include "aclnnop/aclnn_stack.h"
#include "aclnnop/aclnn_nan_to_num.h"
#define CHECK_RET(cond, return_expr) \
do \
{ \
if (!(cond)) \
{ \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do \
{ \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t> &shape);
void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr);
//int Init(int32_t deviceId);
/*
template <typename T>
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
aclDataType dataType, aclTensor** tensor);
*/

View File

@ -0,0 +1,33 @@
#pragma once
#include <acl/aclops/binary_op_acl.h>
#include <acl/aclops/unary_op_acl.h>
#include <acl/aclops/conv_op_acl.h>
#include <acl/aclops/ternary_op_acl.h>
#include <acl/aclops/reduce_op_acl.h>
#include <acl/aclops/expand_op_acl.h>
#include <acl/aclops/getitem_op_acl.h>
#include <acl/aclops/setitem_op_acl.h>
#include <acl/aclops/matmul_op_acl.h>
#include <acl/aclops/random_op_acl.h>
#include <acl/aclops/bmm_op_acl.h>
#include <acl/aclops/pool_op_acl.h>
#include <acl/aclops/flip_op_acl.h>
#include <acl/aclops/concat_op_acl.h>
#include <acl/aclops/gather_scatter_op_acl.h>
#include <acl/aclops/cumsum_op_acl.h>
#include <acl/aclops/index_op_acl.h>
#include <acl/aclops/where_op_acl.h>
#include <acl/aclops/floor_op_acl.h>
#include <acl/aclops/transpose_op_acl.h>
#include <acl/aclops/flashattention_op_acl.h>
#include <acl/aclops/relu_op_acl.h>
#include <acl/aclops/dropout_op_acl.h>
#include <acl/aclops/silu_op_acl.h>
#include <acl/aclops/sigmoid_op_acl.h>
#include <acl/aclops/softmax_op_acl.h>
#include <acl/aclops/stack_op_acl.h>
#include <acl/aclops/nantonum_op_acl.h>
#include <acl/aclops/rope_op_acl.h>
#include <acl/aclops/triu_op_acl.h>
#include <acl/aclops/embedding_op_acl.h>
#include <acl/aclops/norms_op_acl.h>

View File

@ -0,0 +1,56 @@
#pragma once
#include "utils.h"
#include "acl_jittor.h"
namespace jittor
{
extern int sync_run;
class BaseOpRunner
{
protected:
vector<Var *> in_;
vector<Var *> out_;
int ret = -1;
uint64_t workspaceSize = 0;
aclOpExecutor *executor;
bool is_group_op = false;
std::vector<std::vector<int64_t>> inputShapes;
std::vector<std::vector<int64_t>> outputShapes;
std::vector<aclTensor *> inputTensors;
std::vector<aclTensor *> outputTensors;
public:
string name;
string jt_name;
std::unique_ptr<AclOpAttr> op_attr;
bool use_nchw = false;
BaseOpRunner(const string &name = "") : name(name) {}
virtual ~BaseOpRunner() = default;
// Common functionality for adding input/output variables
void add(Var *v, bool is_input);
virtual void setupInputDesc();
void cleanupDesc();
virtual void setupOutputDesc();
virtual void syncRun();
void checkRet(aclnnStatus ret);
// Base run method with common operator lookup logic
void run();
protected:
// Virtual method for specific operator execution
virtual void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) = 0;
void cleanupAttr();
};
}

View File

@ -0,0 +1,152 @@
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "binary_op_acl.h"
#include "base_op.h"
namespace jittor
{
extern int sync_run;
// Common functionality for adding input/output variables
void BaseOpRunner::add(Var *v, bool is_input)
{
if (is_input)
{
in_.push_back(v);
}
else
{
out_.push_back(v);
}
return;
}
void BaseOpRunner::setupInputDesc()
{
auto input_num = in_.size();
for (int input_idx = 0; input_idx < input_num; input_idx++)
{
std::vector<int64_t> shape;
for (int j = 0; j < in_[input_idx]->shape.size(); j++)
{
shape.push_back(in_[input_idx]->shape[j]);
}
inputShapes.push_back(shape);
}
for (int idx = 0; idx < input_num; idx++)
{
inputTensors.push_back(nullptr);
auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
}
}
void BaseOpRunner::cleanupDesc()
{
auto input_num = in_.size();
auto output_num = out_.size();
for (int idx = 0; idx < input_num; idx++)
{
aclDestroyTensor(inputTensors[idx]);
}
for (int idx = 0; idx < output_num; idx++)
{
aclDestroyTensor(outputTensors[idx]);
}
}
void BaseOpRunner::setupOutputDesc()
{
auto output_num = out_.size();
for (int output_idx = 0; output_idx < output_num; output_idx++)
{
std::vector<int64_t> shape;
for (int j = 0; j < out_[output_idx]->shape.size(); j++)
{
shape.push_back(out_[output_idx]->shape[j]);
}
outputShapes.push_back(shape);
}
for (int idx = 0; idx < output_num; idx++)
{
outputTensors.push_back(nullptr);
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
}
}
void BaseOpRunner::syncRun()
{
if (sync_run)
{
// ret = aclrtSynchronizeStream(aclstream);
// CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
}
}
void BaseOpRunner::checkRet(aclnnStatus ret)
{
if (ret != ACL_SUCCESS)
{
auto tmp_err_msg = aclGetRecentErrMsg();
LOGir << name << ", " << tmp_err_msg;
}
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
}
// Base run method with common operator lookup logic
void BaseOpRunner::run()
{
if (is_group_op)
{
auto it = aclOpFuncMap.find(name);
if (it == aclOpFuncMap.end())
{
LOGir << "aclOpFuncMap Not supported op: " << name;
throw std::runtime_error("Unsupported operation type.");
}
setupInputDesc();
setupOutputDesc();
executeOp(it);
cleanupDesc();
}
else
{
auto it = aclOpFuncMap.find(name);
setupInputDesc();
setupOutputDesc();
executeOp(it);
cleanupDesc();
}
}
}

View File

@ -0,0 +1,124 @@
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "binary_op_acl.h"
namespace jittor
{
BinaryOpRunner::BinaryOpRunner() : BaseOpRunner("binary")
{
use_nchw = false;
is_group_op = true;
}
void BinaryOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
aclScalar *alpha = nullptr;
if (name == string("Add") || name == string("Sub"))
{
if (get_dtype(in_[0]->dtype()) == ACL_FLOAT)
{
float alphaValue = 1.0;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_FLOAT16)
{
__fp16 alphaValue = 1.0;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_INT64)
{
int64_t alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_INT32)
{
int alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_INT8)
{
int8_t alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_INT16)
{
int16_t alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_UINT8)
{
uint8_t alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_UINT16)
{
uint16_t alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_UINT32)
{
uint32_t alphaValue = 1;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else if (get_dtype(in_[0]->dtype()) == ACL_BOOL)
{
bool alphaValue = true;
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
}
else
{
LOGf << "Not supported dtype: " << in_[0]->dtype();
}
CHECK_RET(alpha != nullptr, return);
ret = it->second.getWorkspaceSizeFuncAdd(inputTensors[0], inputTensors[1], alpha, outputTensors[0], &workspaceSize, &executor);
}
else
{
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
}
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyScalar(alpha);
return;
}
}

View File

@ -0,0 +1,14 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
struct BinaryOpRunner : public BaseOpRunner
{
BinaryOpRunner();
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
};
}

View File

@ -0,0 +1,128 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def acl_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
BatchMatMulOpRunner op;
{input_code}
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
class BmmACL(jt.Function):
def __init__(self, trans_x2=False):
super(BmmACL, self).__init__()
self.trans_x2 = trans_x2
def execute(self, x1, x2):
self.input = [x1, x2]
result = acl_cmd("BatchMatMul", [x1, x2],
output_dtypes=[x1.dtype],
output_shapes=[
x1.shape[:-1] + x2.shape[-2:-1] if self.trans_x2
else x1.shape[:-1] + x2.shape[-1:]
],
attr_code="op.jt_name=\"bmm_trans_1\";"
if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
return result
def grad(self, grad_output):
x1, x2 = self.input
if len(x1) != len(x2):
reshape_grad_x2 = True
else:
reshape_grad_x2 = False
grad_x1 = acl_cmd(
"BatchMatMul", [grad_output, x2],
output_dtypes=[x1.dtype],
output_shapes=[
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
else grad_output.shape[:-1] + x1.shape[-1:]
],
attr_code="op.jt_name=\"bmm_trans_1\";"
if not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
if self.trans_x2:
if reshape_grad_x2:
output_shape = grad_output.shape[1:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = acl_cmd("BatchMatMul", [
grad_output.reshape(-1, grad_output.shape[-1]),
x1.reshape(-1, x1.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
else:
output_shape = grad_output.shape[:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = acl_cmd("BatchMatMul", [grad_output, x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
else:
if reshape_grad_x2:
output_shape = x1.shape[1:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = acl_cmd("BatchMatMul", [
x1.reshape(-1, x1.shape[-1]),
grad_output.reshape(-1, grad_output.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
else:
output_shape = x1.shape[:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = acl_cmd("BatchMatMul", [x1, grad_output],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
if len(grad_x1.shape) > len(x1.shape):
grad_x1 = grad_x1.sum(0)
if len(grad_x2.shape) > len(x2.shape):
grad_x2 = grad_x2.sum(0)
return grad_x1, grad_x2

View File

@ -0,0 +1,77 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "bmm_op_acl.h"
namespace jittor
{
BatchMatMulOpRunner::BatchMatMulOpRunner() : BaseOpRunner("BatchMatMulMatMul")
{
}
void BatchMatMulOpRunner::setupInputDesc()
{
auto input_num = in_.size();
for (int input_idx = 0; input_idx < input_num; input_idx++)
{
std::vector<int64_t> shape;
for (int j = 0; j < in_[input_idx]->shape.size(); j++)
{
shape.push_back(in_[input_idx]->shape[j]);
}
inputShapes.push_back(shape);
}
for (int idx = 0; idx < input_num; idx++)
{
inputTensors.push_back(nullptr);
if ((jt_name == "bmm_trans_1" && idx == 1) || (jt_name == "bmm_trans_0" && idx == 0))
{
auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
}
else
{
auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
}
}
}
void BatchMatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnbatchMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
}
}

View File

@ -0,0 +1,17 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class BatchMatMulOpRunner : public BaseOpRunner
{
protected:
void setupInputDesc() override;
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
BatchMatMulOpRunner();
};
}

View File

@ -0,0 +1,186 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def concat_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class ConcatACL(jt.Function):
def __init__(self):
super(ConcatACL, self).__init__()
def __call__(self, *args):
assert isinstance(args[0], (list, tuple))
assert isinstance(args[1], int)
if jt.flags.no_grad:
return self.execute(*args)
backup = args
args = list(args)
taped_inputs = []
taped_outputs = []
input_mask = [-1] * (len(args[0]) + 1)
newargs = [list(), args[1]]
for i, v in enumerate(args[0]):
if isinstance(v, jt.Var):
if v.is_stop_grad():
# -2 in input_mask represents it is stop_grad
input_mask[i] = -2
newargs[0].append(v)
continue
v = v.tape()
newargs[0].append(v)
input_mask[i] = len(taped_inputs)
taped_inputs.append(v)
ori_res = self.execute(*newargs)
if not isinstance(ori_res, Sequence):
res = [ori_res]
else:
res = list(ori_res)
output_mask = [-1] * len(res)
for i, v in enumerate(res):
if isinstance(v, jt.Var):
v = v.tape()
output_mask[i] = len(taped_outputs)
res[i] = v
taped_outputs.append(v)
self.input_mask = input_mask
self.output_mask = output_mask
# tape output and input together so
# backward treat them as one operator
jt.tape_together(taped_inputs, taped_outputs, self._grad)
if isinstance(ori_res, Sequence):
return res
else:
return res[0]
def execute(self, input_tensors, dim=0):
for _ in input_tensors:
if not (-_.ndim <= dim < _.ndim):
print(_.shape, dim)
raise ValueError("dim out of range")
if dim < 0:
dim += input_tensors[0].ndim
self.input = input_tensors
self.dim = dim
for i in range(len(input_tensors)):
if input_tensors[i].dtype != input_tensors[0].dtype:
raise ValueError("All input tensors must have the same dtype")
if input_tensors[i].shape[:dim] != input_tensors[
0].shape[:dim] or input_tensors[i].shape[
dim + 1:] != input_tensors[0].shape[dim + 1:]:
raise ValueError("All input tensors must have the same shape")
attr_code = f"""
op.jt_name = "concat";
ConcatAttr *attr = new ConcatAttr();
attr->tensorNum = {len(input_tensors)};
attr->dim = {dim};
op.op_attr.reset(attr);
"""
result = concat_cmd(
"Concat",
input_tensors,
output_dtypes=[input_tensors[0].dtype],
output_shapes=[
jt.empty(self.calculate_output_shape(input_tensors, dim)).shape
],
attr_code=attr_code)[0]
return result
def _grad(self, *args):
new_args = ((args[i] if i >= 0 else None) for i in self.output_mask)
ret = self.grad(*new_args)
new_ret = []
for i, r in enumerate(ret):
j = self.input_mask[i]
if j < 0:
# -2 in input_mask represents it is stop_grad
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
"because the input value is not jittor variable."
else:
new_ret.append(r)
return new_ret
def grad(self, grad_output):
grad_inputs = self.split_grad(grad_output, self.input, self.dim)
return grad_inputs
def calculate_output_shape(self, input_tensors, axis):
shape = list(input_tensors[0].shape)
for tensor in input_tensors[1:]:
shape[axis] += tensor.shape[axis]
return tuple(shape)
def split_grad(self, grad_output, input_tensors, axis):
offset = []
shapeVec = []
dtypeVec = []
for tensor in input_tensors:
offset.append(tensor.shape[axis])
dtypeVec.append(tensor.dtype)
shapeVec.append(tensor.shape)
attr_code = f"""
op.jt_name = "splitwithsize";
auto *attr = new SplitWithSizeAttr();
attr->splitSize = {{ {", ".join(map(str, offset))} }};
attr->dim = {axis};
op.op_attr.reset(attr);
"""
result = concat_cmd("SplitWithSize", [grad_output],
output_dtypes=dtypeVec,
output_shapes=shapeVec,
attr_code=attr_code)
return result

View File

@ -0,0 +1,89 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "concat_op_acl.h"
namespace jittor
{
ConcatOpRunner::ConcatOpRunner() : BaseOpRunner("Concat")
{
}
void ConcatOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto input_num = in_.size();
std::vector<aclTensor *> concatTensorList = {};
for (int i = 0; i < input_num; i++)
{
concatTensorList.push_back(inputTensors[i]);
}
auto concatTensorListInput = aclCreateTensorList(&concatTensorList[0], input_num);
auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
ret = aclnnCatGetWorkspaceSize(concatTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnCat(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCat failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
SplitWithSizeOpRunner::SplitWithSizeOpRunner() : BaseOpRunner("SplitWithSize")
{
}
void SplitWithSizeOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto output_num = out_.size();
auto attr = dynamic_cast<SplitWithSizeAttr *>(op_attr.get());
auto splitSize = aclCreateIntArray(attr->splitSize.data(), attr->splitSize.size());
auto tensorList = aclCreateTensorList(&outputTensors[0], output_num);
ret = aclnnSplitWithSizeGetWorkspaceSize(inputTensors[0], splitSize, attr->dim, tensorList, &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnSplitWithSize(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSplitWithSize failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class ConcatOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
ConcatOpRunner();
};
class SplitWithSizeOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SplitWithSizeOpRunner();
};
}

View File

@ -0,0 +1,160 @@
import os
import jittor_utils
from jittor_utils import env_or_try_find
import ctypes
import glob
import jittor as jt
import jittor.compiler as compiler
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable):
return x
return tuple([x] * n)
return parse
_pair = _ntuple(2)
def conv_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class ConvACL(jt.Function):
def execute(self,
x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
self.input = x
self.weight = weight
self.bias = bias
padding = _pair(padding)
stride = _pair(stride)
dilation = _pair(dilation)
out_channels = weight.shape[0]
if groups <= 0:
raise ValueError("groups must be a positive integer")
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
attr_code = f"""
op.jt_name = "conv2d";
ConvAttr *attr = new ConvAttr();
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
attr->convPads = {{ {padding[0]}, {padding[1]} }};
attr->convDilations = {{ {dilation[0]}, {dilation[1]} }};
attr->group = {groups};
attr->convOutPads = {{1,1}};
op.op_attr.reset(attr);
"""
input_height, input_width = x.shape[-2:]
kernel_height, kernel_width = weight.shape[-2:]
output_height = (input_height + 2 * padding[0] - dilation[0] *
(kernel_height - 1) - 1) // stride[0] + 1
output_width = (input_width + 2 * padding[1] - dilation[1] *
(kernel_width - 1) - 1) // stride[1] + 1
output_shape = (x.shape[0], out_channels, output_height, output_width)
inputs = [x, weight]
if bias is not None:
inputs.append(bias)
result = conv_cmd(
"Conv2d",
inputs,
output_dtypes=[x.dtype],
output_shapes=[output_shape],
attr_code=attr_code,
)[0]
return result
def grad(self, grad_output):
x = self.input
weight = self.weight
bias = self.bias
inputs = [grad_output, x, weight]
if bias is not None:
inputs.append(bias)
output_shapes = [x.shape, weight.shape]
output_dtypes = [x.dtype, weight.dtype]
if bias is not None:
output_shapes.append(bias.shape)
output_dtypes.append(bias.dtype)
else:
output_shapes.append([weight.shape[0]])
output_dtypes.append(x.dtype)
padding = self.padding
stride = self.stride
dilation = self.dilation
groups = self.groups
attr_code = f"""
op.jt_name = "conv2dbackward";
ConvAttr *attr = new ConvAttr();
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
attr->convPads = {{ {padding[0]}, {padding[1]} }};
attr->convDilations = {{ {dilation[0]}, {dilation[1]} }};
attr->group = {groups};
attr->convOutPads = {{ 1,1}};
op.op_attr.reset(attr);
"""
results = conv_cmd("Conv2dBackward",
inputs,
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)
if self.bias is None:
return results[0], results[1]
return results

View File

@ -0,0 +1,152 @@
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "conv_op_acl.h"
namespace jittor
{
Conv2dOpRunner::Conv2dOpRunner() : BaseOpRunner("Conv2d")
{
use_nchw = true;
}
void Conv2dOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
// for conv
aclIntArray *strides = nullptr;
aclIntArray *pads = nullptr;
aclIntArray *outPads = nullptr;
aclIntArray *dilations = nullptr;
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
strides = aclCreateIntArray(attr->convStrides.data(), 2);
pads = aclCreateIntArray(attr->convPads.data(), 2);
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
aclTensor *bias = nullptr;
auto input_num = in_.size();
if (input_num == 3)
bias = inputTensors[2];
ret = aclnnConvolutionGetWorkspaceSize(inputTensors[0], inputTensors[1], bias, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolution failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyIntArray(strides);
aclDestroyIntArray(pads);
aclDestroyIntArray(outPads);
aclDestroyIntArray(dilations);
return;
}
Conv2dBackwardOpRunner::Conv2dBackwardOpRunner() : BaseOpRunner("Conv2dBackward")
{
use_nchw = true;
}
void Conv2dBackwardOpRunner::setupOutputDesc()
{
auto output_num = out_.size();
for (int output_idx = 0; output_idx < output_num; output_idx++)
{
std::vector<int64_t> shape;
for (int j = 0; j < out_[output_idx]->shape.size(); j++)
{
shape.push_back(out_[output_idx]->shape[j]);
}
outputShapes.push_back(shape);
}
for (int idx = 0; idx < 2; idx++)
{
outputTensors.push_back(nullptr);
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
}
// biasgrad nd format
{
outputTensors.push_back(nullptr);
auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false);
CHECK_RET(ret == ACL_SUCCESS, return);
}
}
void Conv2dBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
// for conv
aclIntArray *strides = nullptr;
aclIntArray *pads = nullptr;
aclIntArray *outPads = nullptr;
aclIntArray *dilations = nullptr;
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
strides = aclCreateIntArray(attr->convStrides.data(), 2);
pads = aclCreateIntArray(attr->convPads.data(), 2);
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
bool outputMask[3] = {true, true, true};
auto input_num = in_.size();
if (input_num == 3)
{
outputMask[2] = false;
}
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size());
ret = aclnnConvolutionBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnConvolutionBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolutionBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyIntArray(strides);
aclDestroyIntArray(pads);
aclDestroyIntArray(outPads);
aclDestroyIntArray(dilations);
return;
}
}

View File

@ -0,0 +1,27 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class Conv2dOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
Conv2dOpRunner();
};
class Conv2dBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
void setupOutputDesc() override;
public:
Conv2dBackwardOpRunner();
};
}

View File

@ -0,0 +1,101 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def cumsum_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class CumsumACL(jt.Function):
def __init__(self):
super(CumsumACL, self).__init__()
def execute(self, input, dim=-1):
self.dim = dim
attr_code = f"""
op.jt_name = "cumsum";
GatherAttr *attr = new GatherAttr();
attr->dim = {dim};
op.op_attr.reset(attr);
"""
result = cumsum_cmd("Cumsum", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
cumsum_attr_code = f"""
op.jt_name = "cumsum";
GatherAttr *attr = new GatherAttr();
attr->dim = {self.dim};
op.op_attr.reset(attr);
"""
flip_attr_code = f"""
op.jt_name = "flip";
ReduceAttr *attr = new ReduceAttr();
attr->axes = {{{self.dim}}};
attr->prod_dim = {{{1}}};
op.op_attr.reset(attr);
"""
flipped_grad_output = cumsum_cmd("Flip", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=flip_attr_code)[0]
cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=cumsum_attr_code)[0]
grad_input = cumsum_cmd("Flip", [cumulative_grad],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=flip_attr_code)[0]
return grad_input

View File

@ -0,0 +1,57 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "cumsum_op_acl.h"
namespace jittor
{
CumsumOpRunner::CumsumOpRunner() : BaseOpRunner("Cumsum")
{
}
void CumsumOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
ret = aclnnCumsumGetWorkspaceSize(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnCumsum(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCumsum failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,17 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class CumsumOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
CumsumOpRunner();
};
}

View File

@ -0,0 +1,94 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def dropout_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class DropoutACL(jt.Function):
def __init__(self):
super(DropoutACL, self).__init__()
def execute(self, x, p=0.5, is_train=False):
self.input = x
num_elements = x.numel()
aligned_elements = (num_elements + 127) // 128 * 128
mask_shape = (aligned_elements // 8, )
attr_code = f"""
op.jt_name = "dropout";
DropoutAttr *attr = new DropoutAttr();
attr->p = {p};
attr->train = {"true" if is_train else "false"};
attr->seed = 0;
attr->offset = 0;
op.op_attr.reset(attr);
"""
result = dropout_cmd("Dropout", [x],
output_dtypes=[x.dtype, "uint8"],
output_shapes=[x.shape, mask_shape],
attr_code=attr_code)
self.maskout = result[1]
return result[0]
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "dropoutbackward";
DropoutAttr *attr = new DropoutAttr();
attr->scale = 1.0;
op.op_attr.reset(attr);
"""
grad_input = dropout_cmd("DropoutBackward",
[grad_output, self.maskout],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input

View File

@ -0,0 +1,82 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "dropout_op_acl.h"
namespace jittor
{
DropoutOpRunner::DropoutOpRunner() : BaseOpRunner("Dropout")
{
}
void DropoutOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
ret = aclnnDropoutGetWorkspaceSize(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnDropout(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropout failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
DropoutBackwardOpRunner::DropoutBackwardOpRunner() : BaseOpRunner("DropoutBackward")
{
}
void DropoutBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
ret = aclnnDropoutBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnDropoutBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropoutBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,27 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class DropoutOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
DropoutOpRunner();
};
class DropoutBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
DropoutBackwardOpRunner();
};
}

View File

@ -0,0 +1,91 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def embedding_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class EmbeddingACL(jt.Function):
def __init__(self):
super(EmbeddingACL, self).__init__()
def execute(
self,
indices,
weight,
):
inputs = [weight, indices]
self.indices = indices
self.weight_shape = weight.shape
output_shape = list(indices.shape) + list(weight.shape[1:])
outputs = [jt.empty(output_shape, weight.dtype)]
attr_code = f"""
op.jt_name = "embedding";
"""
result = embedding_cmd("Embedding",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
inputs = [grad_output, self.indices]
outputs = [jt.empty(self.weight_shape, grad_output.dtype)]
attr_code = f"""
op.jt_name = "embeddingbackward";
EmbeddingAttr *attr = new EmbeddingAttr();
attr->numEmbeddings = {self.weight_shape[0]};
op.op_attr.reset(attr);
"""
grad_weight = embedding_cmd("EmbeddingBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return None, grad_weight

View File

@ -0,0 +1,82 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "embedding_op_acl.h"
namespace jittor
{
EmbeddingOpRunner::EmbeddingOpRunner() : BaseOpRunner("Embedding")
{
}
void EmbeddingOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnEmbeddingGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnEmbedding(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbedding failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
EmbeddingBackwardOpRunner::EmbeddingBackwardOpRunner() : BaseOpRunner("EmbeddingBackward")
{
}
void EmbeddingBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<EmbeddingAttr *>(op_attr.get());
auto numEmbeddings = attr->numEmbeddings;
ret = aclnnEmbeddingDenseBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], numEmbeddings, 0, false, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnEmbeddingDenseBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbeddingDenseBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class EmbeddingOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
EmbeddingOpRunner();
};
class EmbeddingBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
EmbeddingBackwardOpRunner();
};
}

View File

@ -0,0 +1,58 @@
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "expand_op_acl.h"
namespace jittor
{
ExpandOpRunner::ExpandOpRunner() : BaseOpRunner("ternary")
{
use_nchw = false;
}
void ExpandOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
aclIntArray *size = nullptr;
size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size());
ret = aclnnExpandGetWorkspaceSize(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnExpand(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnExpand failed. ERROR: %d\n", name.c_str(), ret); return);
aclDestroyIntArray(size);
return;
}
}

View File

@ -0,0 +1,14 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
struct ExpandOpRunner : public BaseOpRunner
{
ExpandOpRunner();
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
};
}

View File

@ -0,0 +1,209 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def flashattention_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class FlashAttentionACL(jt.Function):
def __init__(self,
headnum,
layout="BNSD",
prefix=None,
qstart=None,
kvstart=None,
scale=1.0,
prob=1.0,
pretokens=2147483647,
nexttokens=2147483647,
innerprecise=0,
sparsemode=0,
psetype=1):
self.headnum = headnum
self.layout = layout
self.scale = scale
self.prob = prob
self.pretokens = pretokens
self.nexttokens = nexttokens
self.innerprecise = innerprecise
self.sparsemode = sparsemode
self.psetype = psetype
self.prefix = prefix
self.qstart = qstart
self.kvstart = kvstart
def execute(
self,
q,
k,
v,
realshift=None,
dropMask=None,
paddingMask=None,
attenMask=None,
):
if self.layout == 'BSH':
B, SQ, H = q.shape
SKV = k.shape[1]
N = self.headnum
D = H / N
elif self.layout == 'SBH':
SQ, B, H = q.shape
SKV = k.shape[0]
N = self.headnum
D = H / N
elif self.layout == 'BSND':
B, SQ, N, D = q.shape
SKV = k.shape[1]
elif self.layout == 'BNSD':
B, N, SQ, D = q.shape
SKV = k.shape[2]
else:
raise ValueError(f"got invalid input layout {self.layout}")
output_shape = (B, N, SQ, 8)
self.q = q
self.k = k
self.v = v
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
self.kvstart = self.kvstart if self.kvstart else [0 for _ in range(B)]
self.hasRealshift = (not realshift == None)
self.hasDropmask = (not dropMask == None)
self.hasPaddingmask = (not paddingMask == None)
self.hasAttenmask = (not attenMask == None)
# 待定目前设为nullptr
self.realshift = realshift if realshift else jt.zeros(B, N, SQ, SKV)
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
self.paddingMask = paddingMask if paddingMask else jt.zeros(
B, N, SQ, SKV)
self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
attr_code = f"""
op.jt_name = "flashattention";
FlashAttentionAttr *attr = new FlashAttentionAttr();
attr->scale = {self.scale};
attr->keepProb = {self.prob};
attr->preToken = {self.pretokens};
attr->nextToken = {self.nexttokens};
attr->headNum = {self.headnum};
attr->inputLayout = "{self.layout}";
attr->innerPrecise = {self.innerprecise};
attr->sparseMode = {self.sparsemode};
attr->psetype = {self.psetype};
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
op.op_attr.reset(attr);
"""
inputs = [
q, k, v, self.realshift, self.dropMask, self.paddingMask,
self.attenMask
]
result = flashattention_cmd(
"FlashAttention",
inputs,
output_dtypes=["float", "float", q.dtype],
output_shapes=[output_shape, output_shape, q.shape],
attr_code=attr_code)
self.maxout = result[0]
self.sumout = result[1]
self.attenout = result[2]
return self.attenout
def grad(self, dy):
attr_code = f"""
op.jt_name = "flashattentionbackward";
FlashAttentionAttr *attr = new FlashAttentionAttr();
attr->scale = {self.scale};
attr->keepProb = {self.prob};
attr->preToken = {self.pretokens};
attr->nextToken = {self.nexttokens};
attr->headNum = {self.headnum};
attr->inputLayout = "{self.layout}";
attr->innerPrecise = {self.innerprecise};
attr->sparseMode = {self.sparsemode};
attr->psetype = {self.psetype};
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
op.op_attr.reset(attr);
"""
inputs = [
self.q, self.k, self.v, dy, self.realshift, self.dropMask,
self.paddingMask, self.attenMask, self.maxout, self.sumout,
self.attenout
]
result = flashattention_cmd(
"FlashAttentionBackward",
inputs,
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
attr_code=attr_code)
return result

View File

@ -0,0 +1,88 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "flashattention_op_acl.h"
namespace jittor
{
FlashAttentionOpRunner::FlashAttentionOpRunner() : BaseOpRunner("FlashAttention")
{
}
void FlashAttentionOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
char *layout = const_cast<char *>(attr->inputLayout.data());
ret = aclnnFlashAttentionScoreV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFlashAttentionScoreV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreV2 failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward")
{
}
void FlashAttentionBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
char *layout = const_cast<char *>(attr->inputLayout.data());
ret = aclnnFlashAttentionScoreGradV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFlashAttentionScoreGradV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreGradV2 failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,27 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class FlashAttentionOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlashAttentionOpRunner();
};
class FlashAttentionBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlashAttentionBackwardOpRunner();
};
}

View File

@ -0,0 +1,85 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def flip_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class FlipACL(jt.Function):
def __init__(self):
super(FlipACL, self).__init__()
def execute(self, input, dim):
if type(dim) is tuple:
dim = list(dim)
if type(dim) is not list:
dim = [dim]
attr_code = f"""
op.jt_name = "flip";
ReduceAttr *attr = new ReduceAttr();
attr->axes = {{{', '.join(map(str, (list(dim))))}}};
attr->prod_dim = {len(dim)};
op.op_attr.reset(attr);
"""
self.attr_code = attr_code
result = flip_cmd("Flip", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=self.attr_code)[0]
return result
def grad(self, grad_output):
grad_input = flip_cmd("Flip", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=self.attr_code)[0]
return grad_input

View File

@ -0,0 +1,58 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "flip_op_acl.h"
namespace jittor
{
FlipOpRunner::FlipOpRunner() : BaseOpRunner("Flip")
{
}
void FlipOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
auto dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
ret = aclnnFlipGetWorkspaceSize(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFlip(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlip failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,16 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class FlipOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlipOpRunner();
};
}

View File

@ -0,0 +1,70 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def floor_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class FloorIntACL(jt.Function):
def __init__(self):
super(FloorIntACL, self).__init__()
def execute(self, input):
self.shape = input.shape
result = floor_cmd("Floor", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code="op.jt_name=\"floor\";")[0]
return result
def grad(self, grad_output):
return jt.zeros(self.shape, dtype=grad_output.dtype)

View File

@ -0,0 +1,56 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "floor_op_acl.h"
namespace jittor
{
FloorOpRunner::FloorOpRunner() : BaseOpRunner("Floor")
{
}
void FloorOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnFloorGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFloor(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFloor failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,16 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class FloorOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FloorOpRunner();
};
}

View File

@ -0,0 +1,126 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def gather_scatter_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class GatherACL(jt.Function):
def __init__(self):
super(GatherACL, self).__init__()
def execute(self, input, dim, index):
self.dim = dim
self.index = index
attr_code = f"""
op.jt_name = "gather";
GatherAttr *attr = new GatherAttr();
attr->dim = {dim};
op.op_attr.reset(attr);
"""
result = gather_scatter_cmd("Gather", [input, index],
output_dtypes=[input.dtype],
output_shapes=[index.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
attr_code = f"""
op.jt_name = "scatter";
ScatterAttr *attr = new ScatterAttr();
attr->axis = {self.dim};
attr->reduction = {1};
op.op_attr.reset(attr);
"""
grad_input = gather_scatter_cmd("Scatter",
[tmp, self.index, grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[tmp.shape],
attr_code=attr_code)[0]
return grad_input
class ScatterACL(jt.Function):
def __init__(self):
super(ScatterACL, self).__init__()
def execute(self, input, dim, index, src, reduce='void'):
self.dim = dim
self.index = index
self.reduce = reduce
attr_code = f"""
op.jt_name = "scatter";
ScatterAttr *attr = new ScatterAttr();
attr->axis = {dim};
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
op.op_attr.reset(attr);
"""
result = gather_scatter_cmd("Scatter", [input, self.index, src],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "gather";
GatherAttr *attr = new GatherAttr();
attr->dim = {self.dim};
op.op_attr.reset(attr);
"""
grad_input = gather_scatter_cmd("Gather", [grad_output, self.index],
output_dtypes=[grad_output.dtype],
output_shapes=[self.index.shape],
attr_code=attr_code)[0]
return grad_output, None, None, grad_input

View File

@ -0,0 +1,80 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "gather_scatter_op_acl.h"
namespace jittor
{
GatherOpRunner::GatherOpRunner() : BaseOpRunner("Gather")
{
}
void GatherOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
ret = aclnnGatherGetWorkspaceSize(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnGather(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnGather failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
ScatterOpRunner::ScatterOpRunner() : BaseOpRunner("Scatter")
{
}
void ScatterOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
ret = aclnnScatterGetWorkspaceSize(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnScatter(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnScatter failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class GatherOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
GatherOpRunner();
};
class ScatterOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
ScatterOpRunner();
};
}

View File

@ -0,0 +1,419 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def getitem_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
def getitem_forward(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
def caculate_shape(tensors):
if isinstance(tensors, jt.Var):
# tensors = tensors[0]
return tensors.shape
elif isinstance(tensors, (int, float)):
return []
elif isinstance(tensors, (list, tuple)):
# return [caculate_shape(tensor) for tensor in tensors]
sub_shape = caculate_shape(tensors[0])
return [len(tensors)] + sub_shape
else:
assert False, f"not implemented for {type(tensors)}"
def can_broadcast_and_shape(shape1, shape2):
"""
检查两个张量是否可以广播并返回广播后的形状
参数:
- shape1: 第一个张量的形状tuple list
- shape2: 第二个张量的形状tuple list
返回:
- can_broadcast: 布尔值表示是否可以广播
- broadcast_shape: 如果可以广播返回广播后的形状否则返回 None
"""
# 将形状转换为元组,以防输入是列表
shape1 = tuple(shape1)
shape2 = tuple(shape2)
# 使两个形状的长度一致通过在前面补1
len1, len2 = len(shape1), len(shape2)
if len1 < len2:
shape1 = (1, ) * (len2 - len1) + shape1
elif len2 < len1:
shape2 = (1, ) * (len1 - len2) + shape2
broadcast_shape = []
# 从最后一维开始检查每一维度
for dim1, dim2 in zip(shape1, shape2):
if dim1 == dim2:
broadcast_shape.append(dim1)
elif dim1 == 1:
broadcast_shape.append(dim2)
elif dim2 == 1:
broadcast_shape.append(dim1)
else:
# 如果在某一维度上不兼容,则不能广播
return False, None
return True, tuple(broadcast_shape)
class GetItemACL(jt.Function):
def __init__(self):
self.type_ = 'notype'
def stride(self, x, dim):
stride = 1
for i in range(dim + 1, len(x.shape)):
stride *= x.shape[i]
return stride
def execute(self, x, slices, return_x=None):
if isinstance(slices, jt.Var) and slices.dtype == 'bool':
# assert False, "not support bool type now"
#TODO:优化
assert x.shape == slices.shape, "shape not match"
output_len = slices.sum().item()
# output = jt.empty((output_len,),dtype=x.dtype)
x_len = x.numel()
output = jt.empty((x_len), dtype=x.dtype)
outputs = [output]
inputs = [x, slices]
# print(inputs,outputs)
# print(output.shape)
self.mask = slices
self.type_ = 'mask'
attr_code = f"""
op.jt_name = "maskedselect";
"""
result = getitem_cmd("MaskedSelect",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
result = result[:output_len]
result.sync()
return result
self.x_shape = x.shape
if not isinstance(slices, tuple):
slices = (slices, )
slices = list(slices)
for i, s in enumerate(slices):
if isinstance(s, int) and s < 0:
slices[i] = s + x.shape[i]
slices = tuple(slices)
slices_list = list(slices)
# if not isinstance(slices[0], slice):
#check slices contains slice type
contains_slice = False
for s in slices:
if not isinstance(s, jt.Var) and (isinstance(s, slice)
or s == Ellipsis):
contains_slice = True
break
if not contains_slice:
indices = []
output_shape = []
slices_len = len(slices)
boardcast_shape = caculate_shape(slices_list[0])
for ii in range(1, len(slices)):
dd, boardcast_shape = can_broadcast_and_shape(
boardcast_shape, caculate_shape(slices_list[ii]))
assert dd is True, "can not broadcast"
output_shape = boardcast_shape
output_shape += x.shape[slices_len:]
if output_shape == []:
output_shape = [1]
for ii in slices:
indices.append(jt.Var(ii).int32())
if isinstance(slices[0],
jt.Var) or isinstance(slices[0], int) or isinstance(
slices[0], list) or isinstance(slices[0], tuple):
self.indices = indices
inputs = [x] + indices
attr_code = f"""
op.jt_name = "index";
"""
self.type_ = 'index'
result = getitem_cmd("Index",
inputs=inputs,
output_dtypes=[x.dtype],
output_shapes=[output_shape],
attr_code=attr_code)[0]
result.sync()
return result
assert contains_slice, "slice type error"
x_dim = len(x.shape)
slices = list(slices)
for s in slices:
if not isinstance(s, jt.Var) and s == Ellipsis:
slices = slices[:slices.index(s)] + [
slice(None, None, None)
] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:]
break
slices = tuple(slices)
if len(slices) < x_dim:
slices += (slice(None, None, None), ) * (x_dim - len(slices))
inputs = [x]
sizes = []
begins = []
ends = []
steps = []
dims = []
squeeze_dims = []
extra_data = {}
if len(slices):
extra_data["a"] = len(slices)
for dim, s in enumerate(slices):
if isinstance(s, int):
s = slice(s, s + 1, 1)
squeeze_dims.append(dim)
if isinstance(s, jt.Var):
assert False, "jt.Var not supported"
start, stop, step = s.indices(x.size(dim))
size = (stop - start - 1) // step + 1
# stride = self.stride(x, dim) * step
sizes.append(size)
extra_data[str(dim * 3)] = start
extra_data[str(dim * 3 + 1)] = stop
extra_data[str(dim * 3 + 2)] = step
steps.append(step)
begins.append(start)
ends.append(stop)
dims.append(dim)
else:
extra_data["a"] = -1
sizes = [1]
steps = [1]
self.type_ = 'slicev2'
# for backward
self.begins = begins
self.ends = ends
self.steps = steps
self.dims = dims
self.slices = slices
attr_code = """
op.jt_name = "slicev2";
StrideAttr *attr = new StrideAttr();
int slice_dim = data["a"];
if(slice_dim == -1) {
attr->begins = {};
attr->ends = {};
attr->steps = {1};
attr->axes = {};
} else {
vector<long int> begins;
vector<long int> ends;
vector<long int> steps;
vector<long int> dims;
for(int dim = 0; dim < slice_dim; dim++) {
dims.push_back(dim);
begins.push_back(data[std::to_string(dim*3)]);
ends.push_back(data[std::to_string(dim*3+1)]);
steps.push_back(data[std::to_string(dim*3+2)]);
}
attr->begins = begins;
attr->ends = ends;
attr->steps = steps;
attr->axes = dims;
}
op.op_attr.reset(attr);
"""
result = getitem_forward("SliceV2",
inputs,
output_dtypes=[x.dtype],
output_shapes=[jt.empty(sizes).shape],
attr_code=attr_code,
extra_data=extra_data)[0]
self.squeeze_dims = squeeze_dims
for dim in squeeze_dims[::-1]:
result = jt.squeeze(result, dim)
result.sync()
return result
def grad(self, grad_output):
if self.type_ == 'index':
indices = self.indices
inputs = [grad_output] + indices
attr_code = f"""
op.jt_name = "indexputimplaccumulate";
"""
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
# breakpoint()
result = getitem_cmd("IndexPutImplAccumulate",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
result.sync()
return result, None
elif self.type_ == 'slicev2':
begins = self.begins
ends = self.ends
steps = self.steps
dims = self.dims
slices = self.slices
#注意前向的维数可能会被压缩,所以这里要还原
for dim in self.squeeze_dims:
grad_output = jt.unsqueeze(grad_output, dim)
#适配华为奇怪的要求最后一个维度的step必须是1
expand_dim = False
if isinstance(slices[-1], slice):
if slices[-1].step is not None and slices[-1].step != 1:
slices = slices + (slice(None, None, None), )
expand_dim = True
elif isinstance(slices[-1], int):
#注意最后一个维度是数字
slices = list(slices)
slices[-1] = slice(slices[-1], slices[-1] + 1, 1)
slices = tuple(slices)
slices = slices + (slice(None, None, None), )
expand_dim = True
else:
assert False, "not supported"
# x = x.unsqueeze(-1)
if expand_dim:
grad_output = grad_output.unsqueeze(-1)
self.x_shape = self.x_shape + (1, )
sizes = []
begins = []
ends = []
steps = []
dims = []
for dim, s in enumerate(slices):
if isinstance(s, int):
s = slice(s, s + 1, 1)
# squeeze_dims.append(dim)
if isinstance(s, jt.Var):
assert False, "jt.Var not supported"
start, stop, step = s.indices(self.x_shape[dim])
size = (stop - start - 1) // step + 1
# stride = self.stride(x, dim) * step
sizes.append(size)
steps.append(step)
begins.append(start)
ends.append(stop)
dims.append(dim)
if not sizes:
sizes = [1]
steps = [1]
attr_code = f"""
op.jt_name = "stridedsliceassignv2";
StrideAttr *attr = new StrideAttr();
attr->begins = {{ {", ".join(map(str, begins))} }};
attr->ends = {{ {", ".join(map(str, ends))} }};
attr->steps = {{ {", ".join(map(str, steps))} }};
attr->axes = {{ {", ".join(map(str, dims))} }};
op.op_attr.reset(attr);
"""
inputs = [grad_output]
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
result = getitem_cmd("StridedSliceAssignV2",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
result.sync()
if expand_dim:
result = result.squeeze(-1)
return result, None
elif self.type_ == 'mask':
return self.mask.float()
pass
else:
assert False, f"grad not implemented for {self.type_}"

View File

@ -0,0 +1,165 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "getitem_op_acl.h"
namespace jittor
{
MaskedSelectOpRunner::MaskedSelectOpRunner() : BaseOpRunner("MaskedSelect")
{
}
void MaskedSelectOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnMaskedSelectGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnMaskedSelect(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMaskedSelect failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
IndexOpRunner::IndexOpRunner() : BaseOpRunner("Index")
{
}
void IndexOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto input_num = in_.size();
auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);
ret = aclnnIndexGetWorkspaceSize(inputTensors[0], indexTensorList, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnIndex(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndex failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
SliceV2OpRunner::SliceV2OpRunner() : BaseOpRunner("SliceV2")
{
}
void SliceV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
ret = aclnnSliceV2GetWorkspaceSize(inputTensors[0], begins, ends, axes, steps, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnSliceV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSliceV2 failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
IndexPutImplAccumulateOpRunner::IndexPutImplAccumulateOpRunner() : BaseOpRunner("IndexPutImplAccumulate")
{
}
void IndexPutImplAccumulateOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto input_num = in_.size();
std::vector<aclTensor *> indexTensorList = {};
for (int i = 1; i < input_num; i++)
{
indexTensorList.push_back(inputTensors[i]);
}
auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1);
ret = aclnnIndexPutImplGetWorkspaceSize(outputTensors[0], indexTensorListInput, inputTensors[0], true, true, &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
StridedSliceAssignV2OpRunner::StridedSliceAssignV2OpRunner() : BaseOpRunner("StridedSliceAssignV2")
{
}
void StridedSliceAssignV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
ret = aclnnStridedSliceAssignV2GetWorkspaceSize(outputTensors[0], inputTensors[0], begins, ends, steps, axes, &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnStridedSliceAssignV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStridedSliceAssignV2 failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,57 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class MaskedSelectOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
MaskedSelectOpRunner();
};
class IndexOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
IndexOpRunner();
};
class SliceV2OpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SliceV2OpRunner();
};
class IndexPutImplAccumulateOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
IndexPutImplAccumulateOpRunner();
};
class StridedSliceAssignV2OpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
StridedSliceAssignV2OpRunner();
};
}

View File

@ -0,0 +1,107 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def range_forward(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
class IndexACL(jt.Function):
def __init__(self):
super(IndexACL, self).__init__()
def execute(self, inshape: list, dim=None, dtype="int32"):
# zeros a tensor, shape is inshape, dtype is dtype
dim_input = dim
if dim == None:
dim = [i for i in range(len(inshape))]
elif type(dim) == int:
dim = [dim]
results = []
extra_data = {}
extra_data["dim_count"] = len(dim)
for i, d in enumerate(dim):
max_len = inshape[d]
extra_data[f"dim_{i}_start"] = 0
extra_data[f"dim_{i}_end"] = max_len
extra_data[f"dim_{i}_step"] = 1
tmp = jt.zeros(max_len, dtype=dtype)
range_attr_code = f"""
op.jt_name = "range";
RangeAttr *attr = new RangeAttr();
attr->start = data["dim_{i}_start"];
attr->end = data["dim_{i}_end"];
attr->step = data["dim_{i}_step"];
op.op_attr.reset(attr);
"""
result = range_forward("Range", [],
output_dtypes=[tmp.dtype],
output_shapes=[tmp.shape],
attr_code=range_attr_code,
extra_data=extra_data)[0]
broadcast_dims = list(range(len(inshape)))
broadcast_dims.remove(d)
result = jt.broadcast(result, shape=inshape, dims=broadcast_dims)
results.append(result)
if len(results) != 1 or dim_input == None:
return tuple(results)
elif len(results) == 1 and dim_input != None:
return results[0]
else:
return results
def grad(self, grad_output):
return grad_output

Some files were not shown because too many files have changed in this diff Show More