buildroot/package/libopenssl/0009-RVV-optimized-chacha20...

313 lines
13 KiB
Diff

From 928cc0e2a0b1bbf48b4f4047708f04c74f1edc1a Mon Sep 17 00:00:00 2001
From: "lff@Snode" <junzhao.liang@spacemit.com>
Date: Mon, 25 Mar 2024 10:57:54 +0800
Subject: [PATCH] RVV optimized chacha20
---
crypto/chacha/chacha_enc.c | 223 ++++++++++++++++++++++++++++++-
crypto/evp/e_chacha20_poly1305.c | 11 +-
include/crypto/chacha.h | 7 +
3 files changed, 239 insertions(+), 2 deletions(-)
diff --git a/crypto/chacha/chacha_enc.c b/crypto/chacha/chacha_enc.c
index 18251ea..0231b8f 100644
--- a/crypto/chacha/chacha_enc.c
+++ b/crypto/chacha/chacha_enc.c
@@ -11,7 +11,7 @@
#include <string.h>
-#include "crypto/chacha.h"
+#include "include/crypto/chacha.h"
#include "crypto/ctype.h"
typedef unsigned int u32;
@@ -128,3 +128,224 @@ void ChaCha20_ctr32(unsigned char *out, const unsigned char *inp,
input[12]++;
}
}
+
+#if defined(__riscv_vector)
+#include <riscv_vector.h>
+#define QUARTERROUND_RVV(n, vl) \
+ { \
+ va = __riscv_vadd_vv_u32m##n(va, vb, vl); \
+ vd = __riscv_vxor_vv_u32m##n(vd, va, vl); \
+ vd_t = __riscv_vsll_vx_u32m##n(vd, 16, vl); \
+ vd = __riscv_vsrl_vx_u32m##n(vd, 16, vl); \
+ vd = __riscv_vor_vv_u32m##n(vd, vd_t, vl); \
+ \
+ vc = __riscv_vadd_vv_u32m##n(vc, vd, vl); \
+ vb = __riscv_vxor_vv_u32m##n(vb, vc, vl); \
+ vb_t = __riscv_vsll_vx_u32m##n(vb, 12, vl); \
+ vb = __riscv_vsrl_vx_u32m##n(vb, 20, vl); \
+ vb = __riscv_vor_vv_u32m##n(vb, vb_t, vl); \
+ \
+ va = __riscv_vadd_vv_u32m##n(va, vb, vl); \
+ vd = __riscv_vxor_vv_u32m##n(vd, va, vl); \
+ vd_t = __riscv_vsll_vx_u32m##n(vd, 8, vl); \
+ vd = __riscv_vsrl_vx_u32m##n(vd, 24, vl); \
+ vd = __riscv_vor_vv_u32m##n(vd, vd_t, vl); \
+ \
+ vc = __riscv_vadd_vv_u32m##n(vc, vd, vl); \
+ vb = __riscv_vxor_vv_u32m##n(vb, vc, vl); \
+ vb_t = __riscv_vsll_vx_u32m##n(vb, 7, vl); \
+ vb = __riscv_vsrl_vx_u32m##n(vb, 25, vl); \
+ vb = __riscv_vor_vv_u32m##n(vb, vb_t, vl); \
+ }
+
+void ChaCha20_ctr32_r(unsigned char *out, const unsigned char *inp,
+ size_t len, size_t blocks, const unsigned int key[8],
+ const unsigned int counter[4])
+{
+ size_t i, vl;
+ u8 outbuf[4*16*8]; // 4Bytes x 16elems x 8blocks
+
+ vuint32m1_t v00, v01, v02, v03, v04, v05, v06, v07, v08, v09, v10, v11, v12, v13, v14, v15;
+ vuint8m8_t vkey, vsrc;
+ vuint32m4_t va, vb, vc, vd, vb_t, vd_t;
+ vuint32m1_t vtmp0, vtmp1, vtmp2, vtmp3;
+
+ /* deal with 8 blocks at a time */
+ vuint32m1_t v12_og = __riscv_vid_v_u32m1(8);
+ v12_og = __riscv_vadd_vx_u32m1(v12_og, counter[0], 8);
+
+ while (len > 0) {
+ /* prepare 16 vectors for each elements */
+ v00 = __riscv_vmv_v_x_u32m1(0x61707865, 8);
+ v01 = __riscv_vmv_v_x_u32m1(0x3320646e, 8);
+ v02 = __riscv_vmv_v_x_u32m1(0x79622d32, 8);
+ v03 = __riscv_vmv_v_x_u32m1(0x6b206574, 8);
+ v04 = __riscv_vmv_v_x_u32m1(key[0], 8);
+ v05 = __riscv_vmv_v_x_u32m1(key[1], 8);
+ v06 = __riscv_vmv_v_x_u32m1(key[2], 8);
+ v07 = __riscv_vmv_v_x_u32m1(key[3], 8);
+ v08 = __riscv_vmv_v_x_u32m1(key[4], 8);
+ v09 = __riscv_vmv_v_x_u32m1(key[5], 8);
+ v10 = __riscv_vmv_v_x_u32m1(key[6], 8);
+ v11 = __riscv_vmv_v_x_u32m1(key[7], 8);
+ v12 = v12_og;
+ v13 = __riscv_vmv_v_x_u32m1(counter[1], 8);
+ v14 = __riscv_vmv_v_x_u32m1(counter[2], 8);
+ v15 = __riscv_vmv_v_x_u32m1(counter[3], 8);
+
+ /* combine and compute 4 vectors simultaneously */
+ va = __riscv_vset_v_u32m1_u32m4(va, 0, v00);
+ va = __riscv_vset_v_u32m1_u32m4(va, 1, v01);
+ va = __riscv_vset_v_u32m1_u32m4(va, 2, v02);
+ va = __riscv_vset_v_u32m1_u32m4(va, 3, v03);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 0, v04);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 1, v05);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 2, v06);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 3, v07);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 0, v08);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 1, v09);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 2, v10);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 3, v11);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 0, v12);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 1, v13);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 2, v14);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 3, v15);
+
+ for (i = 0; i < 10; ++i) {
+ /* fisrt half quarter round */
+ QUARTERROUND_RVV(4, 32);
+
+ /* rerange */
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vb, 0);
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vb, 1);
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vb, 2);
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vb, 3);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 0, vtmp1);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 1, vtmp2);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 2, vtmp3);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 3, vtmp0);
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vc, 0);
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vc, 1);
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vc, 2);
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vc, 3);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 0, vtmp2);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 1, vtmp3);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 2, vtmp0);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 3, vtmp1);
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vd, 0);
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vd, 1);
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vd, 2);
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vd, 3);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 0, vtmp3);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 1, vtmp0);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 2, vtmp1);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 3, vtmp2);
+
+ /* second half quarter round */
+ QUARTERROUND_RVV(4, 32);
+
+ /* recover */
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vb, 0);
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vb, 1);
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vb, 2);
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vb, 3);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 0, vtmp0);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 1, vtmp1);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 2, vtmp2);
+ vb = __riscv_vset_v_u32m1_u32m4(vb, 3, vtmp3);
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vc, 0);
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vc, 1);
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vc, 2);
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vc, 3);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 0, vtmp0);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 1, vtmp1);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 2, vtmp2);
+ vc = __riscv_vset_v_u32m1_u32m4(vc, 3, vtmp3);
+ vtmp3 = __riscv_vget_v_u32m4_u32m1(vd, 0);
+ vtmp0 = __riscv_vget_v_u32m4_u32m1(vd, 1);
+ vtmp1 = __riscv_vget_v_u32m4_u32m1(vd, 2);
+ vtmp2 = __riscv_vget_v_u32m4_u32m1(vd, 3);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 0, vtmp0);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 1, vtmp1);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 2, vtmp2);
+ vd = __riscv_vset_v_u32m1_u32m4(vd, 3, vtmp3);
+
+ }
+
+ /* split */
+ v00 = __riscv_vget_v_u32m4_u32m1(va, 0);
+ v01 = __riscv_vget_v_u32m4_u32m1(va, 1);
+ v02 = __riscv_vget_v_u32m4_u32m1(va, 2);
+ v03 = __riscv_vget_v_u32m4_u32m1(va, 3);
+ v04 = __riscv_vget_v_u32m4_u32m1(vb, 0);
+ v05 = __riscv_vget_v_u32m4_u32m1(vb, 1);
+ v06 = __riscv_vget_v_u32m4_u32m1(vb, 2);
+ v07 = __riscv_vget_v_u32m4_u32m1(vb, 3);
+ v08 = __riscv_vget_v_u32m4_u32m1(vc, 0);
+ v09 = __riscv_vget_v_u32m4_u32m1(vc, 1);
+ v10 = __riscv_vget_v_u32m4_u32m1(vc, 2);
+ v11 = __riscv_vget_v_u32m4_u32m1(vc, 3);
+ v12 = __riscv_vget_v_u32m4_u32m1(vd, 0);
+ v13 = __riscv_vget_v_u32m4_u32m1(vd, 1);
+ v14 = __riscv_vget_v_u32m4_u32m1(vd, 2);
+ v15 = __riscv_vget_v_u32m4_u32m1(vd, 3);
+
+ /* x[i] + input[i] */
+ v00 = __riscv_vadd_vx_u32m1(v00, 0x61707865, 8);
+ v01 = __riscv_vadd_vx_u32m1(v01, 0x3320646e, 8);
+ v02 = __riscv_vadd_vx_u32m1(v02, 0x79622d32, 8);
+ v03 = __riscv_vadd_vx_u32m1(v03, 0x6b206574, 8);
+ v04 = __riscv_vadd_vx_u32m1(v04, key[0], 8);
+ v05 = __riscv_vadd_vx_u32m1(v05, key[1], 8);
+ v06 = __riscv_vadd_vx_u32m1(v06, key[2], 8);
+ v07 = __riscv_vadd_vx_u32m1(v07, key[3], 8);
+ v08 = __riscv_vadd_vx_u32m1(v08, key[4], 8);
+ v09 = __riscv_vadd_vx_u32m1(v09, key[5], 8);
+ v10 = __riscv_vadd_vx_u32m1(v10, key[6], 8);
+ v11 = __riscv_vadd_vx_u32m1(v11, key[7], 8);
+ v12 = __riscv_vadd_vv_u32m1(v12, v12_og, 8);
+ v13 = __riscv_vadd_vx_u32m1(v13, counter[1], 8);
+ v14 = __riscv_vadd_vx_u32m1(v14, counter[2], 8);
+ v15 = __riscv_vadd_vx_u32m1(v15, counter[3], 8);
+
+ /* counter++ */
+ v12_og = __riscv_vadd_vx_u32m1(v12_og, 8, 8);
+
+ /* XOR input and store */
+ int blk = blocks > 8 ? 8 : blocks;
+
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 0, 64, v00, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 1, 64, v01, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 2, 64, v02, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 3, 64, v03, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 4, 64, v04, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 5, 64, v05, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 6, 64, v06, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 7, 64, v07, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 8, 64, v08, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 9, 64, v09, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 10, 64, v10, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 11, 64, v11, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 12, 64, v12, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 13, 64, v13, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 14, 64, v14, blk);
+ __riscv_vsse32_v_u32m1((u32 *)outbuf + 15, 64, v15, blk);
+
+ blocks -= blk;
+
+ /* e32m1*16 = e8m8*2 */
+ for (i = 0; (len > 0) && (i < 2); ++i) {
+ vl = __riscv_vsetvl_e8m8(len);
+ vsrc = __riscv_vle8_v_u8m8(inp, vl);
+ vkey = __riscv_vle8_v_u8m8(outbuf + i * 256, vl);
+ vsrc = __riscv_vxor_vv_u8m8(vsrc, vkey, vl);
+
+ __riscv_vse8_v_u8m8(out, vsrc, vl);
+
+ out += vl;
+ inp += vl;
+ len -= vl;
+ }
+ }
+}
+#endif /* __riscv_vector */
\ No newline at end of file
diff --git a/crypto/evp/e_chacha20_poly1305.c b/crypto/evp/e_chacha20_poly1305.c
index bdc406b..feaf7a6 100644
--- a/crypto/evp/e_chacha20_poly1305.c
+++ b/crypto/evp/e_chacha20_poly1305.c
@@ -8,6 +8,7 @@
*/
#include <stdio.h>
+#include <riscv_vector.h>
#include "internal/cryptlib.h"
#ifndef OPENSSL_NO_CHACHA
@@ -16,7 +17,7 @@
# include <openssl/objects.h>
# include "evp_local.h"
# include "crypto/evp.h"
-# include "crypto/chacha.h"
+# include "include/crypto/chacha.h"
typedef struct {
union {
@@ -102,11 +103,19 @@ static int chacha_cipher(EVP_CIPHER_CTX * ctx, unsigned char *out,
blocks -= ctr32;
ctr32 = 0;
}
+
+#if defined(__riscv_vector)
+ ChaCha20_ctr32_r(out, inp, len, blocks, key->key.d, key->counter);
+ inp += len;
+ out += len;
+ len -= len;
+#else
blocks *= CHACHA_BLK_SIZE;
ChaCha20_ctr32(out, inp, blocks, key->key.d, key->counter);
len -= blocks;
inp += blocks;
out += blocks;
+#endif
key->counter[0] = ctr32;
if (ctr32 == 0) key->counter[1]++;
diff --git a/include/crypto/chacha.h b/include/crypto/chacha.h
index 4029400..7ebf4d8 100644
--- a/include/crypto/chacha.h
+++ b/include/crypto/chacha.h
@@ -26,6 +26,13 @@
void ChaCha20_ctr32(unsigned char *out, const unsigned char *inp,
size_t len, const unsigned int key[8],
const unsigned int counter[4]);
+
+#if defined(__riscv_vector)
+void ChaCha20_ctr32_r(unsigned char *out, const unsigned char *inp,
+ size_t len, size_t blocks, const unsigned int key[8],
+ const unsigned int counter[4]);
+#endif
+
/*
* You can notice that there is no key setup procedure. Because it's
* as trivial as collecting bytes into 32-bit elements, it's reckoned
--
2.25.1