Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/main/java/org/ohdsi/webapi/ShiroConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import java.util.Collection;
import java.util.Set;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.realm.Realm;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.apache.shiro.web.servlet.AbstractShiroFilter;
import org.ohdsi.webapi.shiro.AtlasWebSecurityManager;
import org.ohdsi.webapi.shiro.lockout.*;
import org.ohdsi.webapi.shiro.management.DataSourceAccessBeanPostProcessor;
Expand All @@ -13,14 +15,18 @@
import org.ohdsi.webapi.shiro.management.datasource.DataSourceAccessParameterResolver;
import org.ohdsi.webapi.shiro.realms.JwtAuthRealm;
import org.ohdsi.webapi.shiro.subject.WebDelegatingRunAsSubjectFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;

import jakarta.servlet.Filter;
import java.util.Map;
Expand All @@ -33,6 +39,8 @@
@Configuration
public class ShiroConfiguration {

private static final Logger log = LoggerFactory.getLogger(ShiroConfiguration.class);

@Value("${security.maxLoginAttempts}")
private int maxLoginAttempts;
@Value("${security.duration.initial}")
Expand All @@ -44,7 +52,7 @@ public class ShiroConfiguration {
@Autowired
protected ApplicationEventPublisher eventPublisher;

@Bean
@Bean(name = "shiroFilter")
public ShiroFilterFactoryBean shiroFilter(Security security, LockoutPolicy lockoutPolicy) {

ShiroFilterFactoryBean shiroFilter = new ShiroFilterFactoryBean();
Expand All @@ -53,7 +61,10 @@ public ShiroFilterFactoryBean shiroFilter(Security security, LockoutPolicy locko
Map<String, Filter> filters = security.getFilters().entrySet().stream()
.collect(Collectors.toMap(f -> f.getKey().getTemplateName(), Map.Entry::getValue));
shiroFilter.setFilters(filters);
shiroFilter.setFilterChainDefinitionMap(security.getFilterChain());

Map<String, String> filterChain = security.getFilterChain();

shiroFilter.setFilterChainDefinitionMap(filterChain);

return shiroFilter;
}
Expand All @@ -73,6 +84,9 @@ public DefaultWebSecurityManager securityManager(Security security, LockoutPolic

securityManager.setSubjectFactory(new WebDelegatingRunAsSubjectFactory());

// Initialize SecurityUtils for programmatic access throughout the application
SecurityUtils.setSecurityManager(securityManager);

return securityManager;
}

Expand Down
34 changes: 25 additions & 9 deletions src/main/java/org/ohdsi/webapi/shiro/TokenManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ public static String getSubject(String jwt) throws JwtException {
}

public static Claims getBody(String jwt) {

// Get untrusted subject for secret key retrieval
// Extract subject without signature verification to retrieve signing key
String untrustedSubject = getUntrustedSubject(jwt);
if (untrustedSubject == null) {
throw new UnsupportedJwtException("Cannot extract subject from the token");
}

// Pick all secret keys: latest one + previous keys, which were just invalidated (to overcome concurrency issue)
// Retrieve signing keys: current key + grace period keys for concurrency handling
List<SecretKey> keyOptions = gracePeriodInvalidTokens.get(untrustedSubject);
if (userToKeyMap.containsKey(untrustedSubject)) {
keyOptions.add(0, userToKeyMap.get(untrustedSubject));
Expand All @@ -86,12 +85,30 @@ public static Claims getBody(String jwt) {
}

protected static String getUntrustedSubject(String jws) {
int i = jws.lastIndexOf('.');
if (i == -1) {
try {
// Split JWT into header.payload.signature components
String[] parts = jws.split("\\.");
if (parts.length != 3) {
return null;
}

// Base64-decode payload to extract subject claim
String payload = new String(java.util.Base64.getUrlDecoder().decode(parts[1]));

// Extract "sub" field from JSON payload
int subIndex = payload.indexOf("\"sub\"");
if (subIndex == -1) {
return null;
}

int colonIndex = payload.indexOf(":", subIndex);
int startQuote = payload.indexOf("\"", colonIndex);
int endQuote = payload.indexOf("\"", startQuote + 1);

return payload.substring(startQuote + 1, endQuote);
} catch (Exception e) {
return null;
}
String untrustedJwtString = jws.substring(0, i+1);
return Jwts.parser().unsecured().build().parseUnsecuredClaims(untrustedJwtString).getPayload().getSubject();
}

public static Boolean invalidate(String jwt) {
Expand Down Expand Up @@ -127,7 +144,6 @@ public static String extractToken(ServletRequest request) {
if (headerParts.length != 2)
return null;

String jwt = headerParts[1];
return jwt;
return headerParts[1];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@
import org.ohdsi.webapi.shiro.filters.AtlasAuthFilter;
import org.ohdsi.webapi.shiro.tokens.JwtAuthToken;
import org.ohdsi.webapi.shiro.TokenManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class AtlasJwtAuthFilter extends AtlasAuthFilter {

private static final Logger logger = LoggerFactory.getLogger(AtlasJwtAuthFilter.class);

@Override
protected JwtAuthToken createToken(ServletRequest request, ServletResponse response) throws Exception {
String jwt = TokenManager.extractToken(request);
try {
String subject = TokenManager.getSubject(jwt);
return new JwtAuthToken(subject);
} catch (JwtException e) {
logger.warn("JWT validation failed: {}", e.getMessage());
throw new AuthenticationException(e);
}
}
Expand All @@ -33,6 +38,7 @@ protected boolean onAccessDenied(ServletRequest request, ServletResponse respons
loggedIn = executeLogin(request, response);
}
catch(AuthenticationException ae) {
logger.debug("JWT authentication failed: {}", ae.getMessage());
loggedIn = false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ protected FilterChainBuilder getFilterChainBuilder() {
// login/logout
.addRestPath("/user/refresh", JWT_AUTHC, UPDATE_TOKEN, SEND_TOKEN_IN_HEADER)
.addProtectedRestPath("/user/runas", RUN_AS, UPDATE_TOKEN, SEND_TOKEN_IN_HEADER)
.addProtectedRestPath("/user/me")
.addRestPath("/user/logout", LOGOUT);

// MUST be called before adding OAuth filters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Component;


import static org.ohdsi.webapi.shiro.management.FilterTemplates.*;
/**
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ public FilterChainBuilder addPath(String path, FilterTemplates... filters) {

public FilterChainBuilder addPath(String path, String filters) {
path = path.replaceAll("/+$", "");

// Prepend /WebAPI to match JAX-RS @ApplicationPath("/WebAPI")
if (!path.startsWith("/WebAPI") && !path.equals("/**") && !path.equals("/*")) {
path = "/WebAPI" + path;
}

this.filterChain.put(path, filters);

// If path ends with non wildcard character, need to add two paths -
Expand Down
27 changes: 1 addition & 26 deletions src/main/java/org/ohdsi/webapi/trexsql/TrexSQLConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@

import org.springframework.boot.context.properties.ConfigurationProperties;

import java.util.HashMap;
import java.util.Map;

/**
* Configuration properties for trexsql integration.
* Maps to trexsql.* in application properties.
* Global trexsql configuration. Per-source config is in the source table (is_cache_enabled).
*/
@ConfigurationProperties(prefix = "trexsql")
public class TrexSQLConfig {

private boolean enabled = false;
private String cachePath = "./data/cache";
private String extensionsPath;
private Map<String, TrexSQLSourceConfig> sources = new HashMap<>();

public boolean isEnabled() {
return enabled;
Expand All @@ -40,24 +35,4 @@ public String getExtensionsPath() {
public void setExtensionsPath(String extensionsPath) {
this.extensionsPath = extensionsPath;
}

public Map<String, TrexSQLSourceConfig> getSources() {
return sources;
}

public void setSources(Map<String, TrexSQLSourceConfig> sources) {
this.sources = sources;
}

public TrexSQLSourceConfig getSourceConfig(String sourceKey) {
return sources.get(sourceKey);
}

public boolean isEnabledForSource(String sourceKey) {
if (!enabled) {
return false;
}
TrexSQLSourceConfig sourceConfig = sources.get(sourceKey);
return sourceConfig != null && sourceConfig.isEnabled();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ public int getPriority() {
public Collection<Concept> executeSearch(SearchProviderConfig searchConfig, String query, String rows) throws Exception {
String sourceKey = searchConfig.getSourceKey();

if (!trexsqlService.isEnabledForSource(sourceKey)) {
log.debug("TrexSQL not enabled for source {}", sourceKey);
throw new IllegalStateException("TrexSQL not enabled for source: " + sourceKey);
}

if (!trexsqlService.isCacheAvailable(sourceKey)) {
log.debug("Cache not available for source {}", sourceKey);
throw new IllegalStateException("TrexSQL cache not available for source: " + sourceKey);
Expand Down
35 changes: 9 additions & 26 deletions src/main/java/org/ohdsi/webapi/trexsql/TrexSQLService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import java.util.Map;

/**
* Service for TrexSQL operations used by SearchProvider.
* Service for TrexSQL operations. Cache is available if file exists.
*/
@Service
@ConditionalOnProperty(name = "trexsql.enabled", havingValue = "true", matchIfMissing = false)
Expand All @@ -28,42 +28,25 @@ public TrexSQLService(TrexSQLConfig config, TrexSQLInstanceManager instanceManag
this.instanceManager = instanceManager;
}

public boolean isEnabledForSource(String sourceKey) {
return config.isEnabledForSource(sourceKey);
}

/**
* Check if cache file exists for source.
*/
public boolean isCacheAvailable(String sourceKey) {
TrexSQLSourceConfig sourceConfig = config.getSourceConfig(sourceKey);
if (sourceConfig == null) {
return false;
}
String databaseCode = sourceConfig.getDatabaseCode();
if (databaseCode == null || databaseCode.isEmpty()) {
return false;
}
return Paths.get(config.getCachePath(), databaseCode + ".db")
.toFile().exists();
return Paths.get(config.getCachePath(), sourceKey + ".db").toFile().exists();
}

@SuppressWarnings("unchecked")
public List<Map<String, Object>> searchVocab(String sourceKey, String searchTerm, int maxRows) {
log.debug("Searching vocabulary for source {} with term: {}", sourceKey, searchTerm);

TrexSQLSourceConfig sourceConfig = config.getSourceConfig(sourceKey);
if (sourceConfig == null) {
throw new IllegalStateException("TrexSQL source configuration not found for key: " + sourceKey);
}

String databaseCode = sourceConfig.getDatabaseCode();
if (databaseCode == null || databaseCode.isEmpty()) {
throw new IllegalStateException("TrexSQL database code not configured for source: " + sourceKey);
if (!isCacheAvailable(sourceKey)) {
throw new IllegalStateException("TrexSQL cache not available for source: " + sourceKey);
}

Map<String, Object> options = new HashMap<>();
options.put("database-code", databaseCode);
options.put("database-code", sourceKey);
options.put("max-rows", maxRows);
String cachePath = config.getCachePath();
options.put("cache-path", cachePath != null ? cachePath : "/data/cache");
options.put("cache-path", config.getCachePath());

try {
Object db = instanceManager.getInstance();
Expand Down
27 changes: 0 additions & 27 deletions src/main/java/org/ohdsi/webapi/trexsql/TrexSQLSourceConfig.java

This file was deleted.

Loading