diff --git a/main.py b/main.py index 44fecc6..4846826 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ from nacl.encoding import HexEncoder from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block +from minichain.rpc import JSONRPCServer from minichain.validators import is_valid_receiver from minichain.block import calculate_receipt_root @@ -113,7 +114,7 @@ def mine_and_process_block(chain, mempool, miner_pk): # Network message handler # ────────────────────────────────────────────── -def make_network_handler(chain, mempool): +def make_network_handler(chain, mempool, network): """Return an async callback that processes incoming P2P messages.""" async def handler(data): @@ -159,7 +160,33 @@ async def handler(data): # Drop only confirmed transactions so higher nonces can remain queued. mempool.remove_transactions(block.transactions) else: - logger.warning("📥 Received Block #%s — rejected", block.index) + if block.index > chain.last_block.index: + logger.warning("📥 Received Block #%s — ahead of us (tip: %s). Requesting chain sync...", block.index, chain.last_block.index) + asyncio.create_task(network.broadcast_chain_request()) + else: + logger.warning("📥 Received Block #%s — rejected", block.index) + + elif msg_type == "chain_request": + logger.info("📡 Peer requested chain sync. Broadcasting our chain...") + blocks_dicts = [b.to_dict() for b in chain.chain] + payload = {"type": "chain_response", "data": {"blocks": blocks_dicts}} + asyncio.create_task(network._broadcast_raw(payload)) + + elif msg_type == "chain_response": + blocks_payload = payload.get("blocks", []) + new_chain = [] + try: + new_chain = [Block.from_dict(b) for b in blocks_payload] + except Exception as e: + logger.warning("❌ Failed to parse chain_response: %s", e) + return + + if new_chain: + success, orphans = chain.resolve_conflicts(new_chain) + if success: + logger.info("🔄 Reorg complete! Restoring %d orphaned txs to mempool.", len(orphans)) + for tx in orphans: + mempool.add_transaction(tx) return handler @@ -389,8 +416,10 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data mempool = Mempool() network = P2PNetwork() - handler = make_network_handler(chain, mempool) + handler = make_network_handler(chain, mempool, network) network.register_handler(handler) + + rpc_server = JSONRPCServer(chain, mempool, network) # When a new peer connects, send our state so they can sync async def on_peer_connected(writer): @@ -406,6 +435,10 @@ async def on_peer_connected(writer): network.register_on_peer_connected(on_peer_connected) await network.start(port=port, host=host) + + # Start RPC server on a port correlated to the node port (e.g. 8545 if P2P is 9000) + rpc_port = 8545 + (port - 9000) + rpc_task = asyncio.create_task(rpc_server.start(host="127.0.0.1", port=rpc_port)) # Fund this node's wallet so it can transact in the demo if fund > 0: @@ -431,6 +464,9 @@ async def on_peer_connected(writer): logger.info("Chain saved to '%s'", datadir) except Exception as e: logger.error("Failed to save chain during shutdown: %s", e) + + if rpc_task: + rpc_task.cancel() await network.stop() diff --git a/minichain/chain.py b/minichain/chain.py index b37dcad..af7a43d 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -89,6 +89,9 @@ def _create_genesis_block(self, genesis_path): genesis_block.hash = computed_hash self.chain.append(genesis_block) + + # Snapshot the state exactly after genesis allocation for clean reorg rebuilds + self._genesis_state_snapshot = self.state.snapshot() @property def last_block(self): @@ -98,6 +101,16 @@ def last_block(self): with self._lock: # Acquire lock for thread-safe access return self.chain[-1] + def get_total_work(self, chain_list=None): + """ + Calculates the cumulative PoW of a chain. + Work is proportional to 2^difficulty. + """ + if chain_list is None: + with self._lock: + chain_list = self.chain + return sum(2 ** (block.difficulty or 1) for block in chain_list) + def add_block(self, block): """ Validates and adds a block to the chain if all transactions succeed. @@ -147,3 +160,77 @@ def add_block(self, block): self.state = temp_state self.chain.append(block) return True + + def resolve_conflicts(self, new_chain_list) -> tuple[bool, list]: + """ + Evaluates a competing chain. If it has strictly greater cumulative work, + attempts a reorg. Rebuilds state from genesis to guarantee validity. + Returns: (success_bool, list_of_orphaned_transactions) + """ + if not new_chain_list: + return False, [] + + with self._lock: + current_work = self.get_total_work() + new_work = self.get_total_work(new_chain_list) + + if new_work <= current_work: + logger.debug("Incoming chain (work: %s) is not heavier than local chain (work: %s). Rejecting.", new_work, current_work) + return False, [] + + # 1. Verify genesis block matches + if new_chain_list[0].hash != self.chain[0].hash: + logger.warning("Reorg failed: Genesis hash mismatch.") + return False, [] + + logger.info("Incoming chain is heavier (%s > %s). Attempting reorg...", new_work, current_work) + + # 2. Snapshot current state and chain in case reorg fails validation + state_snapshot = self.state.snapshot() + original_chain = list(self.chain) + + # 3. Rebuild state entirely from genesis using the new chain + temp_state = State() + temp_state.restore(self._genesis_state_snapshot) + + # Verify and apply blocks 1 to N + for i in range(1, len(new_chain_list)): + prev_block = new_chain_list[i-1] + block = new_chain_list[i] + + try: + validate_block_link_and_hash(prev_block, block) + except ValueError as exc: + logger.warning("Reorg failed at block %s: %s", block.index, exc) + return False, [] + + receipts = [] + for tx in block.transactions: + receipt = temp_state.validate_and_apply(tx) + if receipt is None: + logger.warning("Reorg failed: Transaction validation failed in block %s", block.index) + return False, [] + receipts.append(receipt) + + total_fees = sum(getattr(r, 'gas_used', 0) for r in receipts) + if block.miner: + temp_state.credit_mining_reward(block.miner, reward=temp_state.DEFAULT_MINING_REWARD + total_fees) + + computed_receipt_root = calculate_receipt_root(receipts) + if block.receipt_root != computed_receipt_root: + logger.warning("Reorg failed: Invalid receipt root at block %s. Expected %s, got %s", block.index, computed_receipt_root, block.receipt_root) + return False, [] + + if block.state_root != temp_state.state_root(): + logger.warning("Reorg failed: Invalid state root at block %s", block.index) + return False, [] + + # 4. Success! Compute orphaned transactions. + old_txs = {tx.tx_id: tx for b in original_chain[1:] for tx in b.transactions} + new_tx_ids = {tx.tx_id for b in new_chain_list[1:] for tx in b.transactions} + orphans = [tx for tx_id, tx in old_txs.items() if tx_id not in new_tx_ids] + + self.chain = new_chain_list + self.state = temp_state + logger.info("Reorg successful! Switched to new chain tip: Block %s", self.last_block.index) + return True, orphans diff --git a/minichain/mempool.py b/minichain/mempool.py index 6e3d8d9..946ca13 100644 --- a/minichain/mempool.py +++ b/minichain/mempool.py @@ -5,8 +5,7 @@ class Mempool: def __init__(self, max_size=1000, transactions_per_block=100): - self._pool = {} - self._size = 0 + self._list = [] # Single sorted list self._lock = threading.Lock() self.max_size = max_size self.transactions_per_block = transactions_per_block @@ -17,64 +16,62 @@ def add_transaction(self, tx): return False with self._lock: - existing = self._pool.get(tx.sender, {}).get(tx.nonce) + existing_idx = None + i_min = 0 + i_max = len(self._list) + + for i, existing_tx in enumerate(self._list): + if existing_tx.sender == tx.sender: + if existing_tx.nonce == tx.nonce: + existing_idx = i + elif existing_tx.nonce < tx.nonce: + # Must insert AFTER the largest lower-nonce transaction + i_min = max(i_min, i + 1) + elif existing_tx.nonce > tx.nonce: + # Must insert BEFORE the smallest higher-nonce transaction + i_max = min(i_max, i) - if existing: - if existing.tx_id == tx.tx_id: + if existing_idx is not None: + existing_tx = self._list[existing_idx] + if existing_tx.tx_id == tx.tx_id: logger.warning("Mempool: Duplicate transaction rejected %s", tx.tx_id) return False - # Fix: Guard against older replacements (e.g. rejected block restore) - # Only allow overwrite if it's a genuinely newer replacement - if tx.timestamp <= existing.timestamp: + if tx.timestamp <= existing_tx.timestamp: logger.warning("Mempool: Ignoring older replacement %s", tx.tx_id) return False + self._list.pop(existing_idx) + if i_max > existing_idx: + i_max -= 1 + if i_min > existing_idx: + i_min -= 1 else: - if self._size >= self.max_size: + if len(self._list) >= self.max_size: logger.warning("Mempool: Full, rejecting transaction") return False - self._size += 1 - self._pool.setdefault(tx.sender, {})[tx.nonce] = tx - return True - - def get_transactions_for_block(self): - with self._lock: - snapshot = {s: list(pool.values()) for s, pool in self._pool.items()} - for txs in snapshot.values(): - txs.sort(key=lambda t: t.nonce) + i_min = min(i_min, i_max) - selected = [] - while len(selected) < self.transactions_per_block: - best_tx = None - best_sender = None + # Insert before the first tx in [i_min, i_max] that has a lower fee + insert_idx = i_max + for j in range(i_min, i_max): + if getattr(self._list[j], 'fee', 0) < getattr(tx, 'fee', 0): + insert_idx = j + break - for sender, txs in snapshot.items(): - if txs: - current_criteria = (-getattr(txs[0], 'fee', 0), txs[0].timestamp, sender, txs[0].nonce) - best_criteria = (-getattr(best_tx, 'fee', 0), best_tx.timestamp, best_sender, best_tx.nonce) if best_tx else None - if best_tx is None or current_criteria < best_criteria: - best_tx = txs[0] - best_sender = sender - - if not best_tx: - break - - selected.append(best_tx) - snapshot[best_sender].pop(0) + self._list.insert(insert_idx, tx) + return True - return selected + def get_transactions_for_block(self): + with self._lock: + # O(1) retrieval! The list is strictly ordered upon insertion. + return list(self._list[:self.transactions_per_block]) def remove_transactions(self, transactions): with self._lock: - for tx in transactions: - pool = self._pool.get(tx.sender) - if pool and tx.nonce in pool: - del pool[tx.nonce] - self._size -= 1 - if not pool: - del self._pool[tx.sender] + keys_to_remove = {(tx.sender, tx.nonce) for tx in transactions} + self._list = [tx for tx in self._list if (tx.sender, tx.nonce) not in keys_to_remove] def __len__(self): with self._lock: - return self._size + return len(self._list) diff --git a/minichain/p2p.py b/minichain/p2p.py index 7bb1e34..262d102 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) TOPIC = "minichain-global" -SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"} +SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block", "chain_request", "chain_response"} class P2PNetwork: @@ -228,6 +228,21 @@ def _validate_block_payload(self, payload): for tx_payload in payload["transactions"] ) + def _validate_chain_request(self, payload): + if not isinstance(payload, dict): + return False + return True + + def _validate_chain_response(self, payload): + if not isinstance(payload, dict) or "blocks" not in payload: + return False + if not isinstance(payload["blocks"], list): + return False + for block_payload in payload["blocks"]: + if not self._validate_block_payload(block_payload): + return False + return True + def _validate_message(self, message): # FIX: Check if message is a dictionary first to prevent crashes if not isinstance(message, dict): @@ -249,6 +264,8 @@ def _validate_message(self, message): "sync": self._validate_sync_payload, "tx": self._validate_transaction_payload, "block": self._validate_block_payload, + "chain_request": self._validate_chain_request, + "chain_response": self._validate_chain_response, } return validators[msg_type](payload) @@ -385,6 +402,21 @@ async def broadcast_block(self, block): self._mark_seen("block", payload["data"]) await self._broadcast_raw(payload) + async def broadcast_chain_request(self): + logger.info("Network: Broadcasting chain request") + payload = {"type": "chain_request", "data": {}} + await self._broadcast_raw(payload) + + async def send_chain_response(self, blocks_dicts, writer): + logger.info("Network: Sending chain response with %d blocks", len(blocks_dicts)) + payload = {"type": "chain_response", "data": {"blocks": blocks_dicts}} + line = (canonical_json_dumps(payload) + "\n").encode() + try: + writer.write(line) + await writer.drain() + except Exception as e: + logger.error("Network: Failed to send chain response: %s", e) + @property def peer_count(self) -> int: return len(self._peers) diff --git a/minichain/rpc.py b/minichain/rpc.py new file mode 100644 index 0000000..542edfd --- /dev/null +++ b/minichain/rpc.py @@ -0,0 +1,87 @@ +import logging +import json +import asyncio +from aiohttp import web +from minichain.transaction import Transaction + +logger = logging.getLogger(__name__) + +class JSONRPCServer: + def __init__(self, chain, mempool, network): + self.chain = chain + self.mempool = mempool + self.network = network + self.app = web.Application() + self.app.add_routes([web.post('/', self.handle_rpc)]) + + async def start(self, host="127.0.0.1", port=8545): + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, host, port) + await site.start() + logger.info("🚀 JSON-RPC Server running on http://%s:%d", host, port) + + async def handle_rpc(self, request): + try: + req_data = await request.json() + except json.JSONDecodeError: + return web.json_response({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}) + + if isinstance(req_data, list): + responses = [] + for req in req_data: + responses.append(await self._process_single(req)) + return web.json_response(responses) + else: + response = await self._process_single(req_data) + return web.json_response(response) + + async def _process_single(self, req): + if not isinstance(req, dict) or "method" not in req or req.get("jsonrpc") != "2.0": + return {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid Request"}, "id": req.get("id")} + + method = req["method"] + params = req.get("params", []) + req_id = req.get("id") + + try: + if method == "mc_blockNumber": + result = self.chain.last_block.index + elif method == "mc_getBlockByNumber": + if not params: + raise ValueError("Missing block number") + idx = params[0] + if idx == "latest": + block = self.chain.last_block + else: + idx = int(idx) + if idx < 0 or idx >= len(self.chain.chain): + block = None + else: + block = self.chain.chain[idx] + result = block.to_dict() if block else None + elif method == "mc_getBalance": + if not params: + raise ValueError("Missing address") + address = params[0] + account = self.chain.state.get_account(address) + result = account["balance"] if account else 0 + elif method == "mc_sendTransaction": + if not params: + raise ValueError("Missing transaction payload") + tx_data = params[0] + tx = Transaction.from_dict(tx_data) + if not tx.verify(): + raise ValueError("Invalid signature") + if self.mempool.add_transaction(tx): + asyncio.create_task(self.network.broadcast_transaction(tx)) + result = tx.tx_id + else: + raise ValueError("Transaction rejected by Mempool") + else: + return {"jsonrpc": "2.0", "error": {"code": -32601, "message": f"Method not found: {method}"}, "id": req_id} + + return {"jsonrpc": "2.0", "result": result, "id": req_id} + except Exception as e: + logger.error("RPC Error processing %s: %s", method, e) + return {"jsonrpc": "2.0", "error": {"code": -32000, "message": str(e)}, "id": req_id} diff --git a/minichain/state.py b/minichain/state.py index f817c6c..73759f8 100644 --- a/minichain/state.py +++ b/minichain/state.py @@ -67,6 +67,18 @@ def copy(self): new_state.contract_machine = ContractMachine(new_state) # Reinitialize contract_machine return new_state + def snapshot(self): + """ + Returns a deep copy of the current accounts dictionary for rollback safety. + """ + return copy.deepcopy(self.accounts) + + def restore(self, snapshot_data): + """ + Restores the state's accounts dictionary from a snapshot. + """ + self.accounts = copy.deepcopy(snapshot_data) + def validate_and_apply(self, tx): """ Validate and apply a transaction. diff --git a/tests/test_reorg.py b/tests/test_reorg.py new file mode 100644 index 0000000..f678bea --- /dev/null +++ b/tests/test_reorg.py @@ -0,0 +1,117 @@ +import pytest +import os +import json +import time + +from minichain.chain import Blockchain +from minichain.transaction import Transaction +from minichain.mempool import Mempool + +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from main import mine_and_process_block + +from nacl.signing import SigningKey +from nacl.encoding import HexEncoder + +@pytest.fixture +def genesis_file(tmp_path): + path = tmp_path / "genesis_reorg.json" + sk = SigningKey.generate() + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + data = { + "timestamp": int(time.time()), + "difficulty": 0, + "alloc": { + pk: {"balance": 1000} + } + } + with open(path, "w") as f: + json.dump(data, f) + return str(path), sk, pk + +def test_resolve_conflicts_heavier_chain(genesis_file): + g_path, sk, pk = genesis_file + + node_a = Blockchain(genesis_path=g_path) + node_b = Blockchain(genesis_path=g_path) + + assert node_a.get_total_work() == node_b.get_total_work() + + pool_b = Mempool() + tx = Transaction(sender=pk, receiver="b"*64, amount=10, nonce=0, fee=1) + tx.sign(sk) + pool_b.add_transaction(tx) + + mined_b = mine_and_process_block(node_b, pool_b, pk) + assert mined_b is not None + assert node_b.get_total_work() > node_a.get_total_work() + + # Node A receives Node B's chain + success, orphans = node_a.resolve_conflicts(node_b.chain) + + assert success is True + assert node_a.last_block.hash == node_b.last_block.hash + assert node_a.state.accounts == node_b.state.accounts + assert len(orphans) == 0 + +def test_resolve_conflicts_reorg_with_orphans(genesis_file): + g_path, sk, pk = genesis_file + + node_a = Blockchain(genesis_path=g_path) + node_b = Blockchain(genesis_path=g_path) + + pool_a = Mempool() + pool_b = Mempool() + + # Node A mines tx1 (nonce 0) + tx1 = Transaction(sender=pk, receiver="a"*64, amount=10, nonce=0, fee=1) + tx1.sign(sk) + pool_a.add_transaction(tx1) + mine_and_process_block(node_a, pool_a, pk) + + # Node B mines tx2 (nonce 0, competing transaction) + tx2 = Transaction(sender=pk, receiver="b"*64, amount=20, nonce=0, fee=1) + tx2.sign(sk) + pool_b.add_transaction(tx2) + mine_and_process_block(node_b, pool_b, pk) + + # Node B mines tx3 (nonce 1) to become the heavier chain + tx3 = Transaction(sender=pk, receiver="c"*64, amount=30, nonce=1, fee=1) + tx3.sign(sk) + pool_b.add_transaction(tx3) + block_b2 = mine_and_process_block(node_b, pool_b, pk) + + assert node_b.get_total_work() > node_a.get_total_work() + + # Node A attempts reorg using B's heavier chain + success, orphans = node_a.resolve_conflicts(node_b.chain) + + assert success is True + assert node_a.last_block.hash == block_b2.hash + + # tx1 was in A's chain but NOT in B's chain. It should be orphaned. + assert len(orphans) == 1 + assert orphans[0].tx_id == tx1.tx_id + +def test_resolve_conflicts_rejects_lighter_chain(genesis_file): + g_path, sk, pk = genesis_file + + node_a = Blockchain(genesis_path=g_path) + node_b = Blockchain(genesis_path=g_path) + + pool_a = Mempool() + + # Node A mines a block + tx1 = Transaction(sender=pk, receiver="a"*64, amount=10, nonce=0, fee=1) + tx1.sign(sk) + pool_a.add_transaction(tx1) + mine_and_process_block(node_a, pool_a, pk) + + # Node B is empty. It tries to reorg Node A with its shorter chain. + success, orphans = node_a.resolve_conflicts(node_b.chain) + + assert success is False + assert len(orphans) == 0 + assert node_a.get_total_work() > node_b.get_total_work() diff --git a/tests/test_rpc.py b/tests/test_rpc.py new file mode 100644 index 0000000..1ac92d3 --- /dev/null +++ b/tests/test_rpc.py @@ -0,0 +1,62 @@ +import pytest +import aiohttp +import asyncio +from minichain.chain import Blockchain +from minichain.mempool import Mempool +from minichain.p2p import P2PNetwork +from minichain.rpc import JSONRPCServer + +@pytest.fixture +def anyio_backend(): + return 'asyncio' + +@pytest.fixture +async def rpc_server(free_tcp_port): + chain = Blockchain() + mempool = Mempool() + network = P2PNetwork() + + server = JSONRPCServer(chain, mempool, network) + port = free_tcp_port + await server.start(host="127.0.0.1", port=port) + + yield server, port, chain, mempool + + await server.app.cleanup() + +@pytest.mark.anyio +async def test_rpc_blockNumber(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = {"jsonrpc": "2.0", "method": "mc_blockNumber", "id": 1} + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["result"] == 0 + assert data["id"] == 1 + +@pytest.mark.anyio +async def test_rpc_getBlockByNumber(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = {"jsonrpc": "2.0", "method": "mc_getBlockByNumber", "params": [0], "id": 2} + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["result"]["index"] == 0 + assert data["id"] == 2 + +@pytest.mark.anyio +async def test_rpc_invalid_method(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = {"jsonrpc": "2.0", "method": "mc_unknown", "id": 3} + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert "error" in data + assert data["error"]["code"] == -32601 + assert data["id"] == 3