Note: This post is best read on a wide screen, and probably not in a rss reader. Substack still doesn’t handle wide codeblocks well.
Motivation
Back in ~2023 I read a lot of white papers. Grokking these exposed my shaky fundamentals, and encouraged me to go back and rebuild. I searched far and wide for a concise explanation of what I needed to understand threshold cryptography, and did not find anything consolidated. I wanted something like Paul Miller's Learning fast elliptic-curve cryptography, but even lower level.
Miller's post jumps straight into phrases like "we’re working in a finite field over some big prime P" and copying formulas from hyperelliptic. I wanted to both have and prove to myself a deep understanding of what sentences like that mean. I wanted to be able to prove hyperelliptic's formulas - otherwise I felt like I was a cargo cultist copying another's math without understanding it.
I set out. I researched and learned. I returned having grown!
Modular Math
Note: Cryptography-conscious folks might know all about timing attacks, and would be right to point out that many of the following implementations are vulnerable to them. My goal here is to use development to build intuition, not to provide production-ready code. Many of the safe versions of the algorithms use these ideas as a base, and then waste work in order to make each algorithm execute in constant time.
The mod
Function
In simple terms, we define a mod b
or mod(a, b)
as the smallest natural number r < b
such that a = k * b + r
where k
is an arbitrary integer.
Examples:
mod(55, 9) = 1
since55 = 6 * 9 + 1
mod(-20, 6) = 4
since-20 = -4 * 6 + 4
We can rearrange the formula to r = a - k * b
. To find the right k
for mod(-20, 6)
, we can iterate. If a
is positive, we need to keep increasing k
. Here is mod(55, 9)
:
55 = 55 - 0*9
46 = 55 - 1*9
37 = 55 - 2*9
28 = 55 - 3*9
19 = 55 - 4*9
10 = 55 - 5*9
1 = 55 - 6*9
If a
is negative, we need to keep decreasing k
. Here is mod(-20, 6)
:
-20 = -20 - 0*6
-14 = -20 - -1*6
-8 = -20 - -2*6
-2 = -20 - -3*6
4 = -20 - -4*6
In order to arrive directly at what k
needs to be, we can use division:
r = a - k*b
0 <= r < b // from the definition
0 <= a - k*b < b // substitue r for a - k*b. Gives two inequalities.
0 <= a - k*b // start with the first inequality
k*b <= a // rearrange
k <= a/b // upper bound on k
a - k*b < b // continue with the second inequality
a/b - k < 1 // divide by b
a/b - 1 < k // lower bound on k
Thus, a/b - 1 < k <= a/b
and k
is an integer. If a = 55; b = 9
, then we
can write:
55/9 - 1 < k <= 55/9
5.11 < k <= 6.11
k = 6
If a = -20, b = 6
, then we can write:
-20/6 -1 < k <= -20/6
-4.33 < k <= -3.33
k = -4
The operation we're looking for is floor. Thus, we can define r
by r = a - floor(a / b) * b
. This property is why the modulus is thought of as "the remainder after division". For example, 55
divided by 9
is 6
with 1
remaining. The definition above, mod(a, b) = r where 0 <= r < b and a = k * b + r for integer k
, is mathematically precise and useful. Many programming languages have a mod
function (often %
) that works like the above definition (for positive b
values). Javascript (and Typescript) instead perform truncated division (rounding toward 0), so r
's sign matches a
's sign.
Here is python:
❯ python
Python 3.10.8
>>> -55 % 9
8
Here is Javascript:
❯ node
Welcome to Node.js v12.18.3.
> -55 % 9
-1
Fortunately this is easy to convert:
function mod(a: number, b: number): number {
let result = a % b
if (result >= 0) {
return result
}
return b + result
}
Since Javascript (and Typescript) are stopping one iteration early, we manually perform it. Boom!
Modular Multiplication
Now that we understand the mod
function, we can make sense of some very fundamental modular math properties that we will use to build our later functions. For instance, we can simplify multiplication. mod(52 * 52, 9) = 4
. Instead of calculating
mod(52*52, 9) =
mod(2704, 9) = 4
We can keep the numbers smaller using our definition a = k*b + r
:
52*52 = (5*9 + 7) * (5*9 + 7)
= 5*9*5*9 + 7*5*9 + 5*9*7 + 49 // distribute
= 225*9 + 35*9 + 35*9 + 49 // simplify
= 295*9 + 49 // factor out the 9; 49 remains
= 295*9 + (5*9 + 4) // reduce 49 into k*9 + r
= 300*9 + 4 // factor out the 9; 4 remains
After we distributed, every other term except 49
had a 9
in it. That means for the purposes of calculating the modulus, all we care about is that term.
This is always the case, because of how we are representing the numbers: 52 = 5*9 + 7
, so when we multiply that by the same representation, only the multiplication of the remainders is relevant. We notice:
mod(52*52, 9) =
mod(7*7, 9) =
mod(49, 9) = 4
To generalize this to mod(x * y, b)
, we can write:
x = k_x*b + r_x
y = k_y*b + r_y
mod(x, b) = r_x
mod(y, b) = r_y
x*y = (k_x*b + r_x) * (k_y*b + r_y)
x*y = k_x*b*k_y*b + r_x*k_y*b + k_x*b*k_y*b + r_x*r_y
x*y = b * (k_x*k_y*b + r_x*k_y + k_x*b*k_y) + r_x*r_y
mod(x*y, b) = mod(b * (k_x*k_y*b + r_x*k_y + k_x*b*k_y) + r_x*r_y, b)
= mod(r_x*r_y, b)
= mod(mod(x, b)*mod(y, b), b)
If we are multiplying two numbers, we can take the modulus beforehand, to keep our numbers small. If we are taking the modulus of a number with many known factors, we can take the modulus of each one, multiply the results, and take the modulus of that.
This abstracts to any number of terms. If a number has factors w, x, y, z
, then:
mod(w*x*y*z, b) =
mod(w * (x*y*z), b) =
mod(mod(w, b)*mod(x*y*z, b), b) =
mod(mod(w, b)*mod(mod(x, b)*mod(y*z, b), b), b) =
mod(mod(w, b)*mod(mod(x, b)*mod(mod(y, b)*mod(z, b), b), b), b) =
Algorithmically, it looks like this:
function multiply(factors: number[], b: number): number {
let result = mod(factors[0], b)
for (let i = 1; i < factors.length; i++) {
result = mod(result * mod(factors[i], b), b)
}
return result
}
Modular Exponentiation
The above property comes in very handy for exponentiation. Since mod(52 ^ 4, 9)
is the same as mod(52 * 52 * 52 * 52, 9)
, we can keep reducing the problem:
mod(52*52*52*52, 9) =
mod(mod(52, 9) * mod(52*52*52, 9), 9)
mod(7 * mod(52*52*52, 9), 9)
mod(7 * mod(mod(52, 9) * mod(52*52, 9), 9), 9)
mod(7 * mod(7 * mod(52*52, 9), 9), 9)
mod(7 * mod(7 * mod(mod(52, 9) * mod(52, 9), 9), 9), 9)
mod(7 * mod(7 * mod(7 * 7, 9), 9), 9)
mod(7 * mod(7 * mod(49, 9), 9), 9)
mod(7 * mod(7 * 4, 9), 9)
mod(7 * mod(28, 9), 9)
mod(7 * 1, 9)
7
Generically, given mod(a ^ b, m)
for b > 0
, we can pull factors out one at a time take their mod
, and keep accumulating, so:
mod(a ^ b, m) =
mod(mod(a, m) * mod(a ^ b-1, m), m)
We can keep repeating this process to keep all of our numbers nice and small. But we can do better! Say we want to calculate mod(42^55, 97)
. We are performing the multiplication operation 55 times. Not great. Faster is to notice that
55 = 1 + 2 + 4 + 0 + 16 + 32
= 2^0 + 2^1 + 2^2 + 2^4 + 2^5
42^55 = 42^1 * 42^2 * 42^4 * 42^16 * 42^32
The way that this helps is that it is fast to double numbers
mod(42^1, 97) = 42
mod(42^2, 97) = mod(42^1 * 42^1, 97)
= mod(42 * 42, 97) // substitute
= mod(1764, 97) // operation 1
= 18
mod(42^4, 97) = mod(42^2 * 42^2, 97)
= mod(18 * 18, 97) // substitute
= mod(324, 97) // operation 2
= 33
mod(42^8, 97) = mod(42^4 * 42^4, 97)
= mod(33 * 33, 97) // substitute
= mod(1089, 97) // operation 3
= 22
mod(42^16, 97) = mod(42^8 * 42^8, 97)
= mod(22 * 22, 97) // substitute
= mod(484, 97) // operation 4
= 96
mod(42^32, 97) = mod(4^16 * 4^16, 97)
= mod(96 * 96, 97) // substitute
= mod(9216, 97) // operation 5
= 1
mod(42^55, 97) = mod(42^1 * 42^2 * 42^4 * 42^16 * 42^32, 97)
mod(42^55, 97) = mod(
mod(42, 97) *
mod(42^2, 97) *
mod(42^4, 97) *
mod(42^16, 97) *
mod(42^32, 97),
97
)
mod(42^55, 97) = mod(42 * 18 * 33 * 96 * 1, 97) // substitute
= mod(mod(42 * 18, 97) * 33 * 96 * 1, 97)
= mod(mod(756, 97) * 33 * 96 * 1, 97) // operation 6
= mod(77 * 33 * 96 * 1, 97)
= mod(mod(77 * 33, 97) * 96 * 1, 97)
= mod(mod(2541, 97) * 96 * 1, 97) // operation 7
= mod(19 * 96 * 1, 97)
= mod(mod(19 * 96, 97) * 1, 97)
= mod(mod(1824, 97) * 1, 97) // operation 8
= mod(78 * 1, 97)
= 78 // operation 9
Now we are performing 5 multiplications to build our powers of 2, and then 4 multiplications to combine them for a total of 9 multiplications. In general, it takes roughly 2*log2(exponent)
multiplications to arrive at the answer, instead of exponent
multiplications. Given that our exponents might be thousands of digits long, this is very important.
In order to make this work, we need to figure out how to decompose our exponent into powers of 2. Fortunately, this is exactly what a binary representation is, and it is the native way computers store numbers.
If your language doesn't give you a conveniently accessible binary representation, it is easy to compute.
Notice:
55 = 1 + 2 * 27
= 1 + 2 * (1 + 2*13)
= 1 + 2 * (1 + 2 * (1 + 2*6)
= 1 + 2 * (1 + 2 * (1 + 2 * (2 * 3)))
= 1 + 2 * (1 + 2 * (1 + 2 * (2 * (1 + 2))))
= 1 + 2 + 4 * (1 + 2 * (2 * (1 + 2))) // start distributing
= 1 + 2 + 4 + 8 * (2 * (1 + 2))
= 1 + 2 + 4 + 16 * (1 + 2)
= 1 + 2 + 4 + 16 + 32
The only computation we're doing is dividing by 2 and taking the result and the remainder. Check out the pattern: if the number is odd, we will use that term since it's representation is 1 + 2 * k
rather than 2 * k
. As we distribute the 2's, the 1 sticks around. This is really clear if we use a power of 2, like 32.
32 = 2 * 16
= 2 * (2 * 8)
= 2 * (2 * (2 * 4))
= 2 * (2 * (2 * (2 * 2))
Converting the above into an algorithm, we get:
// produces a reversed binary representation,
// so 13 would be 1011 instead of 1101
function binary(n: number): number[] {
if (n === 0) {
return [0]
}
let result: number[] = []
while (n != 0) {
result.push(mod(n, 2))
n = Math.floor(n / 2)
}
return result
}
And this lets us build our exponentiation function:
function powerMod(
base: number,
exponent: number,
modulus: number
): number {
const bits = binary(exponent)
let res = 1
let accumulator = base
for (let bit of bits) {
if (bit === 1) {
res = mod(res * accumulator, modulus)
}
accumulator = mod(accumulator * accumulator, modulus)
}
return res
}
If we wanted to get fancy and were willing to dip into bitwise operators, then we could leverage the underlying binary representation for a performance boost at the cost of readability:
function powerMod(
base: number,
exponent: number,
modulus: number
): number {
let res = 1
let accumulator = base
while (exponent > 0) {
if (exponent & 1) {
res = mod(res * accumulator, modulus)
}
accumulator = mod(accumulator * accumulator, modulus)
exponent = exponent >>> 1
}
return res
}
Modular Inverse
We define the identity of an operation to be the special number where operation(a, b) = a
. For instance, in addition, the identity is 0
, since add(x, 0) = x
. In multiplication, the identity is 1
, since multiply(x, 1) = x
.
The inverse of x
in an operation is the number y
such that operation(x, y) = identity
. For addition, this is -x
, since add(x, -x) = 0
. For for multiplication of real numbers, this is 1/x
, since multiply(x, 1/x) = 1
.
The multiplicative inverse of integers, does not exist. There is no integer y
where multiply(x, y) = 1
.
However, the multiplicative inverse of multiplication under mod
sometimes exists1. For given integers x
and modulus m
, we are trying to find an integer y
such that mod(x * y, m) = 1
, as 1
is the modular multiplicative identity (since mod(x * 1, m) = mod(x, m)
).
For example, mod(4 * 2, 7) = 1
, so 2
is the inverse of 4 (mod 7)
, and 4
is the inverse of 2 (mod 7)
.
The brute force way to find inverses is simple: we can search through the numbers one at a time, and check to see if the multiplication is 1
.
mod(1*1, 7) = 1 // inverse pair
mod(2*1, 7) = 2
mod(2*2, 7) = 4
mod(2*3, 7) = 6
mod(2*4, 7) = 1 // inverse pair
mod(3*1, 7) = 3
mod(3*2, 7) = 6
mod(3*3, 7) = 2
mod(3*4, 7) = 5
mod(3*5, 7) = 1 // inverse pair
mod(6*1, 7) = 6
mod(6*2, 7) = 5
mod(6*3, 7) = 4
mod(6*4, 7) = 3
mod(6*5, 7) = 2
mod(6*6, 7) = 1 // inverse pair
This works when the numbers are small, but when they start getting a little bigger, like finding mod(55*y, 97) = 1
, we end up doing:
mod(55*1, 97) = 55
mod(55*2, 97) = 13
mod(55*3, 97) = 68
mod(55*4, 97) = 26
...
mod(55*30, 97) = 1
This ends up performing, on average, m/2
multiplications to find the answer.
We can do better!
If we can solve x * y + k * m = 1
, then
mod(x*y + k*m, m) = 1
mod(x*y, m) = 1
So we need to be able to write that equation. Here's how we can do it:
97 = 55 + 42
42 = 97 - 55 // 42 in terms of m and x
55 = 42 + 13
13 = 55 - 42 // rearrange
13 = 55 - (97 - 55) // substitute 42 for (97 - 55)
13 = -1*97 + 2*55 // 13 in terms of m and x
42 = 3*13 + 3
3 = 42 - 3*13 // rearrange
3 = (97 - 55) - 3 * (-1*97 + 2*55) // substitute
3 = 97 - 55 + 3*97 - 6*55
3 = 4*97 - 7*55 // 3 in terms of m and x
13 = 4*3 + 1
1 = 13 - 4*3 // rearrange
1 = (-1*97 + 2*55) - 4 * (4*97 - 7*55) // substitute
1 = -1*97 + 2*55 - 16*97 + 28*55
1 = -17*97 + 30*55 // 1 in terms of m and x
Taking the mod of both sides, we know that
mod(1, 97) = 1
mod(-17*97 + 30*55, 97) = 1
mod(30*55, 97) = 1
Thus, 30
is the modular inverse of 55 (mod 97)
. How do we turn this into a computer program? One more example using mod(17*y, 97) = 1
, adding some initial relations and making the lines more compact:
97 = 1*97 + 0*17
17 = 0*97 + 1*17
12 = 97 - 5*17
12 = (1*97 + 0*17) - 5 * (0*97 + 1*17)
12 = 1*97 - 5*17
5 = 17 - 1*12
5 = (0*97 + 1*17) - 1 * (1*97 - 5*17)
5 = -1*97 + 6*17
2 = 12 - 2*5
2 = (1*97 - 5*17) - 2 * (-1*97 + 6*17)
2 = 3*97 - 17*17
1 = 5 - 2*2
1 = (-1*97 + 6*17) - 2 * (3*97 - 17*17)
1 = -7*97 + 40*17
See the pattern? We keep writing mod(big, small)
in terms of j*x + k*m
. Each time,
q = Math.floor(big/small)
j = previous_previous_j - q * previous_j
k = previous_previous_k - q * previous_k
big, small = small, mod(big, small)
previous_previous_k = previous_k
previous_k = k
previous_previous_j = previous_j
previous_j = j
So for instance:
previous_k = 0 // x = 0*m + 1*x; k is the coefficient of m
previous_previous_k = 1 // m = 1*m + 0*x
previous_j = 1 // x = 0*m + 1*x; j is the coefficient of x
previous_previous_j = 0 // m = 1*m + 0*x
big = m = 97
small = x = 17
q = Math.floor(big/small) = Math.floor(97/17) = 5
j = previous_previous_j - q * previous_j = 0 - 5*1 = -5
k = previous_previous_k - q * previous_k = 1 - 5*0 = 1
big, small = small, mod(big, small) = 17, mod(97, 17) = 17, 12
previous_previous_j = previous_j = 1
previous_j = j = -5
previous_previous_k = previous_k = 0
previous_k = k = 1
q = Math.floor(big/small) = Math.floor(17/12) = 1
j = previous_previous_j - q * previous_j = 1 - 1*-5 = 6
k = previous_previous_k - q * previous_k = 0 - 1*1 = -1
big, small = small, mod(big, small) = 12, mod(17, 12) = 12, 5
previous_previous_j = previous_j = -5
previous_j = j = 6
previous_previous_k = previous_k = 1
previous_k = k = -1
q = Math.floor(big/small) = Math.floor(12/5) = 2
j = previous_previous_j - q * previous_j = -5 - 2*6 = -17
k = previous_previous_k - q * previous_k = 1 - 2*-1 = 3
big, small = small, mod(big, small) = 5, mod(12, 5) = 5, 2
previous_previous_j = previous_j = 6
previous_j = j = -17
previous_previous_k = previous_k = -1
previous_k = k = 3
q = Math.floor(big/small) = Math.floor(5/2) = 2
j = previous_previous_j - q * previous_j = 6 - 2*-17 = 40
k = previous_previous_k - q * previous_k = -1 - 2*3 = -7
big, small = small, mod(big, small) = 2, mod(5, 2) = 2, 1
return j
This means we can write a pretty simple program:
function modInverse(x, m) {
let previous_previous_j = 0 // j is the coefficient of x
let previous_j = 1
let previous_previous_k = 1 // k is the coefficient of m
let previous_k = 0
let big = m
let small = x
while(small > 1) {
const q = Math.floor(big/small)
const j = previous_previous_j - q * previous_j
const k = previous_previous_k - q * previous_k
const temp = small
small = mod(big, small)
big = temp
previous_previous_j = previous_j
previous_j = j
previous_previous_k = previous_k
previous_k = k
}
if (small === 0) {
return NaN
}
if (previous_j < 0) {
return previous_j + m
}
return previous_j
}
modInverse(2, 97) = 49
and mod(2*49, 97) = mod(98, 97) = 1
. Yay! The above algorithm is called the Extended Euclidean Algorithm. I lay it out here in detail because I find other explanations of what is going on and why it works to be very confusing.
Modular Exponentiation With Negative Bases And Powers
Now that we have worked out modular multiplication and inversion rules, we are ready to tackle negative bases and exponents. For example, what is mod(-17^29, 97)
?
Brute force would be to compute mod(-17*-17*-17..., 97)
, but fortunately we know that mod(a*b, m) = mod(mod(a, m)*mod(b,m), m)
. This means that
mod(-17^29, 97) =
mod(mod(-17, 97)^29, 97) =
mod(80^29, 97)
And then we're able to use our regular algorithm with the binary representation of the exponent.
For negative exponents, we know that x^-y = x^-1^y = (x^-1)^y
. So we can write
mod(x^-y, m)
mod((x^-1)^y, m)
mod(modInverse(x, m)^y, m)
Which means we can put together our whole powerMod
function:
export function powerMod(
base: number,
exponent: number,
modulus: number
): number {
if (base < 0) {
base = modulus + base
}
if (exponent < 0) {
// a^-b mod c === (a^-1)^b mod c
base = modInverse(base, modulus)
if (isNaN(base)) {
return NaN
}
exponent = exponent * -1
}
const bits = binary(exponent)
let result = 1
let accumulator = base
for (let bit of bits) {
if (bit === 1) {
result = mod(result * accumulator, modulus)
}
accumulator = mod(accumulator * accumulator, modulus)
}
return result
}
That handled the last complication! Now we have a working mod
function, quick modular multiplication and exponentiation, and the ability to calculate a modular inverse. We are ready to build Elliptic Curves!
We can find a y
for a given x
and m
such that mod(x*y, m) = 1
if and only if x
and m
share no common factors greater than 1
.
If they do share a common factor f > 1
then we can write:
x = f*a // for integer a
m = f*b // for integer b
1 = mod(x*y, m)
1 = x*y - k*m // for integer k
1 = f*a*y - k*f*b
1 = f(a*y - k*b) // contradiction
q = (a*y - k*b) // define q
1 = f*q
q
is an integer. It is impossible to form 1
by multiplying f*q
where f > 1
.
If they do not share a common factor, then we can write:
r_1 = x - k_1*m
r_2 = m - k_2*r_1
r_2 = m - k_2 * (x - k_1)*m
r_2 = m - k_2*x - k_2*k_1*m
r_2 = -k_2*x + (1 - k_2*k_1)*m
r_3 = r_1 - k_3*r_2
r_3 = x - k_1*m - k_3 * (-k_2*x + (1 - k_2*k_1)*m)
r_3 = x - k_1*m + k_3*k_2*x - k_3*m + k_3*k_2*k_1*m
r_3 = (1 - k_3*k_2)*x + (k_3*k_2*k_1 - k_3 - k_1)*m
...
1 = r_N-2 - k_N*r_N-1
In each step, we're able to represent r_n
as some constant times x
+ some constant times m
. If we ever write that last line, 1 = r_N-2 - k_N*r_N-1
, then because we can always write r_n
in terms of a*x + b*m
for integers a
and b
, then we know that a
is the number we're looking for.
We can always do this when the greatest common denominator is 1
because of how we're choosing our sequence. If x
and m
have a greatest common denominator d
, then we can write:
x = d*a // for integer a > 0
m = d*b // for integer b > 0
r = m - k*x // for integer r such that 0 <= r < x
r = d*b - k*d*a
r = d(b - k*a)
We know that r = d*b - k*d*a
. If we reach r = 0
, we know:
0 = d*b - k*d*a
d*b = k*d*a
b = k*a
m = d*k*a
x = d*a
Thus m
and x
also share a factor a
. a
must be 1
, otherwise we have a contradiction, since we previously stated the greatest common denominator was 1
, yet a
is a denominator greater than 1
.
Since a
and d
are both 1
, x
must also be 1
. Since each new x
is mod(m, x)
, that means that our previous relation was 1 = m - k*x
. Since we can always represent m
and x
in terms of the original m
and x
, we know we can write 1 = a*x + b*m
for integers a
and b
, and thus a
is the modular inverse of x
mod m
.