Skip to content

Commit eaf9b0f

Browse files
committed
Update authentication to use session token from login response
- Modified SessionHandler to extract sessionToken directly from login response - Added domain field to SessionCredentials (defaults to "default") - Updated login endpoint from /auth/login to /login - Added both X-Auth-Token and Authorization headers as per DXTrade API spec - Added session expiration check based on login time (1 hour) - Added LoginResponse model for proper response parsing - Added login() and logout() convenience methods to DXtradeClient - Updated tests to reflect new authentication flow
1 parent 59caf1d commit eaf9b0f

5 files changed

Lines changed: 105 additions & 40 deletions

File tree

src/dxtrade/auth.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,9 @@ def __init__(self, credentials: SessionCredentials) -> None:
194194
)
195195
super().__init__(credentials)
196196
self.credentials: SessionCredentials = credentials
197-
self._session_token: Optional[str] = credentials.session_token
197+
self._session_token: Optional[str] = None
198198
self._token_expires_at: Optional[float] = None
199+
self._last_login: Optional[float] = None
199200

200201
async def authenticate(
201202
self,
@@ -221,7 +222,9 @@ async def authenticate(
221222
if not self._session_token:
222223
raise DXtradeAuthenticationError("Failed to obtain session token")
223224

224-
request.headers["Authorization"] = f"Bearer {self._session_token}"
225+
# Add both X-Auth-Token and Authorization headers as shown in the example
226+
request.headers["X-Auth-Token"] = self._session_token
227+
request.headers["Authorization"] = f"DXAPI {self._session_token}"
225228
return request
226229

227230
async def _refresh_session_token(self, client: httpx.AsyncClient) -> None:
@@ -237,26 +240,25 @@ async def _refresh_session_token(self, client: httpx.AsyncClient) -> None:
237240
login_data = {
238241
"username": self.credentials.username,
239242
"password": self.credentials.password,
243+
"domain": self.credentials.domain or "default",
240244
}
241245

242-
response = await client.post("/auth/login", json=login_data)
246+
response = await client.post("/login", json=login_data)
243247
response.raise_for_status()
244248

245249
data = response.json()
246-
if not data.get("success"):
247-
raise DXtradeAuthenticationError(
248-
data.get("message", "Login failed")
249-
)
250-
251-
token_data = data.get("data", {})
252-
self._session_token = token_data.get("token")
253250

254-
# Set token expiration (assuming 1 hour if not provided)
255-
expires_in = token_data.get("expires_in", 3600)
256-
self._token_expires_at = time.time() + expires_in - 300 # 5 min buffer
251+
# Extract session token directly from response
252+
self._session_token = data.get("sessionToken")
257253

258254
if not self._session_token:
259-
raise DXtradeAuthenticationError("No token in login response")
255+
raise DXtradeAuthenticationError(
256+
data.get("message", "No session token in response")
257+
)
258+
259+
# Set token expiration (default to 1 hour as per example)
260+
self._token_expires_at = time.time() + 3600 - 300 # 1 hour with 5 min buffer
261+
self._last_login = time.time()
260262

261263
except httpx.HTTPError as e:
262264
raise DXtradeAuthenticationError(f"Login request failed: {e}") from e
@@ -267,8 +269,13 @@ def _is_token_expired(self) -> bool:
267269
Returns:
268270
True if token is expired or expiring soon
269271
"""
270-
if not self._token_expires_at:
272+
if not self._token_expires_at or not self._last_login:
273+
return True
274+
275+
# Re-login if session is older than 1 hour (as per example)
276+
if (time.time() - self._last_login) > 3600:
271277
return True
278+
272279
return time.time() >= self._token_expires_at
273280

274281
def get_auth_type(self) -> AuthType:
@@ -290,14 +297,18 @@ async def logout(self, client: httpx.AsyncClient) -> None:
290297

291298
try:
292299
# Try to invalidate token on server
293-
headers = {"Authorization": f"Bearer {self._session_token}"}
294-
await client.post("/auth/logout", headers=headers)
300+
headers = {
301+
"X-Auth-Token": self._session_token,
302+
"Authorization": f"DXAPI {self._session_token}"
303+
}
304+
await client.post("/logout", headers=headers)
295305
except httpx.HTTPError:
296306
# Ignore logout errors - we'll clear the token anyway
297307
pass
298308
finally:
299309
self._session_token = None
300310
self._token_expires_at = None
311+
self._last_login = None
301312

302313

303314
class AuthFactory:

src/dxtrade/client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,33 @@ async def close(self) -> None:
182182

183183
# Convenience methods for common operations
184184

185+
async def login(self) -> bool:
186+
"""Perform login for session-based authentication.
187+
188+
This is automatically called when needed for session auth,
189+
but can be called manually to pre-authenticate.
190+
191+
Returns:
192+
True if login successful
193+
194+
Raises:
195+
DXtradeAuthenticationError: Login failed
196+
"""
197+
from dxtrade.auth import SessionHandler
198+
199+
if isinstance(self._auth_handler, SessionHandler):
200+
# Force refresh of session token
201+
await self._auth_handler._refresh_session_token(self._http_client._client)
202+
return True
203+
return False
204+
205+
async def logout(self) -> None:
206+
"""Perform logout for session-based authentication."""
207+
from dxtrade.auth import SessionHandler
208+
209+
if isinstance(self._auth_handler, SessionHandler):
210+
await self._auth_handler.logout(self._http_client._client)
211+
185212
async def get_server_time(self):
186213
"""Get server time (convenience method)."""
187214
return await self.instruments.get_server_time()

src/dxtrade/models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,19 @@ class SessionCredentials(Credentials):
138138
"""Session-based authentication credentials."""
139139
username: str = Field(..., description="Username")
140140
password: str = Field(..., description="Password", repr=False)
141-
session_token: Optional[str] = Field(None, description="Session token")
141+
domain: str = Field("default", description="Login domain")
142+
143+
144+
# ============================================================================
145+
# Authentication Response Models
146+
# ============================================================================
147+
148+
class LoginResponse(DXtradeBaseModel):
149+
"""Login response from DXtrade API."""
150+
sessionToken: str = Field(..., description="Session authentication token")
151+
expiresIn: Optional[int] = Field(None, description="Token expiration time in seconds")
152+
userId: Optional[str] = Field(None, description="User identifier")
153+
accounts: Optional[List[str]] = Field(None, description="Available account IDs")
142154

143155

144156
# ============================================================================

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def session_credentials():
5454
"""Session credentials fixture."""
5555
return SessionCredentials(
5656
username="test_user",
57-
password="test_password"
57+
password="test_password",
58+
domain="default"
5859
)
5960

6061

tests/test_auth.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -183,46 +183,48 @@ def test_init_with_invalid_credentials(self, bearer_token_credentials):
183183
with pytest.raises(DXtradeConfigurationError):
184184
SessionHandler(bearer_token_credentials)
185185

186-
def test_init_with_existing_token(self):
187-
"""Test initialization with existing session token."""
186+
def test_init_without_token(self):
187+
"""Test initialization without existing session token."""
188188
credentials = SessionCredentials(
189189
username="test_user",
190190
password="test_password",
191-
session_token="existing_token"
191+
domain="default"
192192
)
193193
handler = SessionHandler(credentials)
194-
assert handler._session_token == "existing_token"
194+
assert handler._session_token is None
195+
assert handler._last_login is None
195196

196197
@pytest.mark.asyncio
197198
async def test_authenticate_with_valid_token(self, session_credentials):
198199
"""Test authentication with valid session token."""
199200
handler = SessionHandler(session_credentials)
200201
handler._session_token = "valid_token"
201202
handler._token_expires_at = time.time() + 3600 # 1 hour from now
203+
handler._last_login = time.time() # Just logged in
202204

203205
request = MagicMock(spec=httpx.Request)
204206
request.headers = {}
205207
client = AsyncMock(spec=httpx.AsyncClient)
206208

207209
authenticated_request = await handler.authenticate(request, client)
208210

209-
assert authenticated_request.headers["Authorization"] == "Bearer valid_token"
211+
assert authenticated_request.headers["X-Auth-Token"] == "valid_token"
212+
assert authenticated_request.headers["Authorization"] == "DXAPI valid_token"
210213

211214
@pytest.mark.asyncio
212215
async def test_authenticate_requires_login(self, session_credentials):
213216
"""Test authentication that requires login."""
214217
handler = SessionHandler(session_credentials)
215218
# No existing token
216219

217-
# Mock successful login response
220+
# Mock successful login response (DXTrade format)
218221
login_response = MagicMock(spec=httpx.Response)
219222
login_response.raise_for_status.return_value = None
220223
login_response.json.return_value = {
221-
"success": True,
222-
"data": {
223-
"token": "new_session_token",
224-
"expires_in": 3600
225-
}
224+
"sessionToken": "new_session_token",
225+
"expiresIn": 3600,
226+
"userId": "user123",
227+
"accounts": ["account1", "account2"]
226228
}
227229

228230
client = AsyncMock(spec=httpx.AsyncClient)
@@ -233,27 +235,28 @@ async def test_authenticate_requires_login(self, session_credentials):
233235

234236
authenticated_request = await handler.authenticate(request, client)
235237

236-
# Verify login was called
238+
# Verify login was called with domain
237239
client.post.assert_called_once_with(
238-
"/auth/login",
239-
json={"username": "test_user", "password": "test_password"}
240+
"/login",
241+
json={"username": "test_user", "password": "test_password", "domain": "default"}
240242
)
241243

242244
# Verify token was set
243245
assert handler._session_token == "new_session_token"
244-
assert authenticated_request.headers["Authorization"] == "Bearer new_session_token"
246+
assert authenticated_request.headers["X-Auth-Token"] == "new_session_token"
247+
assert authenticated_request.headers["Authorization"] == "DXAPI new_session_token"
245248

246249
@pytest.mark.asyncio
247250
async def test_authenticate_login_failure(self, session_credentials):
248251
"""Test authentication with login failure."""
249252
handler = SessionHandler(session_credentials)
250253

251-
# Mock failed login response
254+
# Mock failed login response (no session token)
252255
login_response = MagicMock(spec=httpx.Response)
253256
login_response.raise_for_status.return_value = None
254257
login_response.json.return_value = {
255-
"success": False,
256258
"message": "Invalid credentials"
259+
# No sessionToken field
257260
}
258261

259262
client = AsyncMock(spec=httpx.AsyncClient)
@@ -288,10 +291,17 @@ def test_token_expiration_check(self, session_credentials):
288291

289292
# Expired token
290293
handler._token_expires_at = time.time() - 3600 # 1 hour ago
294+
handler._last_login = time.time() - 7200 # 2 hours ago
295+
assert handler._is_token_expired() is True
296+
297+
# Valid token but session older than 1 hour
298+
handler._token_expires_at = time.time() + 3600 # 1 hour from now
299+
handler._last_login = time.time() - 3700 # >1 hour ago
291300
assert handler._is_token_expired() is True
292301

293-
# Valid token
302+
# Valid token and recent session
294303
handler._token_expires_at = time.time() + 3600 # 1 hour from now
304+
handler._last_login = time.time() - 1800 # 30 minutes ago
295305
assert handler._is_token_expired() is False
296306

297307
@pytest.mark.asyncio
@@ -305,10 +315,13 @@ async def test_logout(self, session_credentials):
305315

306316
await handler.logout(client)
307317

308-
# Verify logout request was made
318+
# Verify logout request was made with correct headers
309319
client.post.assert_called_once_with(
310-
"/auth/logout",
311-
headers={"Authorization": "Bearer test_token"}
320+
"/logout",
321+
headers={
322+
"X-Auth-Token": "test_token",
323+
"Authorization": "DXAPI test_token"
324+
}
312325
)
313326

314327
# Verify token was cleared
@@ -329,6 +342,7 @@ async def test_logout_network_error(self, session_credentials):
329342

330343
# Token should still be cleared
331344
assert handler._session_token is None
345+
assert handler._last_login is None
332346

333347

334348
class TestAuthFactory:

0 commit comments

Comments
 (0)