-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsolve.py
151 lines (120 loc) · 4.07 KB
/
solve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from pwn import process, remote
from gmpy2 import iroot
from random import Random
from sage.all import Zmod, PolynomialRing
from Crypto.Util.number import long_to_bytes
from tqdm import tqdm
N_BITS = 1024
PAD_SIZE = 64
# io = process(["python", "server.py"], env={"FLAG": "AIS3{bad_padding_and_bad_random}"})
io = remote("localhost", 6007)
io.recvuntil(b"n = ")
n = int(io.recvlineS().strip())
e = 11
def get_enc(x):
io.sendlineafter(b"> ", x)
return int(io.recvlineS().strip())
k = 0
for _ in range(N_BITS // PAD_SIZE):
k = (k << PAD_SIZE) | 1
kinv = pow(k, -e, n)
def get_outputs():
c = get_enc(b"0")
r, exact = iroot((kinv * c) % n, 11)
assert exact
r = int(r)
return [r & 0xFFFFFFFF, r >> 32]
# Modified from https://github.com/eboda/mersenne-twister-recover
class MT19937Recover:
"""Reverses the Mersenne Twister based on 624 observed outputs.
The internal state of a Mersenne Twister can be recovered by observing
624 generated outputs of it. However, if those are not directly
observed following a twist, another output is required to restore the
internal index.
See also https://en.wikipedia.org/wiki/Mersenne_Twister#Pseudocode .
"""
def unshiftRight(self, x, shift):
res = x
for i in range(32):
res = x ^ res >> shift
return res
def unshiftLeft(self, x, shift, mask):
res = x
for i in range(32):
res = x ^ (res << shift & mask)
return res
def untemper(self, v):
"""Reverses the tempering which is applied to outputs of MT19937"""
v = self.unshiftRight(v, 18)
v = self.unshiftLeft(v, 15, 0xEFC60000)
v = self.unshiftLeft(v, 7, 0x9D2C5680)
v = self.unshiftRight(v, 11)
return v
def go(self, outputs, forward=True):
"""Reverses the Mersenne Twister based on 624 observed values.
Args:
outputs (List[int]): list of >= 624 observed outputs from the PRNG.
However, >= 625 outputs are required to correctly recover
the internal index.
forward (bool): Forward internal state until all observed outputs
are generated.
Returns:
Returns a random.Random() object.
"""
result_state = None
assert len(outputs) >= 624 # need at least 624 values
ivals = []
for i in range(624):
ivals.append(self.untemper(outputs[i]))
if len(outputs) >= 625:
# We have additional outputs and can correctly
# recover the internal index by bruteforce
challenge = outputs[624]
for i in range(1, 626):
state = (3, tuple(ivals + [i]), None)
r = Random()
r.setstate(state)
if challenge == r.getrandbits(32):
result_state = state
break
else:
# With only 624 outputs we assume they were the first observed 624
# outputs after a twist --> we set the internal index to 624.
result_state = (3, tuple(ivals + [624]), None)
rand = Random()
rand.setstate(result_state)
if forward:
for i in range(624, len(outputs)):
assert rand.getrandbits(32) == outputs[i]
return rand
pb = tqdm(desc="Retrieving outputs", total=624)
outputs = []
while len(outputs) < 624:
outputs += get_outputs()
pb.n = len(outputs)
pb.update()
mt = MT19937Recover()
rand = mt.go(outputs)
def generate_padding(rand):
pad = rand.getrandbits(PAD_SIZE)
s = 0
for _ in range(N_BITS // PAD_SIZE):
s = (s << PAD_SIZE) | pad
return s
c1 = get_enc(b"flag")
pad1 = generate_padding(rand)
c2 = get_enc(b"flag")
pad2 = generate_padding(rand)
# (flag+pad1)^e=c1
# (flag+pad2)^e=c2
Z = Zmod(n)
P = PolynomialRing(Z, "x")
x = P.gen()
f = (x + pad1) ** e - c1
g = (x + pad2) ** e - c2
# sage doesn't have gcd for polynomial over a ring
# so we need to implement it ourselves
while g:
f, g = g, f % g
f = f.monic()
print(long_to_bytes(int(-f[0])).decode())