lang: Add `discriminator` argument to `#[account]` attribute (#3149)
This commit is contained in:
parent
4853cd1da7
commit
d73983d3db
|
@ -439,7 +439,7 @@ jobs:
|
|||
path: tests/safety-checks
|
||||
- cmd: cd tests/custom-coder && anchor test --skip-lint && npx tsc --noEmit
|
||||
path: tests/custom-coder
|
||||
- cmd: cd tests/custom-discriminator && anchor test && npx tsc --noEmit
|
||||
- cmd: cd tests/custom-discriminator && anchor test
|
||||
path: tests/custom-discriminator
|
||||
- cmd: cd tests/validator-clone && anchor test --skip-lint && npx tsc --noEmit
|
||||
path: tests/validator-clone
|
||||
|
|
|
@ -31,6 +31,7 @@ The minor version will be incremented upon a breaking change and the patch versi
|
|||
- client: Add `internal_rpc` method for `mock` feature ([#3135](https://github.com/coral-xyz/anchor/pull/3135)).
|
||||
- lang: Add `#[instruction]` attribute proc-macro to override default instruction discriminators ([#3137](https://github.com/coral-xyz/anchor/pull/3137)).
|
||||
- lang: Use associated discriminator constants instead of hardcoding in `#[account]` ([#3144](https://github.com/coral-xyz/anchor/pull/3144)).
|
||||
- lang: Add `discriminator` argument to `#[account]` attribute ([#3149](https://github.com/coral-xyz/anchor/pull/3149)).
|
||||
|
||||
### Fixes
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ use syn::{
|
|||
parse::{Parse, ParseStream},
|
||||
parse_macro_input,
|
||||
token::{Comma, Paren},
|
||||
Ident, LitStr,
|
||||
Expr, Ident, Lit, LitStr, Token,
|
||||
};
|
||||
|
||||
mod id;
|
||||
|
@ -31,6 +31,22 @@ mod id;
|
|||
/// check this discriminator. If it doesn't match, an invalid account was given,
|
||||
/// and the account deserialization will exit with an error.
|
||||
///
|
||||
/// # Args
|
||||
///
|
||||
/// - `discriminator`: Override the default 8-byte discriminator
|
||||
///
|
||||
/// **Usage:** `discriminator = <CONST_EXPR>`
|
||||
///
|
||||
/// All constant expressions are supported.
|
||||
///
|
||||
/// **Examples:**
|
||||
///
|
||||
/// - `discriminator = 0` (shortcut for `[0]`)
|
||||
/// - `discriminator = [1, 2, 3, 4]`
|
||||
/// - `discriminator = b"hi"`
|
||||
/// - `discriminator = MY_DISC`
|
||||
/// - `discriminator = get_disc(...)`
|
||||
///
|
||||
/// # Zero Copy Deserialization
|
||||
///
|
||||
/// **WARNING**: Zero copy deserialization is an experimental feature. It's
|
||||
|
@ -83,23 +99,21 @@ pub fn account(
|
|||
let account_name_str = account_name.to_string();
|
||||
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
|
||||
|
||||
let discriminator: proc_macro2::TokenStream = {
|
||||
let discriminator = args.discriminator.unwrap_or_else(|| {
|
||||
// Namespace the discriminator to prevent collisions.
|
||||
let discriminator_preimage = {
|
||||
// For now, zero copy accounts can't be namespaced.
|
||||
if namespace.is_empty() {
|
||||
format!("account:{account_name}")
|
||||
} else {
|
||||
format!("{namespace}:{account_name}")
|
||||
}
|
||||
let discriminator_preimage = if namespace.is_empty() {
|
||||
format!("account:{account_name}")
|
||||
} else {
|
||||
format!("{namespace}:{account_name}")
|
||||
};
|
||||
|
||||
let mut discriminator = [0u8; 8];
|
||||
discriminator.copy_from_slice(
|
||||
&anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
|
||||
);
|
||||
format!("{discriminator:?}").parse().unwrap()
|
||||
};
|
||||
let discriminator: proc_macro2::TokenStream = format!("{discriminator:?}").parse().unwrap();
|
||||
quote! { &#discriminator }
|
||||
});
|
||||
let disc = if account_strct.generics.lt_token.is_some() {
|
||||
quote! { #account_name::#type_gen::DISCRIMINATOR }
|
||||
} else {
|
||||
|
@ -159,7 +173,7 @@ pub fn account(
|
|||
|
||||
#[automatically_derived]
|
||||
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
|
||||
const DISCRIMINATOR: &'static [u8] = &#discriminator;
|
||||
const DISCRIMINATOR: &'static [u8] = #discriminator;
|
||||
}
|
||||
|
||||
// This trait is useful for clients deserializing accounts.
|
||||
|
@ -229,7 +243,7 @@ pub fn account(
|
|||
|
||||
#[automatically_derived]
|
||||
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
|
||||
const DISCRIMINATOR: &'static [u8] = &#discriminator;
|
||||
const DISCRIMINATOR: &'static [u8] = #discriminator;
|
||||
}
|
||||
|
||||
#owner_impl
|
||||
|
@ -242,7 +256,10 @@ pub fn account(
|
|||
struct AccountArgs {
|
||||
/// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)`
|
||||
zero_copy: Option<bool>,
|
||||
/// Account namespace override, `account` if not specified
|
||||
namespace: Option<String>,
|
||||
/// Discriminator override
|
||||
discriminator: Option<proc_macro2::TokenStream>,
|
||||
}
|
||||
|
||||
impl Parse for AccountArgs {
|
||||
|
@ -257,6 +274,9 @@ impl Parse for AccountArgs {
|
|||
AccountArg::Namespace(ns) => {
|
||||
parsed.namespace.replace(ns);
|
||||
}
|
||||
AccountArg::Discriminator(disc) => {
|
||||
parsed.discriminator.replace(disc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -267,6 +287,7 @@ impl Parse for AccountArgs {
|
|||
enum AccountArg {
|
||||
ZeroCopy { is_unsafe: bool },
|
||||
Namespace(String),
|
||||
Discriminator(proc_macro2::TokenStream),
|
||||
}
|
||||
|
||||
impl Parse for AccountArg {
|
||||
|
@ -300,7 +321,24 @@ impl Parse for AccountArg {
|
|||
return Ok(Self::ZeroCopy { is_unsafe });
|
||||
};
|
||||
|
||||
Err(syn::Error::new(ident.span(), "Unexpected argument"))
|
||||
// Named arguments
|
||||
// TODO: Share the common arguments with `#[instruction]`
|
||||
input.parse::<Token![=]>()?;
|
||||
let value = input.parse::<Expr>()?;
|
||||
match ident.to_string().as_str() {
|
||||
"discriminator" => {
|
||||
let value = match value {
|
||||
// Allow `discriminator = 42`
|
||||
Expr::Lit(lit) if matches!(lit.lit, Lit::Int(_)) => quote! { &[#lit] },
|
||||
// Allow `discriminator = [0, 1, 2, 3]`
|
||||
Expr::Array(arr) => quote! { &#arr },
|
||||
expr => expr.to_token_stream(),
|
||||
};
|
||||
|
||||
Ok(Self::Discriminator(value))
|
||||
}
|
||||
_ => Err(syn::Error::new(ident.span(), "Invalid argument")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -39,9 +39,34 @@ pub mod custom_discriminator {
|
|||
pub fn const_fn(_ctx: Context<DefaultIx>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn account(ctx: Context<CustomAccountIx>, field: u8) -> Result<()> {
|
||||
ctx.accounts.my_account.field = field;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Accounts)]
|
||||
pub struct DefaultIx<'info> {
|
||||
pub signer: Signer<'info>,
|
||||
}
|
||||
|
||||
#[derive(Accounts)]
|
||||
pub struct CustomAccountIx<'info> {
|
||||
#[account(mut)]
|
||||
pub signer: Signer<'info>,
|
||||
#[account(
|
||||
init,
|
||||
payer = signer,
|
||||
space = MyAccount::DISCRIMINATOR.len() + core::mem::size_of::<MyAccount>(),
|
||||
seeds = [b"my_account"],
|
||||
bump
|
||||
)]
|
||||
pub my_account: Account<'info, MyAccount>,
|
||||
pub system_program: Program<'info, System>,
|
||||
}
|
||||
|
||||
#[account(discriminator = 1)]
|
||||
pub struct MyAccount {
|
||||
pub field: u8,
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ describe("custom-discriminator", () => {
|
|||
const program: anchor.Program<CustomDiscriminator> =
|
||||
anchor.workspace.customDiscriminator;
|
||||
|
||||
describe("Can use custom instruction discriminators", () => {
|
||||
describe("Instructions", () => {
|
||||
const testCommon = async (ixName: keyof typeof program["methods"]) => {
|
||||
const tx = await program.methods[ixName]().transaction();
|
||||
|
||||
|
@ -28,4 +28,26 @@ describe("custom-discriminator", () => {
|
|||
it("Constant", () => testCommon("constant"));
|
||||
it("Const Fn", () => testCommon("constFn"));
|
||||
});
|
||||
|
||||
describe("Accounts", () => {
|
||||
it("Works", async () => {
|
||||
// Verify discriminator
|
||||
const acc = program.idl.accounts.find((acc) => acc.name === "myAccount")!;
|
||||
assert(acc.discriminator.length < 8);
|
||||
|
||||
// Verify regular `init` ix works
|
||||
const field = 5;
|
||||
const { pubkeys, signature } = await program.methods
|
||||
.account(field)
|
||||
.rpcAndKeys();
|
||||
await program.provider.connection.confirmTransaction(
|
||||
signature,
|
||||
"confirmed"
|
||||
);
|
||||
const myAccount = await program.account.myAccount.fetch(
|
||||
pubkeys.myAccount
|
||||
);
|
||||
assert.strictEqual(field, myAccount.field);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue