313 lines
13 KiB
Diff
313 lines
13 KiB
Diff
From 928cc0e2a0b1bbf48b4f4047708f04c74f1edc1a Mon Sep 17 00:00:00 2001
|
|
From: "lff@Snode" <junzhao.liang@spacemit.com>
|
|
Date: Mon, 25 Mar 2024 10:57:54 +0800
|
|
Subject: [PATCH] RVV optimized chacha20
|
|
|
|
---
|
|
crypto/chacha/chacha_enc.c | 223 ++++++++++++++++++++++++++++++-
|
|
crypto/evp/e_chacha20_poly1305.c | 11 +-
|
|
include/crypto/chacha.h | 7 +
|
|
3 files changed, 239 insertions(+), 2 deletions(-)
|
|
|
|
diff --git a/crypto/chacha/chacha_enc.c b/crypto/chacha/chacha_enc.c
|
|
index 18251ea..0231b8f 100644
|
|
--- a/crypto/chacha/chacha_enc.c
|
|
+++ b/crypto/chacha/chacha_enc.c
|
|
@@ -11,7 +11,7 @@
|
|
|
|
#include <string.h>
|
|
|
|
-#include "crypto/chacha.h"
|
|
+#include "include/crypto/chacha.h"
|
|
#include "crypto/ctype.h"
|
|
|
|
typedef unsigned int u32;
|
|
@@ -128,3 +128,224 @@ void ChaCha20_ctr32(unsigned char *out, const unsigned char *inp,
|
|
input[12]++;
|
|
}
|
|
}
|
|
+
|
|
+#if defined(__riscv_vector)
|
|
+#include <riscv_vector.h>
|
|
+#define QUARTERROUND_RVV(n, vl) \
|
|
+ { \
|
|
+ va = __riscv_vadd_vv_u32m##n(va, vb, vl); \
|
|
+ vd = __riscv_vxor_vv_u32m##n(vd, va, vl); \
|
|
+ vd_t = __riscv_vsll_vx_u32m##n(vd, 16, vl); \
|
|
+ vd = __riscv_vsrl_vx_u32m##n(vd, 16, vl); \
|
|
+ vd = __riscv_vor_vv_u32m##n(vd, vd_t, vl); \
|
|
+ \
|
|
+ vc = __riscv_vadd_vv_u32m##n(vc, vd, vl); \
|
|
+ vb = __riscv_vxor_vv_u32m##n(vb, vc, vl); \
|
|
+ vb_t = __riscv_vsll_vx_u32m##n(vb, 12, vl); \
|
|
+ vb = __riscv_vsrl_vx_u32m##n(vb, 20, vl); \
|
|
+ vb = __riscv_vor_vv_u32m##n(vb, vb_t, vl); \
|
|
+ \
|
|
+ va = __riscv_vadd_vv_u32m##n(va, vb, vl); \
|
|
+ vd = __riscv_vxor_vv_u32m##n(vd, va, vl); \
|
|
+ vd_t = __riscv_vsll_vx_u32m##n(vd, 8, vl); \
|
|
+ vd = __riscv_vsrl_vx_u32m##n(vd, 24, vl); \
|
|
+ vd = __riscv_vor_vv_u32m##n(vd, vd_t, vl); \
|
|
+ \
|
|
+ vc = __riscv_vadd_vv_u32m##n(vc, vd, vl); \
|
|
+ vb = __riscv_vxor_vv_u32m##n(vb, vc, vl); \
|
|
+ vb_t = __riscv_vsll_vx_u32m##n(vb, 7, vl); \
|
|
+ vb = __riscv_vsrl_vx_u32m##n(vb, 25, vl); \
|
|
+ vb = __riscv_vor_vv_u32m##n(vb, vb_t, vl); \
|
|
+ }
|
|
+
|
|
+void ChaCha20_ctr32_r(unsigned char *out, const unsigned char *inp,
|
|
+ size_t len, size_t blocks, const unsigned int key[8],
|
|
+ const unsigned int counter[4])
|
|
+{
|
|
+ size_t i, vl;
|
|
+ u8 outbuf[4*16*8]; // 4Bytes x 16elems x 8blocks
|
|
+
|
|
+ vuint32m1_t v00, v01, v02, v03, v04, v05, v06, v07, v08, v09, v10, v11, v12, v13, v14, v15;
|
|
+ vuint8m8_t vkey, vsrc;
|
|
+ vuint32m4_t va, vb, vc, vd, vb_t, vd_t;
|
|
+ vuint32m1_t vtmp0, vtmp1, vtmp2, vtmp3;
|
|
+
|
|
+ /* deal with 8 blocks at a time */
|
|
+ vuint32m1_t v12_og = __riscv_vid_v_u32m1(8);
|
|
+ v12_og = __riscv_vadd_vx_u32m1(v12_og, counter[0], 8);
|
|
+
|
|
+ while (len > 0) {
|
|
+ /* prepare 16 vectors for each elements */
|
|
+ v00 = __riscv_vmv_v_x_u32m1(0x61707865, 8);
|
|
+ v01 = __riscv_vmv_v_x_u32m1(0x3320646e, 8);
|
|
+ v02 = __riscv_vmv_v_x_u32m1(0x79622d32, 8);
|
|
+ v03 = __riscv_vmv_v_x_u32m1(0x6b206574, 8);
|
|
+ v04 = __riscv_vmv_v_x_u32m1(key[0], 8);
|
|
+ v05 = __riscv_vmv_v_x_u32m1(key[1], 8);
|
|
+ v06 = __riscv_vmv_v_x_u32m1(key[2], 8);
|
|
+ v07 = __riscv_vmv_v_x_u32m1(key[3], 8);
|
|
+ v08 = __riscv_vmv_v_x_u32m1(key[4], 8);
|
|
+ v09 = __riscv_vmv_v_x_u32m1(key[5], 8);
|
|
+ v10 = __riscv_vmv_v_x_u32m1(key[6], 8);
|
|
+ v11 = __riscv_vmv_v_x_u32m1(key[7], 8);
|
|
+ v12 = v12_og;
|
|
+ v13 = __riscv_vmv_v_x_u32m1(counter[1], 8);
|
|
+ v14 = __riscv_vmv_v_x_u32m1(counter[2], 8);
|
|
+ v15 = __riscv_vmv_v_x_u32m1(counter[3], 8);
|
|
+
|
|
+ /* combine and compute 4 vectors simultaneously */
|
|
+ va = __riscv_vset_v_u32m1_u32m4(va, 0, v00);
|
|
+ va = __riscv_vset_v_u32m1_u32m4(va, 1, v01);
|
|
+ va = __riscv_vset_v_u32m1_u32m4(va, 2, v02);
|
|
+ va = __riscv_vset_v_u32m1_u32m4(va, 3, v03);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 0, v04);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 1, v05);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 2, v06);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 3, v07);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 0, v08);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 1, v09);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 2, v10);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 3, v11);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 0, v12);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 1, v13);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 2, v14);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 3, v15);
|
|
+
|
|
+ for (i = 0; i < 10; ++i) {
|
|
+ /* fisrt half quarter round */
|
|
+ QUARTERROUND_RVV(4, 32);
|
|
+
|
|
+ /* rerange */
|
|
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vb, 0);
|
|
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vb, 1);
|
|
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vb, 2);
|
|
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vb, 3);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 0, vtmp1);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 1, vtmp2);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 2, vtmp3);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 3, vtmp0);
|
|
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vc, 0);
|
|
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vc, 1);
|
|
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vc, 2);
|
|
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vc, 3);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 0, vtmp2);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 1, vtmp3);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 2, vtmp0);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 3, vtmp1);
|
|
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vd, 0);
|
|
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vd, 1);
|
|
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vd, 2);
|
|
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vd, 3);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 0, vtmp3);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 1, vtmp0);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 2, vtmp1);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 3, vtmp2);
|
|
+
|
|
+ /* second half quarter round */
|
|
+ QUARTERROUND_RVV(4, 32);
|
|
+
|
|
+ /* recover */
|
|
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vb, 0);
|
|
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vb, 1);
|
|
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vb, 2);
|
|
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vb, 3);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 0, vtmp0);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 1, vtmp1);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 2, vtmp2);
|
|
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 3, vtmp3);
|
|
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vc, 0);
|
|
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vc, 1);
|
|
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vc, 2);
|
|
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vc, 3);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 0, vtmp0);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 1, vtmp1);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 2, vtmp2);
|
|
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 3, vtmp3);
|
|
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vd, 0);
|
|
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vd, 1);
|
|
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vd, 2);
|
|
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vd, 3);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 0, vtmp0);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 1, vtmp1);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 2, vtmp2);
|
|
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 3, vtmp3);
|
|
+
|
|
+ }
|
|
+
|
|
+ /* split */
|
|
+ v00 = __riscv_vget_v_u32m4_u32m1(va, 0);
|
|
+ v01 = __riscv_vget_v_u32m4_u32m1(va, 1);
|
|
+ v02 = __riscv_vget_v_u32m4_u32m1(va, 2);
|
|
+ v03 = __riscv_vget_v_u32m4_u32m1(va, 3);
|
|
+ v04 = __riscv_vget_v_u32m4_u32m1(vb, 0);
|
|
+ v05 = __riscv_vget_v_u32m4_u32m1(vb, 1);
|
|
+ v06 = __riscv_vget_v_u32m4_u32m1(vb, 2);
|
|
+ v07 = __riscv_vget_v_u32m4_u32m1(vb, 3);
|
|
+ v08 = __riscv_vget_v_u32m4_u32m1(vc, 0);
|
|
+ v09 = __riscv_vget_v_u32m4_u32m1(vc, 1);
|
|
+ v10 = __riscv_vget_v_u32m4_u32m1(vc, 2);
|
|
+ v11 = __riscv_vget_v_u32m4_u32m1(vc, 3);
|
|
+ v12 = __riscv_vget_v_u32m4_u32m1(vd, 0);
|
|
+ v13 = __riscv_vget_v_u32m4_u32m1(vd, 1);
|
|
+ v14 = __riscv_vget_v_u32m4_u32m1(vd, 2);
|
|
+ v15 = __riscv_vget_v_u32m4_u32m1(vd, 3);
|
|
+
|
|
+ /* x[i] + input[i] */
|
|
+ v00 = __riscv_vadd_vx_u32m1(v00, 0x61707865, 8);
|
|
+ v01 = __riscv_vadd_vx_u32m1(v01, 0x3320646e, 8);
|
|
+ v02 = __riscv_vadd_vx_u32m1(v02, 0x79622d32, 8);
|
|
+ v03 = __riscv_vadd_vx_u32m1(v03, 0x6b206574, 8);
|
|
+ v04 = __riscv_vadd_vx_u32m1(v04, key[0], 8);
|
|
+ v05 = __riscv_vadd_vx_u32m1(v05, key[1], 8);
|
|
+ v06 = __riscv_vadd_vx_u32m1(v06, key[2], 8);
|
|
+ v07 = __riscv_vadd_vx_u32m1(v07, key[3], 8);
|
|
+ v08 = __riscv_vadd_vx_u32m1(v08, key[4], 8);
|
|
+ v09 = __riscv_vadd_vx_u32m1(v09, key[5], 8);
|
|
+ v10 = __riscv_vadd_vx_u32m1(v10, key[6], 8);
|
|
+ v11 = __riscv_vadd_vx_u32m1(v11, key[7], 8);
|
|
+ v12 = __riscv_vadd_vv_u32m1(v12, v12_og, 8);
|
|
+ v13 = __riscv_vadd_vx_u32m1(v13, counter[1], 8);
|
|
+ v14 = __riscv_vadd_vx_u32m1(v14, counter[2], 8);
|
|
+ v15 = __riscv_vadd_vx_u32m1(v15, counter[3], 8);
|
|
+
|
|
+ /* counter++ */
|
|
+ v12_og = __riscv_vadd_vx_u32m1(v12_og, 8, 8);
|
|
+
|
|
+ /* XOR input and store */
|
|
+ int blk = blocks > 8 ? 8 : blocks;
|
|
+
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 0, 64, v00, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 1, 64, v01, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 2, 64, v02, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 3, 64, v03, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 4, 64, v04, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 5, 64, v05, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 6, 64, v06, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 7, 64, v07, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 8, 64, v08, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 9, 64, v09, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 10, 64, v10, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 11, 64, v11, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 12, 64, v12, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 13, 64, v13, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 14, 64, v14, blk);
|
|
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 15, 64, v15, blk);
|
|
+
|
|
+ blocks -= blk;
|
|
+
|
|
+ /* e32m1*16 = e8m8*2 */
|
|
+ for (i = 0; (len > 0) && (i < 2); ++i) {
|
|
+ vl = __riscv_vsetvl_e8m8(len);
|
|
+ vsrc = __riscv_vle8_v_u8m8(inp, vl);
|
|
+ vkey = __riscv_vle8_v_u8m8(outbuf + i * 256, vl);
|
|
+ vsrc = __riscv_vxor_vv_u8m8(vsrc, vkey, vl);
|
|
+
|
|
+ __riscv_vse8_v_u8m8(out, vsrc, vl);
|
|
+
|
|
+ out += vl;
|
|
+ inp += vl;
|
|
+ len -= vl;
|
|
+ }
|
|
+ }
|
|
+}
|
|
+#endif /* __riscv_vector */
|
|
\ No newline at end of file
|
|
diff --git a/crypto/evp/e_chacha20_poly1305.c b/crypto/evp/e_chacha20_poly1305.c
|
|
index bdc406b..feaf7a6 100644
|
|
--- a/crypto/evp/e_chacha20_poly1305.c
|
|
+++ b/crypto/evp/e_chacha20_poly1305.c
|
|
@@ -8,6 +8,7 @@
|
|
*/
|
|
|
|
#include <stdio.h>
|
|
+#include <riscv_vector.h>
|
|
#include "internal/cryptlib.h"
|
|
|
|
#ifndef OPENSSL_NO_CHACHA
|
|
@@ -16,7 +17,7 @@
|
|
# include <openssl/objects.h>
|
|
# include "evp_local.h"
|
|
# include "crypto/evp.h"
|
|
-# include "crypto/chacha.h"
|
|
+# include "include/crypto/chacha.h"
|
|
|
|
typedef struct {
|
|
union {
|
|
@@ -102,11 +103,19 @@ static int chacha_cipher(EVP_CIPHER_CTX * ctx, unsigned char *out,
|
|
blocks -= ctr32;
|
|
ctr32 = 0;
|
|
}
|
|
+
|
|
+#if defined(__riscv_vector)
|
|
+ ChaCha20_ctr32_r(out, inp, len, blocks, key->key.d, key->counter);
|
|
+ inp += len;
|
|
+ out += len;
|
|
+ len -= len;
|
|
+#else
|
|
blocks *= CHACHA_BLK_SIZE;
|
|
ChaCha20_ctr32(out, inp, blocks, key->key.d, key->counter);
|
|
len -= blocks;
|
|
inp += blocks;
|
|
out += blocks;
|
|
+#endif
|
|
|
|
key->counter[0] = ctr32;
|
|
if (ctr32 == 0) key->counter[1]++;
|
|
diff --git a/include/crypto/chacha.h b/include/crypto/chacha.h
|
|
index 4029400..7ebf4d8 100644
|
|
--- a/include/crypto/chacha.h
|
|
+++ b/include/crypto/chacha.h
|
|
@@ -26,6 +26,13 @@
|
|
void ChaCha20_ctr32(unsigned char *out, const unsigned char *inp,
|
|
size_t len, const unsigned int key[8],
|
|
const unsigned int counter[4]);
|
|
+
|
|
+#if defined(__riscv_vector)
|
|
+void ChaCha20_ctr32_r(unsigned char *out, const unsigned char *inp,
|
|
+ size_t len, size_t blocks, const unsigned int key[8],
|
|
+ const unsigned int counter[4]);
|
|
+#endif
|
|
+
|
|
/*
|
|
* You can notice that there is no key setup procedure. Because it's
|
|
* as trivial as collecting bytes into 32-bit elements, it's reckoned
|
|
--
|
|
2.25.1
|
|
|