diff --git a/ref/compat.h b/ref/compat.h new file mode 100644 index 00000000..9f664bb9 --- /dev/null +++ b/ref/compat.h @@ -0,0 +1,10 @@ +#ifndef COMPAT_H +#define COMPAT_H + +#ifdef _MSC_VER +#define KYBER_NOINLINE __declspec(noinline) +#else +#define KYBER_NOINLINE __attribute__((noinline)) +#endif + +#endif diff --git a/ref/poly.c b/ref/poly.c index 9d3b1358..a8809cec 100644 --- a/ref/poly.c +++ b/ref/poly.c @@ -174,8 +174,7 @@ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) for(i=0;icoeffs[8*i+j] = 0; - cmov_int16(r->coeffs+8*i+j, ((KYBER_Q+1)/2), (msg[i] >> j)&1); + r->coeffs[8*i+j] = cmov_int16(0, ((KYBER_Q+1)/2), (msg[i] >> j)&1); } } } diff --git a/ref/verify.c b/ref/verify.c index 4bcb1e5f..0fae0f47 100644 --- a/ref/verify.c +++ b/ref/verify.c @@ -1,6 +1,7 @@ #include #include #include "verify.h" +#include "compat.h" /************************************************* * Name: verify @@ -59,16 +60,19 @@ void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) /************************************************* * Name: cmov_int16 * -* Description: Copy input v to *r if b is 1, don't modify *r if b is 0. -* Requires b to be in {0,1}; +* Description: Returns a if select == 1, or b if select == 0. +* Requires select to be in {0,1}. * Runs in constant time. * -* Arguments: int16_t *r: pointer to output int16_t -* int16_t v: input int16_t -* uint8_t b: Condition bit; has to be in {0,1} +* Arguments: int16_t a: first int16_t input +* int16_t b: second int16_t input +* uint8_t select: Condition bit; has to be in {0,1} +* +* Returns a if select == 1 or b if select == 0, or +* an undefined value otherwise. **************************************************/ -void cmov_int16(int16_t *r, int16_t v, uint16_t b) +KYBER_NOINLINE int16_t cmov_int16(int16_t a, int16_t b, uint16_t select) { - b = -b; - *r ^= b & ((*r) ^ v); + const uint16_t mask = -select; + return a ^ (mask & (a ^ b)); } diff --git a/ref/verify.h b/ref/verify.h index 09f0ad50..d0e4e8b6 100644 --- a/ref/verify.h +++ b/ref/verify.h @@ -12,6 +12,6 @@ int verify(const uint8_t *a, const uint8_t *b, size_t len); void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); #define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); +int16_t cmov_int16(int16_t a, int16_t b, uint16_t select); #endif