BuckeyeCTF – nitwit
I'm in my post-quantum crypto phase. Now featuring Winternitz one-time signatures.
nitwit
Category: Crypto
Prompt: I’m in my post-quantum crypto phase. Now featuring Winternitz one-time signatures.
Artifact: ⬇️ nitwit.py
Flag format:bctf{...}
Intro
This challenge implements a toy Winternitz one-time signature (WOTS) scheme.
We are allowed to get one signature on a message without the substring "admin",
and then we must forge a valid signature on a new message that does contain "admin"
to obtain the flag.
Challenge Description
From the provided nitwit.py:
- A secret key is expanded into many random chains
x_i. - Each chain is hashed
dtimes to gety_i = H^d(x_i). - The public key is
pk = H(y_0 || y_1 || ... ).
To sign a message m:
mis treated as a 256-bit integer and written in base-16 as a vector ofn_0 = 64nibbles.- A checksum
c = d * n_0 - sum(m_vec)is computed. - Bug: the checksum is encoded in only 2 hex digits, although values up to 960 are possible (which would need 3 digits).
- The final exponent vector
sism_vec || c_vec(length 66). - The signature is
sig[i] = H^{s[i]}(x_i).
Verification recomputes the exponents from the new message, continues each chain up to depth d, and checks that the final concatenation hashes to the stored public key.
Because of the truncated checksum, the mapping from messages to exponent vectors is no longer strictly domination-free.
Exploit Strategy
In a correct WOTS, if you know a signature for message m_old with exponent vector s_old,
you can only derive signatures for messages whose vector s_new satisfies:
forall i: s_new[i] <= s_old[i]
Here, the truncated checksum lets us find two messages such that:
s_new[i] >= s_old[i]for all positions, and- both vectors still satisfy the (broken) checksum rule.
That means we can:
- Ask the server to sign a harmless message
m_old(no"admin"). - Locally compute the exponents
s_oldands_new(for a message containing"admin"). For each component:
forged[i] = H^{s_new[i] - s_old[i]}(sig_old[i])
which effectively extends each hash chain from depth
s_old[i]tos_new[i].- Send
(m_new, forged)as if it were a real signature.
Verification will complete every chain up to depthdand accept, revealing the flag.
Steps
Pick a simple base message
m_oldUse 32 bytes of zero:
1
m_old = b"\x00" * 32
Get its signature from the remote service.
Re-implement the broken vector function
- Convert the 256-bit message to 64 base-16 digits.
- Compute the checksum
c = d * 64 - sum(digits). - Encode
cin 2 hex digits only (this is the core bug). - Concatenate message digits and checksum digits.
Choose
m_newcontaining"admin"Craft a 32-byte message that includes
b"admin"and whose exponent vectors_newis component-wise greater or equal tos_old(this is easy to do by padding with high-valued bytes so the sum and truncated checksum behave like we want).Forge the signature
For each position
i:1
forged[i] = hash_chain(sig_old[i], s_new[i] - s_old[i])
Then submit
m_newand the forged signature.Receive the flag
Once the server verifies that
m_newcontains"admin"and the signature matches the public key, it prints the flag.
Solver
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
import socket, ssl, hashlib, ast, re, sys, time
HOST = "nitwit.challs.pwnoh.io"
PORT = 1337
CONNECT_TIMEOUT = 15
READ_TIMEOUT = 2.5
POST_SEND_READ_LOOPS = 40
def h(x: bytes) -> bytes:
return hashlib.sha256(x).digest()
def chain(x: bytes, t: int) -> bytes:
for _ in range(t):
x = h(x)
return x
def int_to_vec_buggy(m: int, vec_len: int, base: int = 16):
digits = [0] * vec_len
i = vec_len - 1
while m > 0:
digits[i] = m % base
m //= base
i -= 1
return digits
def df_vec(msg_bytes: bytes):
d = 15
n0 = 64
n1 = 2
m_int = int.from_bytes(msg_bytes, "big")
m_vec = int_to_vec_buggy(m_int, n0, 16)
c = d * n0 - sum(m_vec)
c_vec = int_to_vec_buggy(c, n1, 16)
return m_vec + c_vec
TOKEN_RE = re.compile(rb"""b(['"])(?:\\.|(?!\1).)*\1""", re.S)
def extract_sig_list(buf: bytes, expect: int = 66):
sig = []
for m in TOKEN_RE.finditer(buf):
lit = m.group(0).decode("ascii", "strict")
try:
sig.append(ast.literal_eval(lit))
except Exception:
continue
if len(sig) == expect:
return sig
return None
FLAG_RE = re.compile(rb'(?i)(flag\{[^}]+\}|pwnoh\{[^}]+\}|pwn\{[^}]+\})')
def main():
m0 = b"\x00" * 32
m_new = b"admin" + b"\xff" * 23 + b"\x40" + b"\x00" * 3
assert len(m_new) == 32
s_old = df_vec(m0)
s_new = df_vec(m_new)
for i, (a, b) in enumerate(zip(s_old, s_new)):
if b < a:
raise RuntimeError(f"delta negative at idx {i}: {b} < {a}")
ctx = ssl.create_default_context()
with socket.create_connection((HOST, PORT), timeout=CONNECT_TIMEOUT) as raw:
with ctx.wrap_socket(raw, server_hostname=HOST) as s:
s.settimeout(CONNECT_TIMEOUT)
buf = b""
while b">>> " not in buf:
chunk = s.recv(8192)
if not chunk:
break
buf += chunk
sys.stdout.write(chunk.decode(errors="ignore"))
sys.stdout.flush()
s.sendall((m0.hex() + "\n").encode())
data = b""
sig = None
while True:
chunk = s.recv(8192)
if not chunk:
break
data += chunk
sys.stdout.write(chunk.decode(errors="ignore"))
sys.stdout.flush()
if sig is None:
sig = extract_sig_list(data, expect=66)
if sig is not None and (b">>> " in data):
break
if sig is None or len(sig) != 66:
raise RuntimeError(f"[x] Failed to parse signature list (got {0 if sig is None else len(sig)})")
s.sendall((m_new.hex() + "\n").encode())
data2 = b""
while b">>> " not in data2:
chunk = s.recv(8192)
if not chunk:
break
data2 += chunk
sys.stdout.write(chunk.decode(errors="ignore"))
sys.stdout.flush()
forged = [chain(sig[i], s_new[i] - s_old[i]) for i in range(len(s_new))]
payload = repr(forged) + "\n"
s.sendall(payload.encode())
s.settimeout(READ_TIMEOUT)
full = b""
got_flag = None
for _ in range(POST_SEND_READ_LOOPS):
try:
chunk = s.recv(8192)
if not chunk:
break
full += chunk
sys.stdout.write(chunk.decode(errors="ignore"))
sys.stdout.flush()
m = FLAG_RE.search(full)
if m:
got_flag = m.group(1).decode(errors="ignore")
break
except socket.timeout:
continue
if got_flag:
print("\n[+] FLAG:", got_flag)
else:
time.sleep(0.2)
try:
while True:
chunk = s.recv(8192)
if not chunk:
break
full += chunk
sys.stdout.write(chunk.decode(errors="ignore"))
sys.stdout.flush()
m = FLAG_RE.search(full)
if m:
print("\n[+] FLAG:", m.group(1).decode(errors="ignore"))
break
except Exception:
pass
if __name__ == "__main__":
main()
Output:
1
FLAG: bctf{i_f0rg0t_h0w_t0_r3ad_m4th_n0t4t10n}
