Skip to content

Error when the train set has 1 example #73

@LeoGrin

Description

@LeoGrin

Example to reproduce:

from tabpfn_client import TabPFNClassifier
import numpy as np
import pandas as pd

# Create minimal example with just one training sample and two features
X_train = pd.DataFrame({
    "feature1": [0.5],
    "feature2": [0.7]
})
X_test = pd.DataFrame({
    "feature1": np.random.rand(10),
    "feature2": np.random.rand(10)
})

# Single training label and 10 test labels
y_train = np.array([1])  # Single class label
y_test = np.random.randint(0, 2, size=10)  # Random binary labels for testing

# Initialize and fit TabPFN
model = TabPFNClassifier()
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)

# Calculate accuracy
accuracy = np.mean(y_pred == y_test)
print(f"Test accuracy: {accuracy:.4f}")

Traceback:

ERROR:tabpfn_client.client:Fail to call fit, response status: 500
Traceback (most recent call last):
  File "/scratch/lgrinszt/lm_tab/scripts/../test_one_example.py", line 24, in <module>
    model.fit(X_train, y_train)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/estimator.py", line 146, in fit
    self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/service_wrapper.py", line 225, in fit
    return ServiceClient.fit(X, y, config=config)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/client.py", line 237, in fit
    cls._validate_response(response, "fit")
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/client.py", line 477, in _validate_response
    raise RuntimeError(
RuntimeError: Fail to call fit with error: 500, reason: Internal Server Error and text: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 715, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 735, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 73, in app
    response = await f(request)
  File "/usr/local/lib/python3.10/site-packages/fastapi/routing.py", line 301, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/site-packages/fastapi/routing.py", line 212, in run_endpoint_function
    return await dependant.call(**values)
  File "/code/tabpfn-server/app/routers/fit.py", line 70, in fit
    train_set_schema = await upload_train_set(
  File "/code/tabpfn-server/app/routers/fit.py", line 39, in upload_train_set
    user_train_set_mapping = await dataset_serv.add_train_set(
  File "/code/tabpfn-server/app/services/dataset_repo_service.py", line 327, in add_train_set
    content[FileType.Y_TRAIN] = self.preprocess_y_train(content[FileType.Y_TRAIN])
  File "/code/tabpfn-server/app/services/dataset_repo_service.py", line 312, in preprocess_y_train
    return y_train.to_csv(index=False).encode()
AttributeError: 'numpy.int64' object has no attribute 'to_csv'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions