1
//! See [`P256`].
2

            
3
use std::ops::{Add, Mul, Sub};
4

            
5
use digest::{BlockInput, Digest};
6
use generic_array::{typenum::U1, ArrayLength, GenericArray};
7
use opaque_ke::{
8
	errors::InternalError,
9
	key_exchange::group::KeGroup,
10
	rand::{CryptoRng, RngCore},
11
};
12
use p256_::ProjectivePoint;
13
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
14
use subtle::ConstantTimeEq;
15
use voprf::{errors::InternalError as VoprfInternalError, group::Group};
16
use zeroize::Zeroize;
17

            
18
/// Object implementing [`Group`] for P256. This encapsulates
19
/// [`ProjectivePoint`].
20
504
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
21
pub(crate) struct P256(ProjectivePoint);
22

            
23
impl<'de> Deserialize<'de> for P256 {
24
	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25
	where
26
		D: Deserializer<'de>,
27
	{
28
		Self::from_element_slice(&GenericArray::deserialize(deserializer)?).map_err(Error::custom)
29
	}
30
}
31

            
32
impl Serialize for P256 {
33
	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34
	where
35
		S: Serializer,
36
	{
37
		Group::to_arr(self).serialize(serializer)
38
	}
39
}
40

            
41
impl Add<&Self> for P256 {
42
	type Output = Self;
43

            
44
	fn add(self, other: &Self) -> Self {
45
		Self(Add::add(self.0, other.0))
46
	}
47
}
48

            
49
impl Mul<&Scalar> for P256 {
50
	type Output = Self;
51

            
52
648
	fn mul(self, other: &Scalar) -> Self {
53
648
		Self(Mul::mul(self.0, other.0))
54
648
	}
55
}
56

            
57
impl ConstantTimeEq for P256 {
58
684
	fn ct_eq(&self, other: &Self) -> subtle::Choice {
59
684
		self.0.ct_eq(&other.0)
60
684
	}
61
}
62

            
63
impl KeGroup for P256 {
64
	type PkLen = <ProjectivePoint as KeGroup>::PkLen;
65
	type SkLen = <ProjectivePoint as KeGroup>::SkLen;
66

            
67
696
	fn from_pk_slice(element_bits: &GenericArray<u8, Self::PkLen>) -> Result<Self, InternalError> {
68
696
		ProjectivePoint::from_pk_slice(element_bits).map(Self)
69
696
	}
70

            
71
288
	fn random_sk<R: RngCore + CryptoRng>(rng: &mut R) -> GenericArray<u8, Self::SkLen> {
72
288
		ProjectivePoint::random_sk(rng)
73
288
	}
74

            
75
504
	fn public_key(sk: &GenericArray<u8, Self::SkLen>) -> Self {
76
504
		Self(ProjectivePoint::public_key(sk))
77
504
	}
78

            
79
504
	fn to_arr(&self) -> GenericArray<u8, Self::PkLen> {
80
504
		<ProjectivePoint as KeGroup>::to_arr(&self.0)
81
504
	}
82

            
83
432
	fn diffie_hellman(&self, sk: &GenericArray<u8, Self::SkLen>) -> GenericArray<u8, Self::PkLen> {
84
432
		self.0.diffie_hellman(sk)
85
432
	}
86
}
87

            
88
impl Group for P256 {
89
	type ElemLen = <ProjectivePoint as Group>::ElemLen;
90
	type Scalar = Scalar;
91
	type ScalarLen = <ProjectivePoint as Group>::ScalarLen;
92

            
93
	const SUITE_ID: usize = <ProjectivePoint as Group>::SUITE_ID;
94

            
95
216
	fn hash_to_curve<H: BlockInput + Digest, D: ArrayLength<u8> + Add<U1>>(
96
216
		msg: &[u8],
97
216
		dst: GenericArray<u8, D>,
98
216
	) -> Result<Self, VoprfInternalError>
99
216
	where
100
216
		<D as Add<U1>>::Output: ArrayLength<u8>,
101
216
	{
102
216
		ProjectivePoint::hash_to_curve::<H, _>(msg, dst).map(Self)
103
216
	}
104

            
105
	#[allow(single_use_lifetimes)]
106
648
	fn hash_to_scalar<
107
648
		'a,
108
648
		H: BlockInput + Digest,
109
648
		D: ArrayLength<u8> + Add<U1>,
110
648
		I: IntoIterator<Item = &'a [u8]>,
111
648
	>(
112
648
		input: I,
113
648
		dst: GenericArray<u8, D>,
114
648
	) -> Result<Self::Scalar, VoprfInternalError>
115
648
	where
116
648
		<D as Add<U1>>::Output: ArrayLength<u8>,
117
648
	{
118
648
		ProjectivePoint::hash_to_scalar::<H, _, _>(input, dst).map(Scalar)
119
648
	}
120

            
121
360
	fn from_scalar_slice_unchecked(
122
360
		scalar_bits: &GenericArray<u8, Self::ScalarLen>,
123
360
	) -> Result<Self::Scalar, VoprfInternalError> {
124
360
		ProjectivePoint::from_scalar_slice_unchecked(scalar_bits).map(Scalar)
125
360
	}
126

            
127
216
	fn random_nonzero_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar {
128
216
		Scalar(ProjectivePoint::random_nonzero_scalar(rng))
129
216
	}
130

            
131
720
	fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray<u8, Self::ScalarLen> {
132
720
		ProjectivePoint::scalar_as_bytes(scalar.0)
133
720
	}
134

            
135
432
	fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar {
136
432
		Scalar(ProjectivePoint::scalar_invert(&scalar.0))
137
432
	}
138

            
139
468
	fn from_element_slice_unchecked(
140
468
		element_bits: &GenericArray<u8, Self::ElemLen>,
141
468
	) -> Result<Self, VoprfInternalError> {
142
468
		ProjectivePoint::from_element_slice_unchecked(element_bits).map(Self)
143
468
	}
144

            
145
1368
	fn to_arr(&self) -> GenericArray<u8, Self::ElemLen> {
146
1368
		<ProjectivePoint as Group>::to_arr(&self.0)
147
1368
	}
148

            
149
	fn base_point() -> Self {
150
		Self(ProjectivePoint::base_point())
151
	}
152

            
153
252
	fn is_identity(&self) -> bool {
154
252
		self.0.is_identity()
155
252
	}
156

            
157
1764
	fn identity() -> Self {
158
1764
		Self(ProjectivePoint::identity())
159
1764
	}
160

            
161
360
	fn scalar_zero() -> Self::Scalar {
162
360
		Scalar(ProjectivePoint::scalar_zero())
163
360
	}
164
}
165

            
166
/// Wrapper over [`p256::Scalar`](p256_::Scalar) to implement common traits.
167
1152
#[derive(Clone, Copy, Debug, Eq, PartialEq, Zeroize)]
168
pub(crate) struct Scalar(p256_::Scalar);
169

            
170
#[allow(clippy::derive_hash_xor_eq)]
171
impl std::hash::Hash for Scalar {
172
	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
173
		self.0.to_bytes().hash(state);
174
	}
175
}
176

            
177
impl<'de> Deserialize<'de> for Scalar {
178
	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
179
	where
180
		D: Deserializer<'de>,
181
	{
182
		Ok(Self(p256_::Scalar::from_bytes_reduced(
183
			&GenericArray::deserialize(deserializer)?,
184
		)))
185
	}
186
}
187

            
188
impl Serialize for Scalar {
189
	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
190
	where
191
		S: Serializer,
192
	{
193
		self.0.to_bytes().serialize(serializer)
194
	}
195
}
196

            
197
impl Add<&Self> for Scalar {
198
	type Output = Self;
199

            
200
216
	fn add(self, rhs: &Self) -> Self::Output {
201
216
		Self(self.0.add(&rhs.0))
202
216
	}
203
}
204

            
205
impl Sub<&Self> for Scalar {
206
	type Output = Self;
207

            
208
	fn sub(self, rhs: &Self) -> Self::Output {
209
		Self(self.0.sub(&rhs.0))
210
	}
211
}
212

            
213
impl Mul<&Self> for Scalar {
214
	type Output = Self;
215

            
216
	fn mul(self, rhs: &Self) -> Self::Output {
217
		Self(self.0.mul(&rhs.0))
218
	}
219
}
220

            
221
impl ConstantTimeEq for Scalar {
222
360
	fn ct_eq(&self, other: &Self) -> subtle::Choice {
223
360
		self.0.ct_eq(&other.0)
224
360
	}
225
}