Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/main/java/org/ohdsi/webapi/ShiroConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import java.util.Collection;
import java.util.Set;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.realm.Realm;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.apache.shiro.web.servlet.AbstractShiroFilter;
import org.ohdsi.webapi.shiro.AtlasWebSecurityManager;
import org.ohdsi.webapi.shiro.lockout.*;
import org.ohdsi.webapi.shiro.management.DataSourceAccessBeanPostProcessor;
Expand All @@ -13,14 +15,18 @@
import org.ohdsi.webapi.shiro.management.datasource.DataSourceAccessParameterResolver;
import org.ohdsi.webapi.shiro.realms.JwtAuthRealm;
import org.ohdsi.webapi.shiro.subject.WebDelegatingRunAsSubjectFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;

import jakarta.servlet.Filter;
import java.util.Map;
Expand All @@ -33,6 +39,8 @@
@Configuration
public class ShiroConfiguration {

private static final Logger log = LoggerFactory.getLogger(ShiroConfiguration.class);

@Value("${security.maxLoginAttempts}")
private int maxLoginAttempts;
@Value("${security.duration.initial}")
Expand All @@ -44,7 +52,7 @@ public class ShiroConfiguration {
@Autowired
protected ApplicationEventPublisher eventPublisher;

@Bean
@Bean(name = "shiroFilter")
public ShiroFilterFactoryBean shiroFilter(Security security, LockoutPolicy lockoutPolicy) {

ShiroFilterFactoryBean shiroFilter = new ShiroFilterFactoryBean();
Expand All @@ -53,7 +61,10 @@ public ShiroFilterFactoryBean shiroFilter(Security security, LockoutPolicy locko
Map<String, Filter> filters = security.getFilters().entrySet().stream()
.collect(Collectors.toMap(f -> f.getKey().getTemplateName(), Map.Entry::getValue));
shiroFilter.setFilters(filters);
shiroFilter.setFilterChainDefinitionMap(security.getFilterChain());

Map<String, String> filterChain = security.getFilterChain();

shiroFilter.setFilterChainDefinitionMap(filterChain);

return shiroFilter;
}
Expand All @@ -73,6 +84,9 @@ public DefaultWebSecurityManager securityManager(Security security, LockoutPolic

securityManager.setSubjectFactory(new WebDelegatingRunAsSubjectFactory());

// Initialize SecurityUtils for programmatic access throughout the application
SecurityUtils.setSecurityManager(securityManager);

return securityManager;
}

Expand Down
34 changes: 25 additions & 9 deletions src/main/java/org/ohdsi/webapi/shiro/TokenManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ public static String getSubject(String jwt) throws JwtException {
}

public static Claims getBody(String jwt) {

// Get untrusted subject for secret key retrieval
// Extract subject without signature verification to retrieve signing key
String untrustedSubject = getUntrustedSubject(jwt);
if (untrustedSubject == null) {
throw new UnsupportedJwtException("Cannot extract subject from the token");
}

// Pick all secret keys: latest one + previous keys, which were just invalidated (to overcome concurrency issue)
// Retrieve signing keys: current key + grace period keys for concurrency handling
List<SecretKey> keyOptions = gracePeriodInvalidTokens.get(untrustedSubject);
if (userToKeyMap.containsKey(untrustedSubject)) {
keyOptions.add(0, userToKeyMap.get(untrustedSubject));
Expand All @@ -86,12 +85,30 @@ public static Claims getBody(String jwt) {
}

protected static String getUntrustedSubject(String jws) {
int i = jws.lastIndexOf('.');
if (i == -1) {
try {
// Split JWT into header.payload.signature components
String[] parts = jws.split("\\.");
if (parts.length != 3) {
return null;
}

// Base64-decode payload to extract subject claim
String payload = new String(java.util.Base64.getUrlDecoder().decode(parts[1]));

// Extract "sub" field from JSON payload
int subIndex = payload.indexOf("\"sub\"");
if (subIndex == -1) {
return null;
}

int colonIndex = payload.indexOf(":", subIndex);
int startQuote = payload.indexOf("\"", colonIndex);
int endQuote = payload.indexOf("\"", startQuote + 1);

return payload.substring(startQuote + 1, endQuote);
} catch (Exception e) {
return null;
}
String untrustedJwtString = jws.substring(0, i+1);
return Jwts.parser().unsecured().build().parseUnsecuredClaims(untrustedJwtString).getPayload().getSubject();
}

public static Boolean invalidate(String jwt) {
Expand Down Expand Up @@ -127,7 +144,6 @@ public static String extractToken(ServletRequest request) {
if (headerParts.length != 2)
return null;

String jwt = headerParts[1];
return jwt;
return headerParts[1];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@
import org.ohdsi.webapi.shiro.filters.AtlasAuthFilter;
import org.ohdsi.webapi.shiro.tokens.JwtAuthToken;
import org.ohdsi.webapi.shiro.TokenManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class AtlasJwtAuthFilter extends AtlasAuthFilter {

private static final Logger logger = LoggerFactory.getLogger(AtlasJwtAuthFilter.class);

@Override
protected JwtAuthToken createToken(ServletRequest request, ServletResponse response) throws Exception {
String jwt = TokenManager.extractToken(request);
try {
String subject = TokenManager.getSubject(jwt);
return new JwtAuthToken(subject);
} catch (JwtException e) {
logger.warn("JWT validation failed: {}", e.getMessage());
throw new AuthenticationException(e);
}
}
Expand All @@ -33,6 +38,7 @@ protected boolean onAccessDenied(ServletRequest request, ServletResponse respons
loggedIn = executeLogin(request, response);
}
catch(AuthenticationException ae) {
logger.debug("JWT authentication failed: {}", ae.getMessage());
loggedIn = false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ protected FilterChainBuilder getFilterChainBuilder() {
// login/logout
.addRestPath("/user/refresh", JWT_AUTHC, UPDATE_TOKEN, SEND_TOKEN_IN_HEADER)
.addProtectedRestPath("/user/runas", RUN_AS, UPDATE_TOKEN, SEND_TOKEN_IN_HEADER)
.addProtectedRestPath("/user/me")
.addRestPath("/user/logout", LOGOUT);

// MUST be called before adding OAuth filters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;


import static org.ohdsi.webapi.shiro.management.FilterTemplates.*;
/**
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ public FilterChainBuilder addPath(String path, FilterTemplates... filters) {

public FilterChainBuilder addPath(String path, String filters) {
path = path.replaceAll("/+$", "");

// Prepend /WebAPI to match JAX-RS @ApplicationPath("/WebAPI")
if (!path.startsWith("/WebAPI") && !path.equals("/**") && !path.equals("/*")) {
path = "/WebAPI" + path;
}

this.filterChain.put(path, filters);

// If path ends with non wildcard character, need to add two paths -
Expand Down