منشور

BuckeyeCTF – nitwit

I'm in my post-quantum crypto phase. Now featuring Winternitz one-time signatures.

BuckeyeCTF – nitwit

nitwit

Category: Crypto
Prompt: I’m in my post-quantum crypto phase. Now featuring Winternitz one-time signatures.
Artifact: ⬇️ nitwit.py
Flag format: bctf{...}


Intro

This challenge implements a toy Winternitz one-time signature (WOTS) scheme.
We are allowed to get one signature on a message without the substring "admin",
and then we must forge a valid signature on a new message that does contain "admin"
to obtain the flag.


Challenge Description

From the provided nitwit.py:

  • A secret key is expanded into many random chains x_i.
  • Each chain is hashed d times to get y_i = H^d(x_i).
  • The public key is pk = H(y_0 || y_1 || ... ).

To sign a message m:

  1. m is treated as a 256-bit integer and written in base-16 as a vector of n_0 = 64 nibbles.
  2. A checksum c = d * n_0 - sum(m_vec) is computed.
  3. Bug: the checksum is encoded in only 2 hex digits, although values up to 960 are possible (which would need 3 digits).
  4. The final exponent vector s is m_vec || c_vec (length 66).
  5. The signature is sig[i] = H^{s[i]}(x_i).

Verification recomputes the exponents from the new message, continues each chain up to depth d, and checks that the final concatenation hashes to the stored public key.

Because of the truncated checksum, the mapping from messages to exponent vectors is no longer strictly domination-free.


Exploit Strategy

In a correct WOTS, if you know a signature for message m_old with exponent vector s_old,
you can only derive signatures for messages whose vector s_new satisfies:

forall i: s_new[i] <= s_old[i]

Here, the truncated checksum lets us find two messages such that:

  • s_new[i] >= s_old[i] for all positions, and
  • both vectors still satisfy the (broken) checksum rule.

That means we can:

  1. Ask the server to sign a harmless message m_old (no "admin").
  2. Locally compute the exponents s_old and s_new (for a message containing "admin").
  3. For each component:

    forged[i] = H^{s_new[i] - s_old[i]}(sig_old[i])

    which effectively extends each hash chain from depth s_old[i] to s_new[i].

  4. Send (m_new, forged) as if it were a real signature.
    Verification will complete every chain up to depth d and accept, revealing the flag.

Steps

  1. Pick a simple base message m_old

    Use 32 bytes of zero:

    1
    
    m_old = b"\x00" * 32
    

    Get its signature from the remote service.

  2. Re-implement the broken vector function

    • Convert the 256-bit message to 64 base-16 digits.
    • Compute the checksum c = d * 64 - sum(digits).
    • Encode c in 2 hex digits only (this is the core bug).
    • Concatenate message digits and checksum digits.
  3. Choose m_new containing "admin"

    Craft a 32-byte message that includes b"admin" and whose exponent vector s_new is component-wise greater or equal to s_old (this is easy to do by padding with high-valued bytes so the sum and truncated checksum behave like we want).

  4. Forge the signature

    For each position i:

    1
    
    forged[i] = hash_chain(sig_old[i], s_new[i] - s_old[i])
    

    Then submit m_new and the forged signature.

  5. Receive the flag

    Once the server verifies that m_new contains "admin" and the signature matches the public key, it prints the flag.


Solver

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
import socket, ssl, hashlib, ast, re, sys, time

HOST = "nitwit.challs.pwnoh.io"
PORT = 1337
CONNECT_TIMEOUT = 15
READ_TIMEOUT = 2.5
POST_SEND_READ_LOOPS = 40

def h(x: bytes) -> bytes:
    return hashlib.sha256(x).digest()

def chain(x: bytes, t: int) -> bytes:
    for _ in range(t):
        x = h(x)
    return x

def int_to_vec_buggy(m: int, vec_len: int, base: int = 16):
    digits = [0] * vec_len
    i = vec_len - 1
    while m > 0:
        digits[i] = m % base
        m //= base
        i -= 1
    return digits

def df_vec(msg_bytes: bytes):
    d = 15
    n0 = 64
    n1 = 2
    m_int = int.from_bytes(msg_bytes, "big")
    m_vec = int_to_vec_buggy(m_int, n0, 16)
    c = d * n0 - sum(m_vec)
    c_vec = int_to_vec_buggy(c, n1, 16)
    return m_vec + c_vec

TOKEN_RE = re.compile(rb"""b(['"])(?:\\.|(?!\1).)*\1""", re.S)

def extract_sig_list(buf: bytes, expect: int = 66):
    sig = []
    for m in TOKEN_RE.finditer(buf):
        lit = m.group(0).decode("ascii", "strict")
        try:
            sig.append(ast.literal_eval(lit))
        except Exception:
            continue
        if len(sig) == expect:
            return sig
    return None

FLAG_RE = re.compile(rb'(?i)(flag\{[^}]+\}|pwnoh\{[^}]+\}|pwn\{[^}]+\})')

def main():
    m0 = b"\x00" * 32
    m_new = b"admin" + b"\xff" * 23 + b"\x40" + b"\x00" * 3
    assert len(m_new) == 32

    s_old = df_vec(m0)
    s_new = df_vec(m_new)
    for i, (a, b) in enumerate(zip(s_old, s_new)):
        if b < a:
            raise RuntimeError(f"delta negative at idx {i}: {b} < {a}")

    ctx = ssl.create_default_context()
    with socket.create_connection((HOST, PORT), timeout=CONNECT_TIMEOUT) as raw:
        with ctx.wrap_socket(raw, server_hostname=HOST) as s:
            s.settimeout(CONNECT_TIMEOUT)

            buf = b""
            while b">>> " not in buf:
                chunk = s.recv(8192)
                if not chunk:
                    break
                buf += chunk
                sys.stdout.write(chunk.decode(errors="ignore"))
                sys.stdout.flush()

            s.sendall((m0.hex() + "\n").encode())

            data = b""
            sig = None
            while True:
                chunk = s.recv(8192)
                if not chunk:
                    break
                data += chunk
                sys.stdout.write(chunk.decode(errors="ignore"))
                sys.stdout.flush()
                if sig is None:
                    sig = extract_sig_list(data, expect=66)
                if sig is not None and (b">>> " in data):
                    break

            if sig is None or len(sig) != 66:
                raise RuntimeError(f"[x] Failed to parse signature list (got {0 if sig is None else len(sig)})")

            s.sendall((m_new.hex() + "\n").encode())

            data2 = b""
            while b">>> " not in data2:
                chunk = s.recv(8192)
                if not chunk:
                    break
                data2 += chunk
                sys.stdout.write(chunk.decode(errors="ignore"))
                sys.stdout.flush()

            forged = [chain(sig[i], s_new[i] - s_old[i]) for i in range(len(s_new))]
            payload = repr(forged) + "\n"
            s.sendall(payload.encode())

            s.settimeout(READ_TIMEOUT)
            full = b""
            got_flag = None
            for _ in range(POST_SEND_READ_LOOPS):
                try:
                    chunk = s.recv(8192)
                    if not chunk:
                        break
                    full += chunk
                    sys.stdout.write(chunk.decode(errors="ignore"))
                    sys.stdout.flush()
                    m = FLAG_RE.search(full)
                    if m:
                        got_flag = m.group(1).decode(errors="ignore")
                        break
                except socket.timeout:
                    continue

            if got_flag:
                print("\n[+] FLAG:", got_flag)
            else:
                time.sleep(0.2)
                try:
                    while True:
                        chunk = s.recv(8192)
                        if not chunk:
                            break
                        full += chunk
                        sys.stdout.write(chunk.decode(errors="ignore"))
                        sys.stdout.flush()
                        m = FLAG_RE.search(full)
                        if m:
                            print("\n[+] FLAG:", m.group(1).decode(errors="ignore"))
                            break
                except Exception:
                    pass

if __name__ == "__main__":
    main()


Output:

1
FLAG: bctf{i_f0rg0t_h0w_t0_r3ad_m4th_n0t4t10n}
هذا المنشور تحت ترخيص CC BY 4.0 بواسطة المؤلف.