Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public static JwtBundleSet of(Collection<JwtBundle> bundles) {
}
final Map<TrustDomain, JwtBundle> bundleMap = new ConcurrentHashMap<>();
for (JwtBundle bundle : bundles) {
Objects.requireNonNull(bundle, "bundle must not be null");
bundleMap.put(bundle.getTrustDomain(), bundle);
}
return new JwtBundleSet(bundleMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static X509BundleSet of(Collection<X509Bundle> bundles) {

final Map<TrustDomain, X509Bundle> bundleMap = new ConcurrentHashMap<>();
for (X509Bundle bundle : bundles) {
Objects.requireNonNull(bundle, "bundle must not be null");
bundleMap.put(bundle.getTrustDomain(), bundle);
}
return new X509BundleSet(bundleMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ public static TrustDomain parse(String idOrName) {
throw new IllegalArgumentException("Trust domain is missing");
}

// Something looks kinda like a scheme separator, let's try to parse as
// an ID. We use :/ instead of :// since the diagnostics are better for
// a bad input like spiffe:/trustdomain.
// Heuristic: if the input resembles a SPIFFE ID or a URI scheme
// (e.g. spiffe://..., spiffe:/..., or <scheme>:/...), delegate parsing
// to SpiffeId.parse() so scheme-related errors are reported consistently.
if (idOrName.contains(":/")) {
SpiffeId spiffeId = SpiffeId.parse(idOrName);
return spiffeId.getTrustDomain();
return SpiffeId.parse(idOrName).getTrustDomain();
}

validateTrustDomainName(idOrName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ private static SpiffeId getSpiffeIdOfSubject(final JWTClaimsSet claimsSet) throw

// expected audiences must be a subset of the audience claim in the token
private static void validateAudience(List<String> audClaim, Set<String> expectedAudiences) throws JwtSvidException {
if (audClaim == null || audClaim.isEmpty()) {
throw new JwtSvidException("Token missing audience claim");
}
if (!audClaim.containsAll(expectedAudiences)) {
throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", expectedAudiences, audClaim));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import io.spiffe.svid.jwtsvid.JwtSvid;
import org.apache.commons.lang3.tuple.ImmutablePair;

import java.io.Closeable;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
Expand Down Expand Up @@ -228,7 +227,7 @@ private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... e
ImmutablePair<SpiffeId, Set<String>> cacheKey = new ImmutablePair<>(subject, audiencesSet);

List<JwtSvid> svidList = jwtSvids.get(cacheKey);
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
if (svidList != null && !svidList.isEmpty() && !isTokenPastHalfLifetime(svidList.get(0))) {
return svidList;
}

Expand All @@ -238,7 +237,7 @@ private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... e
// If it does not exist or the JWT-SVID has passed half its lifetime, call the Workload API to fetch new JWT-SVIDs,
// add them to the cache map, and return the list of JWT-SVIDs.
svidList = jwtSvids.get(cacheKey);
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
if (svidList != null && !svidList.isEmpty() && !isTokenPastHalfLifetime(svidList.get(0))) {
return svidList;
}

Expand All @@ -247,6 +246,9 @@ private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... e
} else {
svidList = workloadApiClient.fetchJwtSvids(cacheKey.left, audience, extraAudiences);
}
if (svidList == null || svidList.isEmpty()) {
throw new JwtSvidException("Workload API returned empty JWT SVID list");
}
jwtSvids.put(cacheKey, svidList);
return svidList;
}
Expand Down Expand Up @@ -333,4 +335,16 @@ private static WorkloadApiClient createClient(JwtSourceOptions options)
void setClock(Clock clock) {
this.clock = clock;
}

// Visible for testing only.
// This method exists to allow deterministic testing of cache edge cases
// (e.g. empty cached lists) without relying on reflection or timing-based
// behavior, which would be more brittle and less safe.
void putCachedJwtSvidsForTest(SpiffeId subject, Set<String> audiences, List<JwtSvid> svids) {
Objects.requireNonNull(subject, "subject must not be null");
Objects.requireNonNull(audiences, "audiences must not be null");
Objects.requireNonNull(svids, "svids must not be null");
ImmutablePair<SpiffeId, Set<String>> cacheKey = new ImmutablePair<>(subject, new HashSet<>(audiences));
jwtSvids.put(cacheKey, new ArrayList<>(svids));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -153,4 +154,12 @@ void add_null_throwsNullPointerException() {
assertEquals("jwtBundle must not be null", e.getMessage());
}
}
}

@Test
void testOf_nullElementInCollection_throwsNullPointerException() {
JwtBundle jwtBundle1 = new JwtBundle(TrustDomain.parse("example.org"));
List<JwtBundle> bundles = Arrays.asList(jwtBundle1, null);
NullPointerException exception = assertThrows(NullPointerException.class, () -> JwtBundleSet.of(bundles));
assertEquals("bundle must not be null", exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -149,4 +150,12 @@ void testgetBundleForTrustDomain_nullTrustDomain_throwsException() throws Bundle
assertEquals("trustDomain must not be null", e.getMessage());
}
}
}

@Test
void testOf_nullElementInCollection_throwsNullPointerException() {
X509Bundle x509Bundle1 = new X509Bundle(TrustDomain.parse("example.org"));
List<X509Bundle> bundles = Arrays.asList(x509Bundle1, null);
NullPointerException exception = assertThrows(NullPointerException.class, () -> X509BundleSet.of(bundles));
assertEquals("bundle must not be null", exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import static io.spiffe.spiffeid.SpiffeIdTest.TD_CHARS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;

class TrustDomainTest {
Expand Down Expand Up @@ -104,4 +105,28 @@ void test_toIdString() {
final TrustDomain trustDomain = TrustDomain.parse("domain.test");
assertEquals("spiffe://domain.test", trustDomain.toIdString());
}

@Test
void testParseFromSpiffeIdWithPath_extractsTrustDomain() {
TrustDomain trustDomain = TrustDomain.parse("spiffe://example.org/foo");
assertEquals("example.org", trustDomain.getName());
}

@Test
void testParseInvalidScheme_spiffeWithSingleSlash_throwsInvalidScheme() {
assertThrows(InvalidSpiffeIdException.class,
() -> TrustDomain.parse("spiffe:/example.org"));
}

@Test
void testParseInvalidScheme_httpScheme_throwsInvalidScheme() {
assertThrows(InvalidSpiffeIdException.class,
() -> TrustDomain.parse("http://example.org"));
}

@Test
void testParseColonNotFollowedBySlash_validatesAsTrustDomain() {
assertThrows(InvalidSpiffeIdException.class,
() -> TrustDomain.parse("trustdomain:test"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ static Stream<Arguments> provideSuccessScenarios() {
TestUtils.generateToken(claims, key3, "authority3"),
""
))
.build()),
Arguments.of(TestCase.builder()
.name("audience contains expected - success")
.jwtBundle(jwtBundle)
.expectedAudience(Collections.singleton("audience1"))
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(null)
.expectedJwtSvid(newJwtSvidInstance(
trustDomain.newSpiffeId("host"),
audience,
issuedAt,
expiration,
claims.getClaims(),
TestUtils.generateToken(claims, key1, "authority1"),
null
))
.build())
);
}
Expand Down Expand Up @@ -243,6 +259,27 @@ static Stream<Arguments> provideFailureScenarios() {
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
.expectedException(new JwtSvidException("expected audience in [another] (audience=[audience2, audience1])"))
.build()),
Arguments.of(TestCase.builder()
.name("missing audience claim")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(new JWTClaimsSet.Builder()
.subject(spiffeId.toString())
.expirationTime(expiration)
.build(), key1, "authority1"))
.expectedException(new JwtSvidException("Token missing audience claim"))
.build()),
Arguments.of(TestCase.builder()
.name("empty audience claim")
.jwtBundle(jwtBundle)
.expectedAudience(audience)
.generateToken(() -> TestUtils.generateToken(new JWTClaimsSet.Builder()
.subject(spiffeId.toString())
.expirationTime(expiration)
.audience(Collections.emptyList())
.build(), key1, "authority1"))
.expectedException(new JwtSvidException("Token missing audience claim"))
.build()),
Arguments.of(TestCase.builder()
.name("invalid subject claim")
.jwtBundle(jwtBundle)
Expand Down Expand Up @@ -388,4 +425,4 @@ public TestCase build() {
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -30,6 +32,9 @@
import static org.junit.jupiter.api.Assertions.*;

class CachedJwtSourceTest {
private static final SpiffeId TEST_SUBJECT = SpiffeId.parse("spiffe://example.org/workload-server");
private static final String TEST_AUDIENCE = "aud1";

private CachedJwtSource jwtSource;
private WorkloadApiClientStub workloadApiClient;
private WorkloadApiClientErrorStub workloadApiClientErrorStub;
Expand Down Expand Up @@ -519,4 +524,109 @@ void newSource_noSocketAddress() throws Exception {
}
});
}

@Test
void testFetchJwtSvids_cacheContainsEmptyList_refetchesFromWorkloadApi() throws JwtSvidException, JwtSourceException, SocketEndpointAddressException {
// Test that if cache somehow contains empty list (edge case), it refetches
JwtSourceOptions options = JwtSourceOptions.builder()
.workloadApiClient(workloadApiClient)
.initTimeout(Duration.ofSeconds(0))
.build();
CachedJwtSource customJwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
customJwtSource.setClock(clock);

try {
// Seed cache with empty list to simulate edge case
Set<String> audiences = Collections.singleton(TEST_AUDIENCE);
customJwtSource.putCachedJwtSvidsForTest(TEST_SUBJECT, audiences, Collections.emptyList());

int initialCallCount = workloadApiClient.getFetchJwtSvidCallCount();

// Fetch should refetch from Workload API (empty list in cache triggers refetch)
List<JwtSvid> svids = customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE);
assertNotNull(svids);
assertEquals(1, svids.size());
// Should have called Workload API
assertEquals(initialCallCount + 1, workloadApiClient.getFetchJwtSvidCallCount());

// Subsequent fetch should NOT call Workload API again (proves valid list was cached after refetch)
List<JwtSvid> svids2 = customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE);
assertNotNull(svids2);
assertEquals(1, svids2.size());
assertEquals(initialCallCount + 1, workloadApiClient.getFetchJwtSvidCallCount());
} finally {
customJwtSource.close();
}
}

@Test
void testFetchJwtSvids_workloadApiReturnsEmptyList_throwsJwtSvidException() throws JwtSourceException, SocketEndpointAddressException {
// Create a custom client that always returns empty list
WorkloadApiClientStub emptyListClient = new WorkloadApiClientStub() {
@Override
public List<JwtSvid> fetchJwtSvids(SpiffeId subject, String audience, String... extraAudience) throws JwtSvidException {
super.fetchJwtSvids(subject, audience, extraAudience); // increment counter
return Collections.emptyList();
}
};
emptyListClient.setClock(clock);

JwtSourceOptions options = JwtSourceOptions.builder()
.workloadApiClient(emptyListClient)
.initTimeout(Duration.ofSeconds(0))
.build();
CachedJwtSource customJwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
customJwtSource.setClock(clock);

try {
JwtSvidException exception = assertThrows(JwtSvidException.class,
() -> customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE));
assertEquals("Workload API returned empty JWT SVID list", exception.getMessage());
} finally {
customJwtSource.close();
}
}

@Test
void testFetchJwtSvids_emptyListNeverCached() throws JwtSvidException, JwtSourceException, SocketEndpointAddressException {
// Create a custom client that returns empty list on first call, then valid SVIDs
final int[] callCount = new int[1];
WorkloadApiClientStub customClient = new WorkloadApiClientStub() {
@Override
public List<JwtSvid> fetchJwtSvids(SpiffeId subject, String audience, String... extraAudience) throws JwtSvidException {
callCount[0]++;
if (callCount[0] == 1) {
return Collections.emptyList();
} else {
return super.fetchJwtSvids(subject, audience, extraAudience);
}
}
};
customClient.setClock(clock);

JwtSourceOptions options = JwtSourceOptions.builder()
.workloadApiClient(customClient)
.initTimeout(Duration.ofSeconds(0))
.build();
CachedJwtSource customJwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
customJwtSource.setClock(clock);

try {
// First call returns empty, should throw (empty list is not cached)
JwtSvidException exception = assertThrows(JwtSvidException.class,
() -> customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE));
assertEquals("Workload API returned empty JWT SVID list", exception.getMessage());

// Verify empty list was not cached: second call should fetch again and succeed
int callCountBeforeSecondCall = callCount[0];
List<JwtSvid> svids = customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE);
assertNotNull(svids);
assertEquals(1, svids.size());
// Verify that second call actually made a fetch (callCount increased)
assertEquals(callCountBeforeSecondCall + 1, callCount[0]);
} finally {
customJwtSource.close();
}
}
}