diff --git a/Cargo.toml b/Cargo.toml index 0affa9a..1b8220f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,11 +17,16 @@ edition = "2018" bitvec = { version = "0.22", default-features = false, optional = true } byteorder = { version = "1", default-features = false, optional = true } ff_derive = { version = "0.8", path = "ff_derive", optional = true } +lazy_static = { version = "1.4.0", optional = true } rand_core = { version = "0.6", default-features = false } subtle = { version = "2.2.1", default-features = false, features = ["i128"] } +[target.'cfg(target_arch = "x86_64")'.build-dependencies] +cc = "1.0.50" + [features] default = ["bits", "std"] +asm = ["lazy_static", "std"] bits = ["bitvec"] derive = ["byteorder", "ff_derive"] std = [] diff --git a/asm/mul_4.S b/asm/mul_4.S new file mode 100644 index 0000000..9220d71 --- /dev/null +++ b/asm/mul_4.S @@ -0,0 +1,203 @@ +// A*B +// Schoolbook multiplication of four 64b limbs +// result in r8 - r15 +.macro mul_256 a b + xor %rax, %rax + mov 0x00\a, %rdx + mulx 0x00\b, %r8, %r9 + mulx 0x08\b, %rbx, %r10 + adcx %rbx, %r9 + mulx 0x10\b, %rbx, %r11 + adcx %rbx, %r10 + mulx 0x18\b, %rbx, %r12 + adcx %rbx, %r11 + adcx %rax, %r12 + xor %rax, %rax + mov 0x08\a, %rdx + mulx 0x00\b, %rbp, %rbx + adcx %rbp, %r9 + adox %rbx, %r10 + mulx 0x08\b, %rbp, %rbx + adcx %rbp, %r10 + adox %rbx, %r11 + mulx 0x10\b, %rbp, %rbx + adcx %rbp, %r11 + adox %rbx, %r12 + mulx 0x18\b, %rbp, %r13 + adcx %rbp, %r12 + adox %rax, %r13 + adcx %rax, %r13 + xor %rax, %rax + mov 0x10\a, %rdx + mulx 0x00\b, %rbp, %rbx + adcx %rbp, %r10 + adox %rbx, %r11 + mulx 0x08\b, %rbp, %rbx + adcx %rbp, %r11 + adox %rbx, %r12 + mulx 0x10\b, %rbp, %rbx + adcx %rbp, %r12 + adox %rbx, %r13 + mulx 0x18\b, %rbp, %r14 + adcx %rbp, %r13 + adox %rax, %r14 + adcx %rax, %r14 + xor %rax, %rax + mov 0x18\a, %rdx + mulx 0x00\b, %rbp, %rbx + adcx %rbp, %r11 + adox %rbx, %r12 + mulx 0x08\b, %rbp, %rbx + adcx %rbp, %r12 + adox %rbx, %r13 + mulx 0x10\b, %rbp, %rbx + adcx %rbp, %r13 + adox %rbx, %r14 + mulx 0x18\b, %rbp, %r15 + adcx %rbp, %r14 + adox %rax, %r15 + adcx %rax, %r15 +.endm + +// Montgomery reduction +// expects multiplication result in r8 - r15 +// See algo 14.32 from Handbook of Applied Cryptography +.macro red_256 res name + push %rsi + lea .LM(%rip), %rsi + xor %rax, %rax + mov 0x20(%rsi), %rdx + mulx %r8, %rdx, %rbp + mulx 0x00(%rsi), %rbp, %rbx + adox %rbp, %r8 + adcx %rbx, %r9 + mulx 0x08(%rsi), %rbp, %rbx + adox %rbp, %r9 + adcx %rbx, %r10 + mulx 0x10(%rsi), %rbp, %rbx + adox %rbp, %r10 + adcx %rbx, %r11 + mulx 0x18(%rsi), %rbp, %rbx + adox %rbp, %r11 + adcx %rbx, %r12 + adox %rax, %r12 + adcx %rax, %r13 + adox %rax, %r13 + adcx %rax, %r14 + adox %rax, %r14 + adcx %rax, %r15 + adox %rax, %r15 + mov 0x20(%rsi), %rdx + mulx %r9, %rdx, %rbp + mulx 0x00(%rsi), %rbp, %rbx + adox %rbp, %r9 + adcx %rbx, %r10 + mulx 0x08(%rsi), %rbp, %rbx + adox %rbp, %r10 + adcx %rbx, %r11 + mulx 0x10(%rsi), %rbp, %rbx + adox %rbp, %r11 + adcx %rbx, %r12 + mulx 0x18(%rsi), %rbp, %rbx + adox %rbp, %r12 + adcx %rbx, %r13 + adox %rax, %r13 + adcx %rax, %r14 + adox %rax, %r14 + adcx %rax, %r15 + adox %rax, %r15 + mov 0x20(%rsi), %rdx + mulx %r10, %rdx, %rbp + mulx 0x00(%rsi), %rbp, %rbx + adox %rbp, %r10 + adcx %rbx, %r11 + mulx 0x08(%rsi), %rbp, %rbx + adox %rbp, %r11 + adcx %rbx, %r12 + mulx 0x10(%rsi), %rbp, %rbx + adox %rbp, %r12 + adcx %rbx, %r13 + mulx 0x18(%rsi), %rbp, %rbx + adox %rbp, %r13 + adcx %rbx, %r14 + adox %rax, %r14 + adcx %rax, %r15 + adox %rax, %r15 + mov 0x20(%rsi), %rdx + mulx %r11, %rdx, %rbp + mov 0x00(%rsi), %r8 + mulx %r8, %rbp, %rbx + adox %rbp, %r11 + adcx %rbx, %r12 + mov 0x08(%rsi), %r9 + mulx %r9, %rbp, %rbx + adox %rbp, %r12 + adcx %rbx, %r13 + mov 0x10(%rsi), %r10 + mulx %r10, %rbp, %rbx + adox %rbp, %r13 + adcx %rbx, %r14 + mov 0x18(%rsi), %r11 + mulx %r11, %rbp, %rbx + adox %rbp, %r14 + adcx %rbx, %r15 + adox %rax, %r15 + mov %r12, 0x00\res + mov %r13, 0x08\res + mov %r14, 0x10\res + mov %r15, 0x18\res + sub %r8, %r12 + sbb %r9, %r13 + sbb %r10, %r14 + sbb %r11, %r15 + jb .Lred_256\name + mov %r12, 0x00\res + mov %r13, 0x08\res + mov %r14, 0x10\res + mov %r15, 0x18\res +.Lred_256\name: + pop %rsi +.endm + +.macro mod_mul_256 a b res name + mul_256 \a, \b + red_256 \res, \name +.endm + +// BLS12-381 G1 order r used as modulus +// Montgomery constant -m^-1 mod b +.LM: + .quad 0xffffffff00000001 + .quad 0x53bda402fffe5bfe + .quad 0x3339d80809a1d805 + .quad 0x73eda753299d7d48 + .quad 0xfffffffeffffffff + +#ifdef __APPLE__ +.global _mod_mul_4w +_mod_mul_4w: +#else +.global mod_mul_4w +mod_mul_4w: +#endif + // x = rdi + // y = rsi + // result = rdx + push %rbp + push %rbx + push %r12 + push %r13 + push %r14 + push %r15 + mov %rdx, %rcx // rcx = result + + // x * y + mod_mul_256 (%rdi), (%rsi), (%rcx), mm + + pop %r15 + pop %r14 + pop %r13 + pop %r12 + pop %rbx + pop %rbp + ret diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..fc3e906 --- /dev/null +++ b/build.rs @@ -0,0 +1,14 @@ +#[cfg(target_arch = "x86_64")] +fn main() { + let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap(); + + if target_arch == "x86_64" { + cc::Build::new() + .flag("-c") + .file("./asm/mul_4.S") + .compile("libff-derive-crypto.a"); + } +} + +#[cfg(not(target_arch = "x86_64"))] +fn main() {} diff --git a/ff_derive/src/lib.rs b/ff_derive/src/lib.rs index f522c22..888eabf 100644 --- a/ff_derive/src/lib.rs +++ b/ff_derive/src/lib.rs @@ -13,6 +13,9 @@ use std::str::FromStr; mod pow_fixed; +const BLS_381_FR_MODULUS: &str = + "52435875175126190479447740508185965837690552500527637822603658699938581184513"; + enum ReprEndianness { Big, Little, @@ -126,8 +129,9 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); // We're given the modulus p of the prime field - let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs) - .expect("Please supply a PrimeFieldModulus attribute") + let modulus_raw = fetch_attr("PrimeFieldModulus", &ast.attrs) + .expect("Please supply a PrimeFieldModulus attribute"); + let modulus: BigUint = modulus_raw .parse() .expect("PrimeFieldModulus should be a number"); @@ -178,6 +182,7 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { gen.extend(prime_field_impl( &ast.ident, &repr_ident, + &modulus_raw, &modulus, &endianness, limbs, @@ -637,6 +642,7 @@ fn prime_field_constants_and_sqrt( fn prime_field_impl( name: &syn::Ident, repr: &syn::Ident, + modulus_raw: &str, modulus: &BigUint, endianness: &ReprEndianness, limbs: usize, @@ -807,6 +813,45 @@ fn prime_field_impl( a: proc_macro2::TokenStream, b: proc_macro2::TokenStream, limbs: usize, + modulus_raw: &str, + ) -> proc_macro2::TokenStream { + if limbs == 4 && modulus_raw == BLS_381_FR_MODULUS { + mul_impl_asm4(a, b) + } else { + mul_impl_default(a, b, limbs) + } + } + + fn mul_impl_asm4( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + ) -> proc_macro2::TokenStream { + // x86_64 asm for four limbs + let default_impl = mul_impl_default(a.clone(), b.clone(), 4); + + let mut gen = proc_macro2::TokenStream::new(); + gen.extend(quote! { + #[cfg(target_arch = "x86_64")] + { + if *::ff::CPU_SUPPORTS_ADX_INSTRUCTION { + ::ff::mod_mul_4w_assign(&mut (#a.0).0, &(#b.0).0); + } else { + #default_impl + } + } + #[cfg(not(target_arch = "x86_64"))] + { + #default_impl + } + }); + + gen + } + + fn mul_impl_default( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + limbs: usize, ) -> proc_macro2::TokenStream { let mut gen = proc_macro2::TokenStream::new(); @@ -876,9 +921,125 @@ fn prime_field_impl( } } + fn add_assign_impl( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + limbs: usize, + ) -> proc_macro2::TokenStream { + if limbs == 4 { + add_assign_asm_impl(a, b, limbs) + } else { + add_assign_default_impl(a, b, limbs) + } + } + + fn add_assign_asm_impl( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + limbs: usize, + ) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + let default_impl = add_assign_default_impl(a.clone(), b.clone(), limbs); + + gen.extend(quote! { + #[cfg(target_arch = "x86_64")] + { + // This cannot exceed the backing capacity. + use core::arch::x86_64::*; + use core::mem; + + unsafe { + let mut carry = _addcarry_u64( + 0, + (#a.0).0[0], + (#b.0).0[0], + &mut (#a.0).0[0] + ); + carry = _addcarry_u64( + carry, (#a.0).0[1], + (#b.0).0[1], + &mut (#a.0).0[1] + ); + carry = _addcarry_u64( + carry, (#a.0).0[2], + (#b.0).0[2], + &mut (#a.0).0[2] + ); + _addcarry_u64( + carry, + (#a.0).0[3], + (#b.0).0[3], + &mut (#a.0).0[3] + ); + + let mut s_sub: [u64; 4] = mem::uninitialized(); + + carry = _subborrow_u64( + 0, + (#a.0).0[0], + MODULUS.0[0], + &mut s_sub[0] + ); + carry = _subborrow_u64( + carry, + (#a.0).0[1], + MODULUS.0[1], + &mut s_sub[1] + ); + carry = _subborrow_u64( + carry, + (#a.0).0[2], + MODULUS.0[2], + &mut s_sub[2] + ); + carry = _subborrow_u64( + carry, + (#a.0).0[3], + MODULUS.0[3], + &mut s_sub[3] + ); + + if carry == 0 { + // Direct assign fails since size can be 4 or 6 + // Obviously code doesn't work at all for size 6 + // (#a).0 = s_sub; + (#a.0).0[0] = s_sub[0]; + (#a.0).0[1] = s_sub[1]; + (#a.0).0[2] = s_sub[2]; + (#a.0).0[3] = s_sub[3]; + } + } + } + #[cfg(not(target_arch = "x86_64"))] + { + #default_impl + } + }); + + gen + } + + fn add_assign_default_impl( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + _limbs: usize, + ) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + gen.extend(quote! { + // This cannot exceed the backing capacity. + #a.0.add_nocarry(&#b.0); + + // However, it may need to be reduced. + #a.reduce(); + }); + gen + } + let squaring_impl = sqr_impl(quote! {self}, limbs); - let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs); + let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs, modulus_raw); let invert_impl = inv_impl(quote! {self}, name, modulus); + let add_assign = add_assign_impl(quote! {self}, quote! {other}, limbs); let montgomery_impl = mont_impl(limbs); // self.0[0].ct_eq(&other.0[0]) & self.0[1].ct_eq(&other.0[1]) & ... @@ -1053,11 +1214,7 @@ fn prime_field_impl( impl<'r> ::core::ops::AddAssign<&'r #name> for #name { #[inline] fn add_assign(&mut self, other: &#name) { - // This cannot exceed the backing capacity. - self.add_nocarry(other); - - // However, it may need to be reduced. - self.reduce(); + #add_assign } } diff --git a/src/asm.rs b/src/asm.rs new file mode 100644 index 0000000..090f833 --- /dev/null +++ b/src/asm.rs @@ -0,0 +1,49 @@ +lazy_static::lazy_static! { + pub static ref CPU_SUPPORTS_ADX_INSTRUCTION: bool = std::is_x86_feature_detected!("adx"); +} + +#[link(name = "ff-derive-crypto", kind = "static")] +extern "C" { + fn mod_mul_4w(a: &[u64; 4], b: &[u64; 4], res: &mut [u64; 4]); +} + +pub fn mod_mul_4w_assign(a: &mut [u64; 4], b: &[u64; 4]) { + let mut res = [0; 4]; + unsafe { + mod_mul_4w(&*a, b, &mut res); + } + let _ = core::mem::replace(a, res); +} + +#[cfg(test)] +mod tests { + use super::*; + + use rand_core::SeedableRng; + + #[test] + fn test_mod_mul() { + let mut x: [u64; 4] = [ + 7665858810281813592, + 16340119633057872346, + 4817051413996267933, + 2960177199463250197, + ]; + let y: [u64; 4] = [ + 12935154801682980781, + 13314970078575206070, + 2674023185838267390, + 551755778115450960, + ]; + let exp: [u64; 4] = [ + 12035708911089303301, + 16867479803567096087, + 8918020714254073494, + 3250221169924948371, + ]; + + mod_mul_4w_assign(&mut x, &y); + + assert_eq!(x[0..4], exp[0..4], "\nMod Mul error\n"); + } +} diff --git a/src/lib.rs b/src/lib.rs index ca18d3c..a23757b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,12 +4,21 @@ #![no_std] #![cfg_attr(docsrs, feature(doc_cfg))] #![deny(broken_intra_doc_links)] -#![forbid(unsafe_code)] +#![cfg_attr(not(feature = "asm"), forbid(unsafe_code))] + +#[cfg(feature = "std")] +extern crate std; #[cfg(feature = "derive")] #[cfg_attr(docsrs, doc(cfg(feature = "derive")))] pub use ff_derive::PrimeField; +#[cfg(all(feature = "asm", target_arch = "x86_64"))] +mod asm; + +#[cfg(all(feature = "asm", target_arch = "x86_64"))] +pub use asm::*; + #[cfg(feature = "bits")] #[cfg_attr(docsrs, doc(cfg(feature = "bits")))] pub use bitvec::view::BitViewSized; @@ -18,6 +27,7 @@ pub use bitvec::view::BitViewSized; use bitvec::{array::BitArray, order::Lsb0}; use core::fmt; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + use rand_core::RngCore; use subtle::{ConditionallySelectable, CtOption};