diff --git a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultWorkloadApiClient.java b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultWorkloadApiClient.java index fd42c9ad..da9b6fa2 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultWorkloadApiClient.java +++ b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/DefaultWorkloadApiClient.java @@ -171,14 +171,18 @@ public X509Context fetchX509Context() throws X509ContextException { public void watchX509Context(Watcher watcher) { Objects.requireNonNull(watcher, "watcher must not be null"); - final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor); - final Context.CancellableContext cancellableContext = Context.current().withCancellation(); + synchronized (this) { + assertOpen(); + + final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor); + final Context.CancellableContext cancellableContext = Context.current().withCancellation(); - final StreamObserver streamObserver = - getX509ContextStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub); + final StreamObserver streamObserver = + getX509ContextStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub); - cancellableContext.run(() -> workloadApiAsyncStub.fetchX509SVID(newX509SvidRequest(), streamObserver)); - this.cancellableContexts.add(cancellableContext); + cancellableContext.run(() -> workloadApiAsyncStub.fetchX509SVID(newX509SvidRequest(), streamObserver)); + this.cancellableContexts.add(cancellableContext); + } } /** @@ -200,14 +204,18 @@ public X509BundleSet fetchX509Bundles() throws X509BundleException { public void watchX509Bundles(Watcher watcher) { Objects.requireNonNull(watcher, "watcher must not be null"); - final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor); - final Context.CancellableContext cancellableContext = Context.current().withCancellation(); + synchronized (this) { + assertOpen(); - final StreamObserver streamObserver = - getX509BundlesStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub); + final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor); + final Context.CancellableContext cancellableContext = Context.current().withCancellation(); - cancellableContext.run(() -> workloadApiAsyncStub.fetchX509Bundles(newX509BundlesRequest(), streamObserver)); - this.cancellableContexts.add(cancellableContext); + final StreamObserver streamObserver = + getX509BundlesStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub); + + cancellableContext.run(() -> workloadApiAsyncStub.fetchX509Bundles(newX509BundlesRequest(), streamObserver)); + this.cancellableContexts.add(cancellableContext); + } } /** @@ -331,13 +339,17 @@ public JwtSvid validateJwtSvid(String token, String audience) public void watchJwtBundles(Watcher watcher) { Objects.requireNonNull(watcher, "watcher must not be null"); - RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor); - Context.CancellableContext cancellableContext = Context.current().withCancellation(); + synchronized (this) { + assertOpen(); + + RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor); + Context.CancellableContext cancellableContext = Context.current().withCancellation(); - StreamObserver streamObserver = getJwtBundleStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub); + StreamObserver streamObserver = getJwtBundleStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub); - cancellableContext.run(() -> workloadApiAsyncStub.fetchJWTBundles(newJwtBundlesRequest(), streamObserver)); - this.cancellableContexts.add(cancellableContext); + cancellableContext.run(() -> workloadApiAsyncStub.fetchJWTBundles(newJwtBundlesRequest(), streamObserver)); + this.cancellableContexts.add(cancellableContext); + } } /** @@ -349,7 +361,13 @@ public void close() { log.log(Level.FINE, "Closing WorkloadAPI client"); synchronized (this) { if (!closed) { - for (Context.CancellableContext context : cancellableContexts) { + final List contexts; + synchronized (cancellableContexts) { + contexts = new ArrayList<>(cancellableContexts); + cancellableContexts.clear(); + } + + for (Context.CancellableContext context : contexts) { context.close(); } @@ -366,6 +384,12 @@ public void close() { } + private void assertOpen() { + if (closed) { + throw new IllegalStateException("Cannot register watch on closed Workload API client"); + } + } + private X509Context callFetchX509Context() throws X509ContextException { Iterator x509SvidResponse = workloadApiBlockingStub.fetchX509SVID(newX509SvidRequest()); return GrpcConversionUtils.toX509Context(x509SvidResponse); diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultWorkloadApiClientTest.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultWorkloadApiClientTest.java index 2b273d84..9e62752e 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultWorkloadApiClientTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/DefaultWorkloadApiClientTest.java @@ -1,6 +1,7 @@ package io.spiffe.workloadapi; import com.nimbusds.jose.jwk.Curve; +import io.grpc.Context; import io.grpc.testing.GrpcCleanupRule; import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.bundle.jwtbundle.JwtBundleSet; @@ -23,10 +24,14 @@ import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; import java.io.IOException; +import java.lang.reflect.Field; import java.security.KeyPair; +import java.util.AbstractList; +import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -34,6 +39,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -156,6 +162,16 @@ void testWatchX509ContextNullWatcher_throwsNullPointerException() { } } + @Test + void testWatchX509ContextAfterClose_throwsIllegalStateException() { + workloadApiClient.close(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, + () -> workloadApiClient.watchX509Context(noopWatcher())); + + assertEquals("Cannot register watch on closed Workload API client", exception.getMessage()); + } + @Test void testFetchX509Bundles() { X509BundleSet x509BundleSet = null; @@ -220,6 +236,16 @@ void testWatchX509BundlesNullWatcher_throwsNullPointerException() { } } + @Test + void testWatchX509BundlesAfterClose_throwsIllegalStateException() { + workloadApiClient.close(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, + () -> workloadApiClient.watchX509Bundles(noopWatcher())); + + assertEquals("Cannot register watch on closed Workload API client", exception.getMessage()); + } + @Test void testFetchJwtSvid() { @@ -454,6 +480,32 @@ void testWatchJwtBundlesNullWatcher_throwsNullPointerException() { } } + @Test + void testWatchJwtBundlesAfterClose_throwsIllegalStateException() { + workloadApiClient.close(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, + () -> workloadApiClient.watchJwtBundles(noopWatcher())); + + assertEquals("Cannot register watch on closed Workload API client", exception.getMessage()); + } + + @Test + void testCloseClosesRegisteredContextsFromSynchronizedSnapshot() throws Exception { + Context.CancellableContext firstContext = Context.current().withCancellation(); + Context.CancellableContext secondContext = Context.current().withCancellation(); + MonitorCheckingCancellableContexts contexts = new MonitorCheckingCancellableContexts(); + contexts.add(firstContext); + contexts.add(secondContext); + replaceCancellableContexts(contexts); + + workloadApiClient.close(); + + assertTrue(firstContext.isCancelled()); + assertTrue(secondContext.isCancelled()); + assertTrue(contexts.isEmpty()); + } + private String generateToken(String sub, List aud) { Map claims = new HashMap<>(); @@ -466,4 +518,66 @@ private String generateToken(String sub, List aud) { return TestUtils.generateToken(claims, keyPair, "authority1"); } + private static Watcher noopWatcher() { + return new Watcher() { + @Override + public void onUpdate(T update) { + } + + @Override + public void onError(Throwable e) { + } + }; + } + + private void replaceCancellableContexts(List contexts) throws Exception { + Field field = DefaultWorkloadApiClient.class.getDeclaredField("cancellableContexts"); + field.setAccessible(true); + field.set(workloadApiClient, contexts); + } + + private static final class MonitorCheckingCancellableContexts extends AbstractList { + + private final List delegate = new ArrayList<>(); + + @Override + public Context.CancellableContext get(int index) { + return delegate.get(index); + } + + @Override + public int size() { + return delegate.size(); + } + + @Override + public boolean add(Context.CancellableContext context) { + return delegate.add(context); + } + + @Override + public Iterator iterator() { + assertListMonitorHeld(); + return delegate.iterator(); + } + + @Override + public Object[] toArray() { + assertListMonitorHeld(); + return delegate.toArray(); + } + + @Override + public void clear() { + assertListMonitorHeld(); + delegate.clear(); + } + + private void assertListMonitorHeld() { + if (!Thread.holdsLock(this)) { + throw new AssertionError("cancellableContexts must be accessed while synchronized on the list"); + } + } + } + } \ No newline at end of file