Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

/**
* Mutable assembly context passed to each {@link ChainIdentityProvider#setup} during
* credential chain construction.
* {@link IdentityChain} construction.
*
* <p>Providers use this to:
* <ul>
Expand All @@ -37,6 +37,7 @@
public final class ChainSetup {
private final ScheduledExecutorService executor;
private final String profileNameOverride;
private final String regionOverride;
private final Context properties;
private final List<NamedResolver> resolvers = new ArrayList<>();
private final Function<String, String> envFn;
Expand All @@ -48,6 +49,7 @@ public final class ChainSetup {
private ChainSetup(Builder builder) {
this.executor = builder.executor;
this.profileNameOverride = builder.profileNameOverride;
this.regionOverride = builder.regionOverride;
this.properties = Context.create();
this.envFn = builder.envFn;
this.profileFile = builder.profileFile;
Expand Down Expand Up @@ -82,6 +84,19 @@ public String profileNameOverride() {
return profileNameOverride;
}

/**
* Returns the client-specified region override used by providers that resolve credentials via a service call
* (e.g., STS, SSO), or {@code null} to fall back to the {@code AWS_REGION}/{@code AWS_DEFAULT_REGION}
* environment variables and the profile {@code region} property.
*
* <p>This affects only which regional endpoint those providers call; it does not parameterize the chain itself.
*
* @return the region override, or {@code null}.
*/
public String regionOverride() {
return regionOverride;
}

/**
* Returns the value of the given environment variable, or {@code null} if not set.
*
Expand Down Expand Up @@ -214,6 +229,7 @@ public record NamedResolver(String name, Set<CredentialFeatureId> featureIds, Id
public static final class Builder {
private ScheduledExecutorService executor;
private String profileNameOverride;
private String regionOverride;
private Function<String, String> envFn = System::getenv;
private AwsProfileFile profileFile;

Expand Down Expand Up @@ -242,6 +258,19 @@ public Builder profileNameOverride(String profileNameOverride) {
return this;
}

/**
* Sets the region override. When set, providers that resolve credentials via a service call (e.g., STS,
* SSO) use this region for their endpoint instead of resolving it from {@code AWS_REGION},
* {@code AWS_DEFAULT_REGION}, or the profile {@code region} property.
*
* @param regionOverride the region to use, e.g. {@code "us-west-2"}.
* @return this builder.
*/
public Builder regionOverride(String regionOverride) {
this.regionOverride = regionOverride;
return this;
}

/**
* Sets the function used to resolve environment variables. Defaults to
* {@link System#getenv(String)}. Override in tests for isolation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,32 @@
import software.amazon.smithy.java.logging.InternalLogger;

/**
* A credential provider chain.
* A chain of identity providers, parameterized by the {@link Identity} type it resolves (e.g.,
* {@code AwsCredentialsIdentity} or {@code TokenIdentity}).
*
* <p>Discovers {@link ChainIdentityProvider} implementations via {@link ServiceLoader}, assembles them into an
* ordered chain based on {@link StandardProvider} slots and relative ordering constraints, and resolves
* credentials by trying each provider in order.
* ordered chain based on {@link StandardProvider} slots and relative ordering constraints, and resolves an
* identity by trying each provider in order.
*
* <p>Usage:
* {@snippet lang="java" :
* var chain = CredentialChain.create();
* {@snippet lang = "java":
* var chain = IdentityChain.create();
* var result = chain.resolveIdentity(Context.empty());
* }
*}
*
* <p>The chain is assembled once at creation time. Providers that are not on the classpath simply don't
* participate: their slots are skipped. If no provider in the chain can resolve credentials, the chain returns an
* participate: their slots are skipped. If no provider in the chain can resolve an identity, the chain returns an
* error result describing which providers were tried.
*/
public final class CredentialChain<I extends Identity> implements IdentityResolver<I>, AutoCloseable {
public final class IdentityChain<I extends Identity> implements IdentityResolver<I>, AutoCloseable {

private static final InternalLogger LOGGER = InternalLogger.getLogger(CredentialChain.class);
private static final InternalLogger LOGGER = InternalLogger.getLogger(IdentityChain.class);

private final Class<I> identityType;
private final List<ChainSetup.NamedResolver> resolvers;
private final ScheduledExecutorService executor;

private CredentialChain(
private IdentityChain(
Class<I> identityType,
List<ChainSetup.NamedResolver> resolvers,
ScheduledExecutorService executor
Expand All @@ -59,60 +60,92 @@ private CredentialChain(
}

/**
* Create a credential chain by discovering providers via ServiceLoader.
* Create an identity chain by discovering providers via ServiceLoader.
*
* @param identityType Identity type to resolve.
* @return the assembled chain.
* @throws IllegalStateException if two providers claim the same standard slot.
*/
public static <I extends Identity> CredentialChain<I> create(Class<I> identityType) {
return create(identityType, Executors.newSingleThreadScheduledExecutor(r2 -> {
public static <I extends Identity> IdentityChain<I> create(Class<I> identityType) {
return create(identityType, defaultExecutor(), null, null);
}

/**
* Create an identity chain by discovering providers via ServiceLoader, using a caller-supplied AWS
* config/credentials file and region, with a default background-refresh executor.
*
* @param identityType Identity type to resolve.
* @param profileFile Already-parsed profile file to use, or {@code null} to load from the default locations.
* @param regionOverride Region for service-calling providers to use, or {@code null} to resolve it normally.
* @return the assembled chain.
* @throws IllegalStateException if two providers claim the same standard slot.
*/
public static <I extends Identity> IdentityChain<I> create(
Class<I> identityType,
AwsProfileFile profileFile,
String regionOverride
) {
return create(identityType, defaultExecutor(), profileFile, regionOverride);
}

private static ScheduledExecutorService defaultExecutor() {
return Executors.newSingleThreadScheduledExecutor(r2 -> {
Thread t = new Thread(r2, "aws-credential-chain-refresh");
t.setDaemon(true);
return t;
}));
});
}

/**
* Create a credential chain by discovering providers via ServiceLoader.
* Create an identity chain by discovering providers via ServiceLoader.
*
* @param identityType Identity type to resolve.
* @param ex Executor used for background resolution.
* @return the assembled chain.
* @throws IllegalStateException if two providers claim the same standard slot.
*/
public static <I extends Identity> CredentialChain<I> create(Class<I> identityType, ScheduledExecutorService ex) {
return create(identityType, ex, null);
public static <I extends Identity> IdentityChain<I> create(Class<I> identityType, ScheduledExecutorService ex) {
return create(identityType, ex, null, null);
}

/**
* Create a credential chain by discovering providers via ServiceLoader, using a caller-supplied AWS
* config/credentials file.
* Create an identity chain by discovering providers via ServiceLoader, using a caller-supplied AWS
* config/credentials file and region.
*
* <p>When {@code profileFile} is non-null, the {@code SHARED_CONFIG} provider uses it instead of reading
* {@code ~/.aws/config} and {@code ~/.aws/credentials} from disk. Use this when the file has already been
* loaded, or to point the chain at a non-default location.
*
* <p>When {@code regionOverride} is non-null, providers that resolve credentials via a service call (e.g.,
* STS, SSO) use it for their endpoint instead of resolving the region from the environment or profile. This is
* how a client's configured region flows into credential resolution.
*
* @param identityType Identity type to resolve.
* @param ex Executor used for background resolution.
* @param profileFile Already-parsed profile file to use, or {@code null} to load from the default locations.
* @param regionOverride Region for service-calling providers to use, or {@code null} to resolve it normally.
* @return the assembled chain.
* @throws IllegalStateException if two providers claim the same standard slot.
*/
public static <I extends Identity> CredentialChain<I> create(
public static <I extends Identity> IdentityChain<I> create(
Class<I> identityType,
ScheduledExecutorService ex,
AwsProfileFile profileFile
AwsProfileFile profileFile,
String regionOverride
) {
List<ChainIdentityProvider> registrations = new ArrayList<>();
for (ChainIdentityProvider r : ServiceLoader.load(ChainIdentityProvider.class)) {
registrations.add(r);
}
ChainSetup setup = ChainSetup.builder().executor(ex).profileFile(profileFile).build();
ChainSetup setup = ChainSetup.builder()
.executor(ex)
.profileFile(profileFile)
.regionOverride(regionOverride)
.build();
return assemble(identityType, registrations, ex, setup);
}

static <I extends Identity> CredentialChain<I> assemble(
static <I extends Identity> IdentityChain<I> assemble(
Class<I> identityType,
List<ChainIdentityProvider> registrations,
ScheduledExecutorService executor
Expand All @@ -124,7 +157,7 @@ static <I extends Identity> CredentialChain<I> assemble(
* Assemble a chain using a caller-supplied {@link ChainSetup}. Lets tests inject a deterministic environment
* and profile rather than reading the real process environment and config files.
*/
static <I extends Identity> CredentialChain<I> assemble(
static <I extends Identity> IdentityChain<I> assemble(
Class<I> identityType,
List<ChainIdentityProvider> registrations,
ScheduledExecutorService executor,
Expand Down Expand Up @@ -153,7 +186,7 @@ static <I extends Identity> CredentialChain<I> assemble(
var ordered = setup.resolvers();

if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Assembled credential chain: {}",
LOGGER.debug("Assembled identity chain: {}",
ordered.stream().map(ChainSetup.NamedResolver::name).collect(Collectors.joining(", ")));
}

Expand All @@ -168,7 +201,7 @@ static <I extends Identity> CredentialChain<I> assemble(
}
}
warnDetectedButUnclaimed(claimed);
return new CredentialChain<>(identityType, Collections.unmodifiableList(ordered), executor);
return new IdentityChain<>(identityType, Collections.unmodifiableList(ordered), executor);
}

private static List<ChainIdentityProvider> sortByOrdering(List<ChainIdentityProvider> providers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package software.amazon.smithy.java.aws.credentials.chain;

/**
* Describes where a {@link ChainIdentityProvider} sits in the credential chain.
* Describes where a {@link ChainIdentityProvider} sits in the {@link IdentityChain}.
*
* <p>Three forms:
* <ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FeatureIdTest {

@Test
void successfulProviderEmitsFeatureId() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
provider("env",
StandardProvider.ENVIRONMENT,
Expand All @@ -44,7 +44,7 @@ void successfulProviderEmitsFeatureId() {

@Test
void failedProviderDoesNotEmitFeatureId() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
provider("env",
StandardProvider.ENVIRONMENT,
Expand All @@ -68,7 +68,7 @@ void failedProviderDoesNotEmitFeatureId() {

@Test
void multipleFeatureIdsEmitted() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
provider("proc",
StandardProvider.SHARED_CONFIG,
Expand All @@ -93,7 +93,7 @@ void multipleFeatureIdsEmitted() {

@Test
void noFeatureIdsWhenContextKeyNotSet() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
provider("env",
StandardProvider.ENVIRONMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import software.amazon.smithy.java.aws.auth.api.identity.AwsCredentialsIdentity;
import software.amazon.smithy.java.context.Context;

class AwsCredentialChainTest {
class IdentityChainTest {
@Test
void standardProvidersAreOrderedByEnumOrder() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("imds",
new OrderingConstraint.Standard(StandardProvider.EC2_INSTANCE_METADATA),
Expand All @@ -40,7 +40,7 @@ void standardProvidersAreOrderedByEnumOrder() {

@Test
void firstSuccessfulProviderWins() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("env",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -57,7 +57,7 @@ void firstSuccessfulProviderWins() {

@Test
void allFailReturnsAggregatedError() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("env",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -76,7 +76,7 @@ void allFailReturnsAggregatedError() {
@Test
void duplicateSlotThrows() {
assertThrows(IllegalStateException.class,
() -> CredentialChain.assemble(AwsCredentialsIdentity.class,
() -> IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("a",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -89,7 +89,7 @@ void duplicateSlotThrows() {

@Test
void relativeAfterInsertsCorrectly() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("env",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -107,7 +107,7 @@ void relativeAfterInsertsCorrectly() {

@Test
void relativeBeforeInsertsCorrectly() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("env",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -125,7 +125,7 @@ void relativeBeforeInsertsCorrectly() {

@Test
void relativeToUnclaimedSlotAppendsAtEnd() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class,
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("env",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -141,7 +141,7 @@ void relativeToUnclaimedSlotAppendsAtEnd() {
@Test
void duplicateNameThrows() {
assertThrows(IllegalStateException.class,
() -> CredentialChain.assemble(AwsCredentialsIdentity.class,
() -> IdentityChain.assemble(AwsCredentialsIdentity.class,
List.of(
registration("env",
new OrderingConstraint.Standard(StandardProvider.ENVIRONMENT),
Expand All @@ -154,7 +154,7 @@ void duplicateNameThrows() {

@Test
void emptyChainReturnsDescriptiveError() {
var chain = CredentialChain.assemble(AwsCredentialsIdentity.class, List.of(), null);
var chain = IdentityChain.assemble(AwsCredentialsIdentity.class, List.of(), null);
IdentityResult<AwsCredentialsIdentity> result = chain.resolveIdentity(Context.empty());

assertNull(result.identity());
Expand Down Expand Up @@ -182,7 +182,7 @@ public void setup(Class<? extends Identity> identityType, ChainSetup setup) {
}

private static IdentityResolver<AwsCredentialsIdentity> errorResolver(String msg) {
IdentityResult<AwsCredentialsIdentity> result = IdentityResult.ofError(AwsCredentialChainTest.class, msg);
IdentityResult<AwsCredentialsIdentity> result = IdentityResult.ofError(IdentityChainTest.class, msg);
return new IdentityResolver<>() {
public IdentityResult<AwsCredentialsIdentity> resolveIdentity(Context ctx) {
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public Class<AwsCredentialsIdentity> identityType() {
return AwsCredentialsIdentity.class;
}

// Visible for testing: the STS endpoint configuration this resolver was assembled with.
StsEndpointConfig endpoint() {
return endpoint;
}

@Override
public IdentityResult<AwsCredentialsIdentity> resolveIdentity(Context ctx) {
AwsCredentialsIdentity sourceCredentials = resolveSourceCredentials(source, new HashSet<>());
Expand Down
Loading
Loading