Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,18 @@ public X509Context fetchX509Context() throws X509ContextException {
public void watchX509Context(Watcher<X509Context> watcher) {
Objects.requireNonNull(watcher, "watcher must not be null");

final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
final Context.CancellableContext cancellableContext = Context.current().withCancellation();
synchronized (this) {
assertOpen();

final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
final Context.CancellableContext cancellableContext = Context.current().withCancellation();

final StreamObserver<Workload.X509SVIDResponse> streamObserver =
getX509ContextStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);
final StreamObserver<Workload.X509SVIDResponse> streamObserver =
getX509ContextStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);

cancellableContext.run(() -> workloadApiAsyncStub.fetchX509SVID(newX509SvidRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
cancellableContext.run(() -> workloadApiAsyncStub.fetchX509SVID(newX509SvidRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
}
}

/**
Expand All @@ -200,14 +204,18 @@ public X509BundleSet fetchX509Bundles() throws X509BundleException {
public void watchX509Bundles(Watcher<X509BundleSet> watcher) {
Objects.requireNonNull(watcher, "watcher must not be null");

final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
final Context.CancellableContext cancellableContext = Context.current().withCancellation();
synchronized (this) {
assertOpen();

final StreamObserver<Workload.X509BundlesResponse> streamObserver =
getX509BundlesStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);
final RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
final Context.CancellableContext cancellableContext = Context.current().withCancellation();

cancellableContext.run(() -> workloadApiAsyncStub.fetchX509Bundles(newX509BundlesRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
final StreamObserver<Workload.X509BundlesResponse> streamObserver =
getX509BundlesStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);

cancellableContext.run(() -> workloadApiAsyncStub.fetchX509Bundles(newX509BundlesRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
}
}

/**
Expand Down Expand Up @@ -331,13 +339,17 @@ public JwtSvid validateJwtSvid(String token, String audience)
public void watchJwtBundles(Watcher<JwtBundleSet> watcher) {
Objects.requireNonNull(watcher, "watcher must not be null");

RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
Context.CancellableContext cancellableContext = Context.current().withCancellation();
synchronized (this) {
assertOpen();

RetryHandler retryHandler = new RetryHandler(exponentialBackoffPolicy, retryExecutor);
Context.CancellableContext cancellableContext = Context.current().withCancellation();

StreamObserver<Workload.JWTBundlesResponse> streamObserver = getJwtBundleStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);
StreamObserver<Workload.JWTBundlesResponse> streamObserver = getJwtBundleStreamObserver(watcher, retryHandler, cancellableContext, workloadApiAsyncStub);

cancellableContext.run(() -> workloadApiAsyncStub.fetchJWTBundles(newJwtBundlesRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
cancellableContext.run(() -> workloadApiAsyncStub.fetchJWTBundles(newJwtBundlesRequest(), streamObserver));
this.cancellableContexts.add(cancellableContext);
}
}

/**
Expand All @@ -349,7 +361,13 @@ public void close() {
log.log(Level.FINE, "Closing WorkloadAPI client");
synchronized (this) {
if (!closed) {
for (Context.CancellableContext context : cancellableContexts) {
final List<Context.CancellableContext> contexts;
synchronized (cancellableContexts) {
contexts = new ArrayList<>(cancellableContexts);
cancellableContexts.clear();
}

for (Context.CancellableContext context : contexts) {
context.close();
}

Expand All @@ -366,6 +384,12 @@ public void close() {
}


private void assertOpen() {
if (closed) {
throw new IllegalStateException("Cannot register watch on closed Workload API client");
}
}

private X509Context callFetchX509Context() throws X509ContextException {
Iterator<Workload.X509SVIDResponse> x509SvidResponse = workloadApiBlockingStub.fetchX509SVID(newX509SvidRequest());
return GrpcConversionUtils.toX509Context(x509SvidResponse);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.spiffe.workloadapi;

import com.nimbusds.jose.jwk.Curve;
import io.grpc.Context;
import io.grpc.testing.GrpcCleanupRule;
import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
Expand All @@ -23,17 +24,22 @@
import uk.org.webcompere.systemstubs.environment.EnvironmentVariables;

import java.io.IOException;
import java.lang.reflect.Field;
import java.security.KeyPair;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -156,6 +162,16 @@ void testWatchX509ContextNullWatcher_throwsNullPointerException() {
}
}

@Test
void testWatchX509ContextAfterClose_throwsIllegalStateException() {
workloadApiClient.close();

IllegalStateException exception = assertThrows(IllegalStateException.class,
() -> workloadApiClient.watchX509Context(noopWatcher()));

assertEquals("Cannot register watch on closed Workload API client", exception.getMessage());
}

@Test
void testFetchX509Bundles() {
X509BundleSet x509BundleSet = null;
Expand Down Expand Up @@ -220,6 +236,16 @@ void testWatchX509BundlesNullWatcher_throwsNullPointerException() {
}
}

@Test
void testWatchX509BundlesAfterClose_throwsIllegalStateException() {
workloadApiClient.close();

IllegalStateException exception = assertThrows(IllegalStateException.class,
() -> workloadApiClient.watchX509Bundles(noopWatcher()));

assertEquals("Cannot register watch on closed Workload API client", exception.getMessage());
}


@Test
void testFetchJwtSvid() {
Expand Down Expand Up @@ -454,6 +480,32 @@ void testWatchJwtBundlesNullWatcher_throwsNullPointerException() {
}
}

@Test
void testWatchJwtBundlesAfterClose_throwsIllegalStateException() {
workloadApiClient.close();

IllegalStateException exception = assertThrows(IllegalStateException.class,
() -> workloadApiClient.watchJwtBundles(noopWatcher()));

assertEquals("Cannot register watch on closed Workload API client", exception.getMessage());
}

@Test
void testCloseClosesRegisteredContextsFromSynchronizedSnapshot() throws Exception {
Context.CancellableContext firstContext = Context.current().withCancellation();
Context.CancellableContext secondContext = Context.current().withCancellation();
MonitorCheckingCancellableContexts contexts = new MonitorCheckingCancellableContexts();
contexts.add(firstContext);
contexts.add(secondContext);
replaceCancellableContexts(contexts);

workloadApiClient.close();

assertTrue(firstContext.isCancelled());
assertTrue(secondContext.isCancelled());
assertTrue(contexts.isEmpty());
}


private String generateToken(String sub, List<String> aud) {
Map<String, Object> claims = new HashMap<>();
Expand All @@ -466,4 +518,66 @@ private String generateToken(String sub, List<String> aud) {
return TestUtils.generateToken(claims, keyPair, "authority1");
}

private static <T> Watcher<T> noopWatcher() {
return new Watcher<T>() {
@Override
public void onUpdate(T update) {
}

@Override
public void onError(Throwable e) {
}
};
}

private void replaceCancellableContexts(List<Context.CancellableContext> contexts) throws Exception {
Field field = DefaultWorkloadApiClient.class.getDeclaredField("cancellableContexts");
field.setAccessible(true);
field.set(workloadApiClient, contexts);
}

private static final class MonitorCheckingCancellableContexts extends AbstractList<Context.CancellableContext> {

private final List<Context.CancellableContext> delegate = new ArrayList<>();

@Override
public Context.CancellableContext get(int index) {
return delegate.get(index);
}

@Override
public int size() {
return delegate.size();
}

@Override
public boolean add(Context.CancellableContext context) {
return delegate.add(context);
}

@Override
public Iterator<Context.CancellableContext> iterator() {
assertListMonitorHeld();
return delegate.iterator();
}

@Override
public Object[] toArray() {
assertListMonitorHeld();
return delegate.toArray();
}

@Override
public void clear() {
assertListMonitorHeld();
delegate.clear();
}

private void assertListMonitorHeld() {
if (!Thread.holdsLock(this)) {
throw new AssertionError("cancellableContexts must be accessed while synchronized on the list");
}
}
}

}