Skip to content

Commit 83cf375

Browse files
refactor: replace query parameters with Pydantic models for input validation in API endpoints
1 parent 9b66806 commit 83cf375

1 file changed

Lines changed: 62 additions & 48 deletions

File tree

sentences_chunker.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pathlib import Path
1010
from enum import Enum
1111
from typing import Dict, List, Any, Optional
12-
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
12+
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
1313
from fastapi.concurrency import run_in_threadpool
1414
from pydantic import BaseModel, Field
1515
from contextlib import asynccontextmanager
@@ -171,6 +171,48 @@ class FileChunkingResult(BaseModel):
171171
)
172172

173173

174+
class SplitSentencesInput(BaseModel):
175+
"""Input parameters for split sentences endpoint."""
176+
model_name: SaTModelName = Field(
177+
default=DEFAULT_SAT_MODEL_NAME,
178+
description="The SaT model to use for sentence segmentation"
179+
)
180+
split_threshold: float = Field(
181+
default=DEFAULT_SAT_SPLIT_THRESHOLD,
182+
description="Threshold value for sentence splitting (confidence score for sentence boundaries)",
183+
ge=0.0,
184+
le=1.0
185+
)
186+
187+
188+
class FileChunkerInput(BaseModel):
189+
"""Input parameters for file chunking endpoint."""
190+
model_name: SaTModelName = Field(
191+
default=DEFAULT_SAT_MODEL_NAME,
192+
description="The SaT model to use for sentence segmentation"
193+
)
194+
split_threshold: float = Field(
195+
default=DEFAULT_SAT_SPLIT_THRESHOLD,
196+
description="Threshold value for sentence splitting (confidence score for sentence boundaries)",
197+
ge=0.0,
198+
le=1.0
199+
)
200+
max_chunk_tokens: int = Field(
201+
default=500,
202+
description="Maximum number of tokens per final chunk",
203+
gt=0
204+
)
205+
overlap_sentences: int = Field(
206+
default=1,
207+
description="Number of sentences to overlap between consecutive chunks",
208+
ge=0
209+
)
210+
strict_mode: bool = Field(
211+
default=False,
212+
description="If True, an error is returned if any chunk cannot strictly adhere to token/overlap limits"
213+
)
214+
215+
174216
# Create a singleton for model caching with expiration
175217
_model_cache = {}
176218
_model_last_used = {}
@@ -227,7 +269,7 @@ async def lifespan(app: FastAPI):
227269
app = FastAPI(
228270
title="Text Chunker API",
229271
description="API for chunking text documents into smaller segments with control over token count and overlap",
230-
version="0.6.4",
272+
version="0.6.5",
231273
lifespan=lifespan,
232274
)
233275

@@ -881,16 +923,7 @@ async def split_sentences_endpoint(
881923
file: UploadFile = File(
882924
..., description="Text file (.txt or .md) to split into sentences"
883925
),
884-
model_name: SaTModelName = Query(
885-
DEFAULT_SAT_MODEL_NAME,
886-
description="The SaT model to use for sentence segmentation",
887-
),
888-
split_threshold: float = Query(
889-
DEFAULT_SAT_SPLIT_THRESHOLD,
890-
description="Threshold value for sentence splitting (confidence score for sentence boundaries)",
891-
ge=0.0,
892-
le=1.0,
893-
),
926+
input_data: SplitSentencesInput = Depends(),
894927
):
895928
"""Split text file into sentences using WTPSplit's advanced segmentation.
896929
@@ -921,7 +954,7 @@ async def split_sentences_endpoint(
921954
# Split the text into sentences
922955
sentences = await run_in_threadpool(
923956
lambda: split_sentences_NLP(
924-
text, model_name=model_name, split_threshold=split_threshold
957+
text, model_name=input_data.model_name, split_threshold=input_data.split_threshold
925958
)
926959
)
927960

@@ -952,8 +985,8 @@ async def split_sentences_endpoint(
952985
avg_tokens_per_sentence=int(total_tokens / len(chunks)) if chunks else 0,
953986
max_tokens_in_sentence=max_tokens,
954987
min_tokens_in_sentence=min_tokens,
955-
sat_model_name=model_name,
956-
split_threshold=split_threshold,
988+
sat_model_name=input_data.model_name,
989+
split_threshold=input_data.split_threshold,
957990
source=file.filename,
958991
processing_time=round(processing_time, 4),
959992
)
@@ -976,26 +1009,7 @@ async def split_sentences_endpoint(
9761009
@app.post("/file-chunker/", response_model=FileChunkingResult, tags=["Chunking"])
9771010
async def file_chunker_endpoint(
9781011
file: UploadFile = File(..., description="Text file (.txt or .md) to chunk"),
979-
model_name: SaTModelName = Query(
980-
DEFAULT_SAT_MODEL_NAME,
981-
description="The SaT model to use for initial sentence segmentation",
982-
),
983-
split_threshold: float = Query(
984-
DEFAULT_SAT_SPLIT_THRESHOLD,
985-
description="Threshold for SaT sentence splitting (confidence for boundaries)",
986-
ge=0.0,
987-
le=1.0,
988-
),
989-
max_chunk_tokens: int = Query(
990-
500, description="Maximum number of tokens per final chunk", gt=0
991-
),
992-
overlap_sentences: int = Query(
993-
1, description="Number of sentences to overlap between consecutive chunks", ge=0
994-
),
995-
strict_mode: bool = Query(
996-
False,
997-
description="If True, an error is returned if any chunk cannot strictly adhere to token/overlap limits.",
998-
),
1012+
input_data: FileChunkerInput = Depends(),
9991013
):
10001014
"""Split text into token-limited chunks with optional sentence overlap.
10011015
@@ -1019,9 +1033,9 @@ async def file_chunker_endpoint(
10191033
"""
10201034
start_time = time.time()
10211035
logger.info(
1022-
f"Processing file {file.filename} with model={model_name.value}, "
1023-
f"threshold={split_threshold}, max_tokens={max_chunk_tokens}, "
1024-
f"overlap={overlap_sentences}, strict_mode={strict_mode}"
1036+
f"Processing file {file.filename} with model={input_data.model_name.value}, "
1037+
f"threshold={input_data.split_threshold}, max_tokens={input_data.max_chunk_tokens}, "
1038+
f"overlap={input_data.overlap_sentences}, strict_mode={input_data.strict_mode}"
10251039
)
10261040

10271041
# Validate file type
@@ -1040,16 +1054,16 @@ async def file_chunker_endpoint(
10401054
# Split into sentences using SaT
10411055
sentences = await run_in_threadpool(
10421056
lambda: split_sentences_NLP(
1043-
text, model_name=model_name, split_threshold=split_threshold
1057+
text, model_name=input_data.model_name, split_threshold=input_data.split_threshold
10441058
)
10451059
)
10461060

10471061
if not sentences:
10481062
# Handle empty input
10491063
metadata = FileChunkingMetadata(
1050-
split_threshold=split_threshold,
1051-
configured_max_chunk_tokens=max_chunk_tokens,
1052-
configured_overlap_sentences=overlap_sentences,
1064+
split_threshold=input_data.split_threshold,
1065+
configured_max_chunk_tokens=input_data.max_chunk_tokens,
1066+
configured_overlap_sentences=input_data.overlap_sentences,
10531067
n_input_sentences=0,
10541068
avg_tokens_per_input_sentence=0,
10551069
max_tokens_in_input_sentence=0,
@@ -1058,7 +1072,7 @@ async def file_chunker_endpoint(
10581072
avg_tokens_per_chunk=0,
10591073
max_tokens_in_chunk=0,
10601074
min_tokens_in_chunk=0,
1061-
sat_model_name=model_name,
1075+
sat_model_name=input_data.model_name,
10621076
source=file.filename,
10631077
processing_time=round(time.time() - start_time, 4),
10641078
)
@@ -1087,7 +1101,7 @@ async def file_chunker_endpoint(
10871101
# Group sentences into chunks
10881102
try:
10891103
chunks_data = await _chunk_sentences_by_token_limit(
1090-
sentences_data, max_chunk_tokens, overlap_sentences, strict_mode
1104+
sentences_data, input_data.max_chunk_tokens, input_data.overlap_sentences, input_data.strict_mode
10911105
)
10921106
except StrictChunkingError as e:
10931107
logger.warning(f"Strict mode chunking failed for {file.filename}: {str(e)}")
@@ -1142,9 +1156,9 @@ async def file_chunker_endpoint(
11421156

11431157
# Create metadata
11441158
metadata = FileChunkingMetadata(
1145-
split_threshold=split_threshold,
1146-
configured_max_chunk_tokens=max_chunk_tokens,
1147-
configured_overlap_sentences=overlap_sentences,
1159+
split_threshold=input_data.split_threshold,
1160+
configured_max_chunk_tokens=input_data.max_chunk_tokens,
1161+
configured_overlap_sentences=input_data.overlap_sentences,
11481162
n_input_sentences=len(sentences_data),
11491163
avg_tokens_per_input_sentence=avg_input_tokens,
11501164
max_tokens_in_input_sentence=max_input_tokens,
@@ -1153,7 +1167,7 @@ async def file_chunker_endpoint(
11531167
avg_tokens_per_chunk=avg_output_tokens,
11541168
max_tokens_in_chunk=max_output_tokens,
11551169
min_tokens_in_chunk=min_output_tokens,
1156-
sat_model_name=model_name,
1170+
sat_model_name=input_data.model_name,
11571171
source=file.filename,
11581172
processing_time=round(time.time() - start_time, 4),
11591173
)

0 commit comments

Comments
 (0)