Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion packages/client/src/client/actions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest';
import type { SolanaClientRuntime } from '../rpc/types';
import type { ClientActions } from '../types';
import { createWalletRegistry } from '../wallet/registry';
import type { WalletConnector, WalletRegistry } from '../wallet/types';
import type { WalletAccount, WalletConnector, WalletRegistry } from '../wallet/types';
import { createActions } from './actions';
import { createDefaultClientStore } from './createClientStore';

Expand Down Expand Up @@ -256,4 +256,57 @@ describe('client actions', () => {
expect(signature).toBe(AIRDROP_SIGNATURE);
expect(airdropFactoryMock).toHaveBeenCalled();
});

it('updates the store account when the user switches accounts in their wallet', async () => {
let accountsChangedListener: ((accounts: WalletAccount[]) => void) | undefined;
walletConnector.connect = vi.fn(async () => ({
account: { address: ACCOUNT_ADDRESS, publicKey: new Uint8Array(32) },
connector: { id: 'wallet-1', name: 'Wallet 1' },
disconnect: vi.fn(async () => undefined),
signTransaction: vi.fn(),
onAccountsChanged: (listener: (accounts: WalletAccount[]) => void) => {
accountsChangedListener = listener;
return () => {
accountsChangedListener = undefined;
};
},
}));

await actions.connectWallet('wallet-1');
expect(store.getState().wallet.status).toBe('connected');

const NEW_ADDRESS = 'new-addr' as Address;
accountsChangedListener?.([{ address: NEW_ADDRESS, publicKey: new Uint8Array(32) }]);

const state = store.getState();
expect(state.wallet.status).toBe('connected');
if (state.wallet.status === 'connected') {
expect(state.wallet.session.account.address).toBe(NEW_ADDRESS);
}
});

it('disconnects when the wallet reports no accounts via onAccountsChanged', async () => {
let accountsChangedListener: ((accounts: WalletAccount[]) => void) | undefined;
walletConnector.connect = vi.fn(async () => ({
account: { address: ACCOUNT_ADDRESS, publicKey: new Uint8Array(32) },
connector: { id: 'wallet-1', name: 'Wallet 1' },
disconnect: vi.fn(async () => undefined),
signTransaction: vi.fn(),
onAccountsChanged: (listener: (accounts: WalletAccount[]) => void) => {
accountsChangedListener = listener;
return () => {
accountsChangedListener = undefined;
};
},
}));

await actions.connectWallet('wallet-1');
expect(store.getState().wallet.status).toBe('connected');

accountsChangedListener?.([]);
// disconnectWallet is called with void — drain all pending microtasks
await new Promise<void>((resolve) => setTimeout(resolve));

expect(store.getState().wallet.status).toBe('disconnected');
});
});
14 changes: 14 additions & 0 deletions packages/client/src/client/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,20 @@ export function createActions({ connectors, logger: inputLogger, runtime, store
walletEventsCleanup?.();
walletEventsCleanup = undefined;
void disconnectWallet();
} else {
updateState(store, {
wallet: {
autoConnect: autoConnectPreference,
connectorId: resolvedConnectorId,
session: { ...session, account: accounts[0] },
status: 'connected',
},
});
logger({
data: { address: accounts[0].address.toString(), connectorId: resolvedConnectorId },
level: 'info',
message: 'wallet account changed',
});
}
});
}
Expand Down
30 changes: 19 additions & 11 deletions packages/client/src/wallet/standard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,14 @@ export type WalletStandardSessionOptions = Readonly<{

export function createWalletStandardSession(options: WalletStandardSessionOptions): WalletSession {
const { account, defaultChain, disconnect, metadata, onAccountsChanged, wallet } = options;
let currentAccount = account;
let sessionAccount = toSessionAccount(currentAccount);

// Mutable session state: updated by account change listeners.
// All signing operations read from this object to ensure they always
// use the most recent account after a wallet-side account switch.
const sessionState = {
currentAccount: account,
sessionAccount: toSessionAccount(account),
};

const signMessageFeature = wallet.features[SolanaSignMessage] as
| SolanaSignMessageFeature[typeof SolanaSignMessage]
Expand All @@ -126,12 +132,12 @@ export function createWalletStandardSession(options: WalletStandardSessionOption
| SolanaSignAndSendTransactionFeature[typeof SolanaSignAndSendTransaction]
| undefined;

const resolvedChain = defaultChain ?? getChain(currentAccount);
const resolvedChain = defaultChain ?? getChain(sessionState.currentAccount);

const signMessage = signMessageFeature
? async (message: Uint8Array) => {
const [output] = await signMessageFeature.signMessage({
account: currentAccount,
account: sessionState.currentAccount,
message,
});
return output.signature;
Expand All @@ -143,12 +149,12 @@ export function createWalletStandardSession(options: WalletStandardSessionOption
const wireBytes = new Uint8Array(transactionEncoder.encode(transaction));
const request = resolvedChain
? {
account: currentAccount,
account: sessionState.currentAccount,
chain: resolvedChain,
transaction: wireBytes,
}
: {
account: currentAccount,
account: sessionState.currentAccount,
transaction: wireBytes,
};
const [output] = await signTransactionFeature.signTransaction(request);
Expand All @@ -159,9 +165,9 @@ export function createWalletStandardSession(options: WalletStandardSessionOption
const sendTransaction = signAndSendFeature
? async (transaction: SendableTransaction & Transaction, config?: Readonly<{ commitment?: Commitment }>) => {
const wireBytes = new Uint8Array(transactionEncoder.encode(transaction));
const chain: IdentifierString = defaultChain ?? getChain(currentAccount) ?? 'solana:mainnet-beta';
const chain: IdentifierString = defaultChain ?? getChain(sessionState.currentAccount) ?? 'solana:mainnet-beta';
const [output] = await signAndSendFeature.signAndSendTransaction({
account: currentAccount,
account: sessionState.currentAccount,
chain,
options: {
commitment: mapCommitment(config?.commitment),
Expand All @@ -187,8 +193,8 @@ export function createWalletStandardSession(options: WalletStandardSessionOption
listener([]);
return;
}
currentAccount = accounts[0];
sessionAccount = toSessionAccount(currentAccount);
sessionState.currentAccount = accounts[0];
sessionState.sessionAccount = toSessionAccount(sessionState.currentAccount);
listener(accounts.map(toSessionAccount));
});
changeUnsubscribe = off;
Expand All @@ -200,7 +206,9 @@ export function createWalletStandardSession(options: WalletStandardSessionOption
: undefined;

return {
account: sessionAccount,
get account() {
return sessionState.sessionAccount;
},
connector: metadata,
disconnect: disconnectSession,
onAccountsChanged: handleAccountsChanged,
Expand Down