diff --git a/.gitignore b/.gitignore index e8910808..2cc7466d 100644 --- a/.gitignore +++ b/.gitignore @@ -154,7 +154,7 @@ cython_debug/ **/experiments/* **/config.toml !**/.gitkeep - +test/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/config.template.toml b/config.template.toml index 89d71f20..ffa1f080 100644 --- a/config.template.toml +++ b/config.template.toml @@ -25,3 +25,21 @@ llm_api_key = "" # paper searching engine #engine = "semanticscholar" + +#################################### MCP #################################### +[mcp.servers] + +[mcp.servers.code_search] +command = "python" +args = ["-m", "tiny_scientist.mcp.code_search_server"] +cwd = "." + +[mcp.servers.paper_search] +command = "python" +args = ["-m", "tiny_scientist.mcp.paper_search_server"] +cwd = "." + +[mcp.servers.drawer] +command = "python" +args = ["-m", "tiny_scientist.mcp.drawer_server"] +cwd = "." diff --git a/poetry.lock b/poetry.lock index d8c2b5dd..905b3f78 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aider-chat" @@ -363,7 +363,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, @@ -1068,7 +1068,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev", "test"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -1107,6 +1107,30 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fastmcp" +version = "1.0" +description = "A more ergonomic interface for MCP servers" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "fastmcp-1.0-py3-none-any.whl", hash = "sha256:88f0c5acc2af06f22cf46dd26c1a1c4c54f1479ef09e5f871fdfbade6defe3a6"}, + {file = "fastmcp-1.0.tar.gz", hash = "sha256:202f454e82cb68460a2b7372f975901e78e03b27734ce3a16c4d1d3e3cdbc519"}, +] + +[package.dependencies] +httpx = ">=0.26.0" +mcp = ">=1.0.0,<2.0.0" +pydantic = ">=2.5.3,<3.0.0" +pydantic-settings = ">=2.6.1" +python-dotenv = ">=1.0.1" +typer = ">=0.9.0" + +[package.extras] +dev = ["copychat (>=0.5.2)", "ipython (>=8.12.3)", "pdbpp (>=0.10.3)", "pre-commit", "pyright (>=1.1.389)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.23.5)", "pytest-flakefinder", "pytest-xdist (>=3.6.1)", "ruff"] +tests = ["pre-commit", "pyright (>=1.1.389)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.23.5)", "pytest-flakefinder", "pytest-xdist (>=3.6.1)", "ruff"] + [[package]] name = "filelock" version = "3.18.0" @@ -1715,6 +1739,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.1" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37"}, + {file = "httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e"}, +] + [[package]] name = "huggingface-hub" version = "0.31.1" @@ -1905,6 +1941,22 @@ qtconsole = ["qtconsole"] test = ["pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath", "trio"] +[[package]] +name = "isort" +version = "6.0.1" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.9.0" +groups = ["dev"] +files = [ + {file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"}, + {file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"}, +] + +[package.extras] +colors = ["colorama"] +plugins = ["setuptools"] + [[package]] name = "itsdangerous" version = "2.2.0" @@ -2419,6 +2471,35 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.10.1" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.10.1-py3-none-any.whl", hash = "sha256:4d08301aefe906dce0fa482289db55ce1db831e3e67212e65b5e23ad8454b3c5"}, + {file = "mcp-1.10.1.tar.gz", hash = "sha256:aaa0957d8307feeff180da2d9d359f2b801f35c0c67f1882136239055ef034c2"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.7.2,<3.0.0" +pydantic-settings = ">=2.5.2" +python-multipart = ">=0.0.9" +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +uvicorn = {version = ">=0.23.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.12.4)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -3586,6 +3667,30 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.10.1" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796"}, + {file = "pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee"}, +] + +[package.dependencies] +pydantic = ">=2.7.0" +python-dotenv = ">=0.21.0" +typing-inspection = ">=0.4.0" + +[package.extras] +aws-secrets-manager = ["boto3 (>=1.35.0)", "boto3-stubs[secretsmanager]"] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] +gcp-secret-manager = ["google-cloud-secret-manager (>=2.23.1)"] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pydub" version = "0.25.1" @@ -3792,6 +3897,18 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.20" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, + {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, +] + [[package]] name = "pywin32" version = "310" @@ -4313,6 +4430,34 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "ruff" +version = "0.12.2" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"}, + {file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"}, + {file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"}, + {file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"}, + {file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"}, + {file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"}, + {file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"}, +] + [[package]] name = "scipy" version = "1.15.3" @@ -4714,6 +4859,27 @@ files = [ [package.dependencies] catalogue = ">=2.0.3,<2.1.0" +[[package]] +name = "sse-starlette" +version = "2.3.6" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-2.3.6-py3-none-any.whl", hash = "sha256:d49a8285b182f6e2228e2609c350398b2ca2c36216c2675d875f81e93548f760"}, + {file = "sse_starlette-2.3.6.tar.gz", hash = "sha256:0382336f7d4ec30160cf9ca0518962905e1b69b72d6c1c995131e0a703b436e3"}, +] + +[package.dependencies] +anyio = ">=4.7.0" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio,examples] (>=2.0.41)", "starlette (>=0.41.3)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + [[package]] name = "stack-data" version = "0.6.3" @@ -4734,6 +4900,25 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.47.1" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527"}, + {file = "starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b"}, +] + +[package.dependencies] +anyio = ">=3.6.2,<5" +typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} + +[package.extras] +full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] + [[package]] name = "tabulate" version = "0.9.0" @@ -4968,7 +5153,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["main", "dev", "test"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -5316,6 +5501,27 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.35.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "sys_platform != \"emscripten\"" +files = [ + {file = "uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a"}, + {file = "uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "virtualenv" version = "20.30.0" @@ -5732,4 +5938,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10, <3.12" -content-hash = "46cfcb39f360b7e1ba6a76f031e60fdcc603aa9c289de10cf525f12052db9b96" +content-hash = "4011c12d0bf99a57e589406755e1821ce2e6efeff776a355390434da1b61d804" diff --git a/pyproject.toml b/pyproject.toml index 81a79d98..e05aec41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ cairosvg = "^2.7.1" together = "*" flask = "^3.0.0" flask-cors = "^4.0.0" +fastmcp = "*" +mcp = "*" +httpx = "*" [tool.poetry.group.dev.dependencies] pre-commit = "*" @@ -40,6 +43,8 @@ types-setuptools = "*" types-pyyaml = "^6.0.12.20250402" types-requests = "^2.31" types-toml = "^0.10" +isort = "^6.0.1" +ruff = "^0.12.2" [tool.poetry.group.test.dependencies] pytest = "*" diff --git a/tiny_scientist/coder.py b/tiny_scientist/coder.py index 71dcb90e..aba29738 100644 --- a/tiny_scientist/coder.py +++ b/tiny_scientist/coder.py @@ -29,6 +29,7 @@ def __init__( chat_history: Optional[str] = None, auto_install: bool = True, cost_tracker: Optional[CostTracker] = None, + mcp_client: Any = None, ): """Initialize the ExperimentCoder with configuration and Aider setup.""" self.client, self.model = create_client(model) @@ -39,6 +40,7 @@ def __init__( self.auto_install = auto_install self.config = Config() self.cost_tracker = cost_tracker or CostTracker() + self.mcp_client = mcp_client # Load prompts self.prompts = self.config.prompt_template.coder_prompt @@ -76,9 +78,26 @@ def run( ) -> Tuple[bool, str, Optional[str]]: # Ensure a clean slate for every run print(f"[System] Cleaning experiment directory: {self.output_dir}") - if osp.exists(self.output_dir): - shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir) + + # Save current working directory and switch to parent directory to avoid deletion issues + original_cwd = os.getcwd() + safe_cwd = osp.dirname(osp.abspath(self.output_dir)) + + try: + # Switch to safe directory before cleaning + os.chdir(safe_cwd) + + if osp.exists(self.output_dir): + shutil.rmtree(self.output_dir) + os.makedirs(self.output_dir) + + finally: + # Restore original working directory if it still exists, otherwise use safe directory + try: + os.chdir(original_cwd) + except (FileNotFoundError, OSError): + print(f"[System] Original working directory no longer exists, staying in {safe_cwd}") + os.chdir(safe_cwd) fnames = [ osp.join(self.output_dir, "experiment.py"), osp.join(self.output_dir, "notes.txt"), diff --git a/tiny_scientist/mcp/code_search_server.py b/tiny_scientist/mcp/code_search_server.py new file mode 100644 index 00000000..3e9ca3e0 --- /dev/null +++ b/tiny_scientist/mcp/code_search_server.py @@ -0,0 +1,202 @@ +import json +import os +import re +from typing import Any, Dict, List, Optional + +import httpx +import spacy +import toml +from mcp.server.fastmcp import FastMCP + +# Initialize FastMCP server +mcp = FastMCP("code_search") + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "../..", "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + +# GitHub API configuration +GITHUB_API_BASE = "https://api.github.com" +GITHUB_TOKEN = config["core"].get("github_token", None) + + +async def make_github_request(url: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Make a request to the GitHub API with proper error handling.""" + headers = {"Accept": "application/vnd.github.v3+json"} + if GITHUB_TOKEN: + headers["Authorization"] = f"token {GITHUB_TOKEN}" + + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=headers, params=params, timeout=30.0) + response.raise_for_status() + result: Dict[str, Any] = response.json() + return result + except Exception as e: + print(f"GitHub API request failed: {e}") + return None + + +def format_github_repo_query(idea: Dict[str, Any], max_terms: int = 6, max_query_length: int = 250) -> str: + """Format a research idea into a GitHub search query.""" + title = idea.get("Title", "") + experiment = idea.get("Experiment", "") + combined_text = f"{title}. {experiment}" + + try: + nlp = spacy.load("en_core_web_sm") + doc = nlp(combined_text) + candidates = set() + + # Extract short noun phrases + for chunk in doc.noun_chunks: + phrase = chunk.text.strip().lower() + if 1 <= len(phrase.split()) <= 4: + candidates.add(phrase) + + # Add important standalone nouns and proper nouns + for token in doc: + if token.pos_ in {"NOUN", "PROPN"} and len(token.text) > 2: + candidates.add(token.text.lower()) + + # Clean and deduplicate + seen = set() + keywords = [] + for kw in candidates: + cleaned = re.sub(r"[^\w\s]", "", kw) + if cleaned not in seen: + seen.add(cleaned) + keywords.append(cleaned) + if len(keywords) >= max_terms: + break + + # Build query string + quoted_keywords = [f'"{kw}"' if " " in kw else kw for kw in keywords] + base_query = " ".join(quoted_keywords) + suffix = " in:file language:python" + full_query = f"{base_query} {suffix}" + + # Truncate if needed + if len(full_query) > max_query_length: + full_query = f"{' '.join(quoted_keywords[:max_terms//2])} {suffix}" + + return full_query + except Exception: + # Fallback to simple keyword extraction + return f"{title} {experiment} language:python" + + +def extract_github_repo_info(repos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract relevant information from GitHub repository search results.""" + return [ + { + "name": repo["name"], + "owner": repo["owner"]["login"], + "stars": repo["stargazers_count"], + "forks": repo["forks_count"], + "url": repo["html_url"], + "description": repo["description"] or "No description provided.", + } + for repo in repos + ] + + +def extract_github_code_info(code_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract relevant information from GitHub code search results.""" + return [ + { + "file_name": item["name"], + "repository": item["repository"]["full_name"], + "url": item["html_url"], + } + for item in code_results + ] + + +@mcp.tool() +async def search_github_repositories(query: str, result_limit: int = 10) -> str: + """Search GitHub repositories. + + Args: + query: Search query string or JSON string containing research idea + result_limit: Maximum number of results to return (default: 10) + """ + print(f"[GitHub API] Searching repositories with query: {query}") + + # Try to parse as JSON (research idea format) + try: + idea = json.loads(query) + if isinstance(idea, dict) and any(k in idea for k in ["Title", "Experiment"]): + formatted_query = format_github_repo_query(idea) + print(f"[GitHub API] Formatted query from idea: {formatted_query}") + else: + formatted_query = query + except (json.JSONDecodeError, TypeError): + formatted_query = query + + url = f"{GITHUB_API_BASE}/search/repositories" + params = { + "q": formatted_query, + "sort": "stars", + "order": "desc", + "per_page": min(result_limit, 100), + } + + data = await make_github_request(url, params) + if not data or "items" not in data: + return json.dumps({"error": "Unable to fetch repositories or no repositories found."}) + + repos = extract_github_repo_info(data["items"]) + + # Format results for return + results = {} + for i, repo in enumerate(repos): + results[str(i)] = { + "title": repo["name"], + "source": repo["url"], + "info": f"Stars: {repo['stars']}, Owner: {repo['owner']}", + "description": repo["description"] + } + + return json.dumps(results, indent=2) + + +@mcp.tool() +async def search_github_code(query: str, result_limit: int = 10) -> str: + """Search GitHub code files. + + Args: + query: Search query string + result_limit: Maximum number of results to return (default: 10) + """ + print(f"[GitHub API] Searching code with query: {query}") + + url = f"{GITHUB_API_BASE}/search/code" + params = { + "q": query, + "sort": "indexed", + "order": "desc", + "per_page": min(result_limit, 100), + } + + data = await make_github_request(url, params) + if not data or "items" not in data: + return json.dumps({"error": "Unable to fetch code results or no code found."}) + + code_results = extract_github_code_info(data["items"]) + + # Format results for return + results = {} + for i, code in enumerate(code_results): + results[str(i)] = { + "title": code["file_name"], + "source": code["url"], + "info": f"Repository: {code['repository']}", + } + + return json.dumps(results, indent=2) + + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport='stdio') \ No newline at end of file diff --git a/tiny_scientist/mcp/drawer_server.py b/tiny_scientist/mcp/drawer_server.py new file mode 100644 index 00000000..cdde6cc9 --- /dev/null +++ b/tiny_scientist/mcp/drawer_server.py @@ -0,0 +1,244 @@ +import json +import os +import re +from importlib import resources +from typing import Any, Dict, Optional + +import fitz +import httpx +import toml +from mcp.server.fastmcp import FastMCP + +from tiny_scientist.configs import Config + +# Initialize FastMCP server +mcp = FastMCP("drawer") + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "../..", "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + +# LLM configuration +LLM_MODEL = config["core"].get("model", "gpt-4o-mini") +LLM_API_KEY = config["core"].get("llm_api_key", "") +LLM_TEMPERATURE = config["core"].get("temperature", 0.75) + +# Load prompt templates from the configs module +prompt_config = Config() +prompts = prompt_config.prompt_template.drawer_prompt + + +def escape_curly_braces(text: str) -> str: + """Escape curly braces in text to prevent format string issues.""" + return re.sub(r"({|})", r"{{\1}}", text) + + +def extract_pdf_text_from_resource(package: str, filename: str) -> str: + """Extract text from a PDF resource file.""" + with resources.files(package).joinpath(filename).open("rb") as f: + doc = fitz.open(stream=f.read(), filetype="pdf") + extracted = [page.get_text().strip() for page in doc] + return "\n\n".join(extracted) + + +def get_section_prompts(section_name: str, section_text: str) -> str: + """Get section-specific prompts.""" + section_prompt = prompts.section_prompt[section_name].format( + section_text=section_text + ) + return section_prompt + + +async def make_llm_request(prompt: str, system_message: str) -> Optional[str]: + """Make a request to the LLM API.""" + headers = { + "Authorization": f"Bearer {LLM_API_KEY}", + "Content-Type": "application/json" + } + + data = { + "model": LLM_MODEL, + "messages": [ + {"role": "system", "content": system_message}, + {"role": "user", "content": prompt} + ], + "temperature": LLM_TEMPERATURE + } + + async with httpx.AsyncClient() as client: + try: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=data, + timeout=60.0 + ) + response.raise_for_status() + result = response.json() + content = result["choices"][0]["message"]["content"] + return content if isinstance(content, str) else None + except Exception as e: + print(f"LLM API request failed: {e}") + return None + + +def extract_diagram_data(response: str) -> Dict[str, Any]: + """Extract diagram data from LLM response.""" + result = {"summary": "", "svg": "", "full_response": response} + + try: + parsed = json.loads(response) + summary = parsed["summary"] + svg = parsed["svg"] + except json.JSONDecodeError: + svg_match = re.search(r"", response, re.DOTALL) + svg = svg_match.group(0) if svg_match else "" + summary = ( + re.sub(r"", "", response, flags=re.DOTALL) + .strip() + .split("\n")[0] + ) + + if "" in svg: + result["summary"] = summary + result["svg"] = clean_svg(svg) + else: + print("[ERROR] SVG missing or too short.") + return result + + +def clean_svg(svg: str) -> str: + """Clean and format SVG content.""" + # Strip any outer code block delimiters + svg = svg.strip() + svg = re.sub(r"^```(?:svg)?", "", svg) + svg = re.sub(r"```$", "", svg) + + # Replace problematic ampersands + svg = svg.replace("&", "&") + + # Ensure no double XML declarations + svg = re.sub(r"<\?xml.*?\?>", "", svg, count=1) + + # Remove extra whitespace lines + svg = "\n".join([line for line in svg.splitlines() if line.strip()]) + + return svg.strip() + + +# Initialize system prompt with sample data +def initialize_system_prompt() -> str: + """Initialize the system prompt with sample data.""" + try: + method_sample_raw = extract_pdf_text_from_resource( + "tiny_scientist.fewshot_sample", "framework.pdf" + ) + result_sample_raw = extract_pdf_text_from_resource( + "tiny_scientist.fewshot_sample", "result.pdf" + ) + + method_sample = escape_curly_braces(method_sample_raw) + result_sample = escape_curly_braces(result_sample_raw) + + return prompts.diagram_system_prompt.format( + method_sample=method_sample, + result_sample=result_sample, + ) + except Exception as e: + print(f"[WARNING] Failed to load sample data: {e}") + return "You are a diagram generation assistant. Generate SVG diagrams based on research paper sections." + + +SYSTEM_PROMPT = initialize_system_prompt() + + +@mcp.tool() +async def generate_diagram(section_name: str, section_content: str) -> str: + """Generate an SVG diagram for a research paper section. + + Args: + section_name: Name of the paper section (e.g., "Method", "Results") + section_content: Content of the section to visualize + """ + print(f"[Drawer] Generating diagram for section: {section_name}") + + if not section_content.strip(): + return json.dumps({"error": "Section content cannot be empty"}) + + # Get section-specific prompts + section_prompt = get_section_prompts(section_name, section_content) + + # Generate diagram using LLM + llm_response = await make_llm_request(section_prompt, SYSTEM_PROMPT) + + if not llm_response: + return json.dumps({"error": "Failed to generate diagram from LLM"}) + + # Extract diagram data + diagram = extract_diagram_data(llm_response) + + # Format response + result = { + "diagram": { + "summary": diagram.get("summary", ""), + "svg": diagram.get("svg", ""), + } + } + + return json.dumps(result, indent=2) + + +@mcp.tool() +async def validate_svg(svg_content: str) -> str: + """Validate and clean SVG content. + + Args: + svg_content: SVG content to validate and clean + """ + print("[Drawer] Validating and cleaning SVG content") + + if not svg_content.strip(): + return json.dumps({"error": "SVG content cannot be empty"}) + + try: + cleaned_svg = clean_svg(svg_content) + + # Basic validation - check if it looks like valid SVG + if "" in cleaned_svg: + result = { + "valid": True, + "cleaned_svg": cleaned_svg, + "message": "SVG is valid and has been cleaned" + } + else: + result = { + "valid": False, + "cleaned_svg": "", + "message": "SVG appears to be invalid or incomplete" + } + + return json.dumps(result, indent=2) + except Exception as e: + return json.dumps({ + "valid": False, + "cleaned_svg": "", + "message": f"Error validating SVG: {str(e)}" + }) + + +@mcp.tool() +async def get_supported_sections() -> str: + """Get list of supported section types for diagram generation.""" + supported_sections = list(prompts.section_prompt.keys()) + + result = { + "supported_sections": supported_sections, + "description": "These are the section types that have specialized prompts for diagram generation" + } + + return json.dumps(result, indent=2) + + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport='stdio') \ No newline at end of file diff --git a/tiny_scientist/mcp/paper_search_server.py b/tiny_scientist/mcp/paper_search_server.py new file mode 100644 index 00000000..8ca6d416 --- /dev/null +++ b/tiny_scientist/mcp/paper_search_server.py @@ -0,0 +1,232 @@ +import asyncio +import json +import os +from typing import Any, Dict, List, Optional + +import httpx +import toml +from mcp.server.fastmcp import FastMCP + +# Initialize FastMCP server +mcp = FastMCP("paper_search") + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "../..", "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + +# Semantic Scholar API configuration +S2_API_BASE = "https://api.semanticscholar.org/graph/v1" +S2_API_KEY = config["core"].get("s2_api_key", None) +SEARCH_ENGINE = config["core"].get("engine", "semanticscholar") + +# Debug: Print configuration status +print(f"[Paper Search] Config path: {config_path}") +print(f"[Paper Search] Config exists: {os.path.exists(config_path)}") +print(f"[Paper Search] API Key configured: {'Yes' if S2_API_KEY else 'No'}") +print(f"[Paper Search] Search engine: {SEARCH_ENGINE}") + + +async def make_s2_request(url: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None) -> Optional[Dict[str, Any]]: + """Make a request to the Semantic Scholar API with proper error handling.""" + default_headers = {} + + # Temporarily disable API key due to invalid key issue + # TODO: Update with a valid API key when available + use_api_key = False # Set to True when you have a valid API key + + if S2_API_KEY and use_api_key: + default_headers["X-API-KEY"] = S2_API_KEY + print(f"[Paper Search] Using API key: {S2_API_KEY[:10]}...") + else: + print("[Paper Search] Using unauthenticated access (rate limited)") + + if headers: + default_headers.update(headers) + + async with httpx.AsyncClient() as client: + try: + response = await client.get(url, headers=default_headers, params=params, timeout=30.0) + print(f"[Paper Search] Response status: {response.status_code}") + response.raise_for_status() + result: Dict[str, Any] = response.json() + if result.get('data'): + print(f"[Paper Search] Found {len(result['data'])} papers") + return result + except Exception as e: + print(f"[Paper Search] Semantic Scholar API request failed: {e}") + if hasattr(e, 'response'): + print(f"[Paper Search] Response text: {e.response.text if e.response else 'No response'}") + return None + + +async def make_openalex_request(query: str, result_limit: int) -> Optional[List[Dict[str, Any]]]: + """Make a request to OpenAlex API.""" + try: + import pyalex + from pyalex import Works + + mail = os.environ.get("OPENALEX_MAIL_ADDRESS") + if mail: + pyalex.config.email = mail + else: + print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better API access") + + works = Works().search(query).get(per_page=result_limit) + if not works: + return None + + return [extract_openalex_work_info(work) for work in works] + except ImportError: + print("[ERROR] pyalex not installed, falling back to Semantic Scholar") + return None + except Exception as e: + print(f"OpenAlex API request failed: {e}") + return None + + +def extract_openalex_work_info(work: Dict[str, Any], max_abstract_length: int = 1000) -> Dict[str, str]: + """Extract relevant information from OpenAlex work data.""" + venue = next( + ( + loc["source"]["display_name"] + for loc in work["locations"] + if loc["source"] + ), + "Unknown", + ) + + authors_list = [ + author["author"]["display_name"] for author in work["authorships"] + ] + authors = ( + " and ".join(authors_list) + if len(authors_list) < 20 + else f"{authors_list[0]} et al." + ) + + abstract = work.get("abstract", "") + if len(abstract) > max_abstract_length: + print(f"[WARNING] {work['title']}: Abstract is too long, truncating.") + abstract = abstract[:max_abstract_length] + + return { + "title": work["title"], + "authors": authors, + "venue": venue, + "year": str(work.get("publication_year", "Unknown")), + "abstract": abstract, + "citationCount": str(work.get("cited_by_count", 0)), + } + + +@mcp.tool() +async def search_papers(query: str, result_limit: int = 3) -> str: + """Search for academic papers using Semantic Scholar or OpenAlex. + + Args: + query: Search query string for papers + result_limit: Maximum number of papers to return (default: 3) + """ + print(f"[Paper Search] Searching for papers with query: {query}") + + if not query: + return json.dumps({"error": "No query provided"}) + + papers = None + + if SEARCH_ENGINE == "semanticscholar": + print(f"(Semantic Scholar API) Searching for papers with query: {query}") + papers = await search_semanticscholar(query, result_limit) + elif SEARCH_ENGINE == "openalex": + print(f"(OpenAlex API) Searching for papers with query: {query}") + papers = await make_openalex_request(query, result_limit) + else: + return json.dumps({"error": f"Unsupported search engine: {SEARCH_ENGINE}"}) + + if not papers: + return json.dumps({"error": "No papers found or API error"}) + + # Format papers and fetch bibtex for Semantic Scholar results + results = {} + for paper in papers: + paper_id = paper.get("paperId", None) + bibtex = "N/A" + + if SEARCH_ENGINE == "semanticscholar" and paper_id: + bibtex = await fetch_bibtex(paper_id) + + if bibtex and bibtex != "N/A": + title = paper.get("title", "Unknown Title") + results[title] = { + "title": title, + "bibtex": bibtex + } + + return json.dumps(results, indent=2) + + +async def search_semanticscholar(query: str, result_limit: int) -> Optional[List[Dict[str, Any]]]: + """Search Semantic Scholar for papers.""" + params = { + "query": query, + "limit": result_limit, + "fields": "title,authors,venue,year,abstract,citationStyles,citationCount,paperId", + } + + url = f"{S2_API_BASE}/paper/search" + data = await make_s2_request(url, params) + + if not data or not data.get("total"): + return None + + # Add a small delay to be respectful to the API + await asyncio.sleep(8.0) + result = data.get("data") + return result if isinstance(result, list) else None + + +@mcp.tool() +async def fetch_bibtex(paper_id: str) -> str: + """Fetch BibTeX citation for a paper by its Semantic Scholar ID. + + Args: + paper_id: Semantic Scholar paper ID + """ + print(f"[Paper Search] Fetching BibTeX for paper ID: {paper_id}") + + url = f"{S2_API_BASE}/paper/{paper_id}" + params = {"fields": "citationStyles"} + + data = await make_s2_request(url, params) + if not data: + return "N/A" + + citation_styles = data.get("citationStyles", {}) + bibtex = citation_styles.get("bibtex", "N/A") + return bibtex if isinstance(bibtex, str) else "N/A" + + +@mcp.tool() +async def get_paper_details(paper_id: str) -> str: + """Get detailed information about a paper by its Semantic Scholar ID. + + Args: + paper_id: Semantic Scholar paper ID + """ + print(f"[Paper Search] Getting details for paper ID: {paper_id}") + + url = f"{S2_API_BASE}/paper/{paper_id}" + params = {"fields": "title,authors,venue,year,abstract,citationCount,citationStyles"} + + data = await make_s2_request(url, params) + if not data: + return json.dumps({"error": "Paper not found or API error"}) + + return json.dumps(data, indent=2) + + +# Import asyncio at the end to avoid issues + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport='stdio') \ No newline at end of file diff --git a/tiny_scientist/tool.py b/tiny_scientist/mcp/tool.py similarity index 98% rename from tiny_scientist/tool.py rename to tiny_scientist/mcp/tool.py index 80ce1eec..ae48e016 100644 --- a/tiny_scientist/tool.py +++ b/tiny_scientist/mcp/tool.py @@ -11,10 +11,10 @@ import toml from rich import print -from .configs import Config -from .utils.cost_tracker import CostTracker -from .utils.error_handler import api_calling_error_exponential_backoff -from .utils.llm import create_client, get_response_from_llm +from ..configs import Config +from ..utils.cost_tracker import CostTracker +from ..utils.error_handler import api_calling_error_exponential_backoff +from ..utils.llm import create_client, get_response_from_llm # Load config config_path = os.path.join(os.path.dirname(__file__), "config.toml") diff --git a/tiny_scientist/reviewer.py b/tiny_scientist/reviewer.py index d6141043..1fab22b6 100644 --- a/tiny_scientist/reviewer.py +++ b/tiny_scientist/reviewer.py @@ -4,7 +4,7 @@ from rich import print from .configs import Config -from .tool import BaseTool, PaperSearchTool +from .mcp.tool import BaseTool, PaperSearchTool from .utils.cost_tracker import CostTracker from .utils.error_handler import api_calling_error_exponential_backoff from .utils.input_formatter import InputFormatter @@ -25,6 +25,7 @@ def __init__( temperature: float = 0.75, prompt_template_dir: Optional[str] = None, cost_tracker: Optional[CostTracker] = None, + mcp_client: Any = None, ): self.tools = tools self.num_reviews = num_reviews @@ -32,7 +33,9 @@ def __init__( self.client, self.model = create_client(model) self.temperature = temperature self.config = Config(prompt_template_dir) - self.searcher = PaperSearchTool() + self.mcp_client = mcp_client + # Fallback to traditional searcher if MCP is not available + self.searcher = PaperSearchTool() if not mcp_client else None self._query_cache: Dict[str, List[Dict[str, Any]]] = {} self.last_related_works_string = "" self.cost_tracker = cost_tracker or CostTracker() @@ -107,8 +110,86 @@ def _get_related_works(self, query: str) -> str: if query in self._query_cache: related_papers = self._query_cache[query] else: - results_dict = self.searcher.run(query) - related_papers = list(results_dict.values()) + if self.mcp_client: + # Use MCP client for paper search + import asyncio + + from .utils.mcp_client import search_papers + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_search() -> Optional[str]: + """Run the async search function in a new event loop.""" + return asyncio.run(search_papers(query, self.mcp_client)) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_search) + results_json = future.result(timeout=30.0) # Add timeout + + if results_json: + import json + results_dict = json.loads(results_json) + if results_dict: + # Convert MCP format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # MCP format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # MCP doesn't return author info + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}" + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue" + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + except Exception as e: + print(f"[WARNING] MCP search failed, falling back to traditional search: {e}") + if self.searcher: + results_dict = self.searcher.run(query) + related_papers = list(results_dict.values()) + else: + related_papers = [] + else: + # Use traditional searcher + if self.searcher: + results_dict = self.searcher.run(query) + if results_dict: + # Convert traditional format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # Traditional format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # Traditional tool doesn't return author info either + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}" + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue" + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + self._query_cache[query] = related_papers if related_papers else [] if related_papers: diff --git a/tiny_scientist/scientist.py b/tiny_scientist/scientist.py index 5c216ae1..b28957c5 100644 --- a/tiny_scientist/scientist.py +++ b/tiny_scientist/scientist.py @@ -1,3 +1,5 @@ +import datetime +import os from typing import Any, Dict, List, Optional, Tuple, Union from rich import print @@ -7,6 +9,7 @@ from .thinker import Thinker from .utils.cost_tracker import CostTracker from .utils.input_formatter import InputFormatter +from .utils.mcp_client import MCPClient from .writer import Writer @@ -18,22 +21,37 @@ def __init__( template: str = "acl", prompt_template_dir: Optional[str] = None, budget: Optional[float] = None, + use_mcp: bool = True, ): self.model = model - self.output_dir = output_dir + self.base_output_dir = output_dir # Store user's base directory + + # Create a unique experiment directory with timestamp + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + self.experiment_dir = os.path.join(output_dir, f"experiment_{timestamp}") + + # Ensure the experiment directory exists + os.makedirs(self.experiment_dir, exist_ok=True) + print(f"๐Ÿ”ฌ Created experiment directory: {self.experiment_dir}") + self.template = template self.prompt_template_dir = prompt_template_dir self.input_formatter = InputFormatter() + self.use_mcp = use_mcp self.cost = 0.0 + # Initialize MCP client if enabled + self.mcp_client = MCPClient() if use_mcp else None + # Naive budget split modules = ["thinker", "coder", "writer", "reviewer"] per_module_budget = budget / len(modules) if budget else None + # Use the unique experiment directory for all modules self.thinker = Thinker( model=model, - output_dir=output_dir, + output_dir=self.experiment_dir, prompt_template_dir=prompt_template_dir, tools=[], iter_num=3, @@ -41,23 +59,26 @@ def __init__( generate_exp_plan=True, enable_ethical_defense=False, cost_tracker=CostTracker(budget=per_module_budget), + mcp_client=self.mcp_client, ) self.coder = Coder( model=model, - output_dir=output_dir, + output_dir=self.experiment_dir, prompt_template_dir=prompt_template_dir, max_iters=4, max_runs=3, cost_tracker=CostTracker(budget=per_module_budget), + mcp_client=self.mcp_client, ) self.writer = Writer( model=model, - output_dir=output_dir, + output_dir=self.experiment_dir, prompt_template_dir=prompt_template_dir, template=template, cost_tracker=CostTracker(budget=per_module_budget), + mcp_client=self.mcp_client, ) self.reviewer = Reviewer( @@ -65,8 +86,35 @@ def __init__( prompt_template_dir=prompt_template_dir, tools=[], cost_tracker=CostTracker(budget=per_module_budget), + mcp_client=self.mcp_client, ) + async def initialize_mcp(self) -> None: + """Initialize MCP servers.""" + if self.mcp_client: + print("๐Ÿ”ง Initializing MCP servers...") + results = await self.mcp_client.start_all_servers() + for server_name, success in results.items(): + if success: + print(f"โœ… MCP server '{server_name}' started successfully") + else: + print(f"โŒ Failed to start MCP server '{server_name}'") + + async def cleanup_mcp(self) -> None: + """Clean up MCP servers.""" + if self.mcp_client: + print("๐Ÿงน Shutting down MCP servers...") + await self.mcp_client.stop_all_servers() + + async def __aenter__(self) -> "TinyScientist": + """Async context manager entry.""" + await self.initialize_mcp() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.cleanup_mcp() + def think( self, intent: str, num_ideas: int = 1, pdf_content: Optional[str] = None ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: @@ -93,11 +141,13 @@ def code( print(f"โŒ Experiment failed. Please check {exp_path} for details.") if error_details: print(f"Error details: {error_details}") - return status, exp_path + return status, self.experiment_dir - def write(self, idea: Dict[str, Any], experiment_dir: str) -> str: + def write(self, idea: Dict[str, Any], experiment_dir: Optional[str] = None) -> str: print("๐Ÿ“ Writing paper...") - pdf_path, paper_name = self.writer.run(idea=idea, experiment_dir=experiment_dir) + # Use the internal experiment directory if no specific directory is provided + exp_dir = experiment_dir if experiment_dir is not None else self.experiment_dir + pdf_path, paper_name = self.writer.run(idea=idea, experiment_dir=exp_dir) print( f"Check the generated paper named as {paper_name} and saved at {pdf_path}" ) diff --git a/tiny_scientist/thinker.py b/tiny_scientist/thinker.py index 0093c037..6ea8d91e 100644 --- a/tiny_scientist/thinker.py +++ b/tiny_scientist/thinker.py @@ -5,7 +5,7 @@ from rich import print from .configs import Config -from .tool import PaperSearchTool +from .mcp.tool import PaperSearchTool from .utils.cost_tracker import CostTracker from .utils.error_handler import api_calling_error_exponential_backoff from .utils.llm import ( @@ -28,6 +28,7 @@ def __init__( prompt_template_dir: Optional[str] = None, cost_tracker: Optional[CostTracker] = None, enable_ethical_defense: bool = False, + mcp_client: Any = None, ): self.tools = tools self.iter_num = iter_num @@ -35,7 +36,9 @@ def __init__( self.output_dir = output_dir self.temperature = temperature self.config = Config(prompt_template_dir) - self.searcher = PaperSearchTool() + self.mcp_client = mcp_client + # Fallback to traditional searcher if MCP is not available + self.searcher = PaperSearchTool() if not mcp_client else None self.search_papers = search_papers self.generate_exp_plan = generate_exp_plan self.prompts = self.config.prompt_template.thinker_prompt @@ -441,8 +444,87 @@ def _get_related_works(self, query: str) -> str: print("โœ… Using cached query results") else: print(f"Searching for papers with query: {query}") - results_dict = self.searcher.run(query) - related_papers = list(results_dict.values()) if results_dict else [] + + if self.mcp_client: + # Use MCP client for paper search + import asyncio + + from .utils.mcp_client import search_papers + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_search() -> Optional[str]: + """Run the async search function in a new event loop.""" + return asyncio.run(search_papers(query, self.mcp_client)) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_search) + results_json = future.result(timeout=30.0) # Add timeout + + if results_json: + import json + results_dict = json.loads(results_json) + if results_dict: + # Convert MCP format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # MCP format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # MCP doesn't return author info + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}" + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue" + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + except Exception as e: + print(f"[WARNING] MCP search failed, falling back to traditional search: {e}") + if self.searcher: + results_dict = self.searcher.run(query) + related_papers = list(results_dict.values()) if results_dict else [] + else: + related_papers = [] + else: + # Use traditional searcher + if self.searcher: + results_dict = self.searcher.run(query) + if results_dict: + # Convert traditional format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # Traditional format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # Traditional tool doesn't return author info either + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}" + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue" + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + self._query_cache[query] = related_papers if related_papers: diff --git a/tiny_scientist/utils/mcp_client.py b/tiny_scientist/utils/mcp_client.py new file mode 100644 index 00000000..57800aa9 --- /dev/null +++ b/tiny_scientist/utils/mcp_client.py @@ -0,0 +1,480 @@ +import json +import os +import subprocess +from typing import Any, Dict, List, Optional + +import toml +from rich import print + + +class MCPClient: + """Client for managing and communicating with MCP servers.""" + + def __init__(self, config_path: Optional[str] = None): + """Initialize MCP client with configuration. + + Args: + config_path: Path to configuration file containing MCP server settings + """ + self.config_path = config_path or self._get_default_config_path() + self.config = self._load_config() + self.servers: Dict[str, subprocess.Popen[str]] = {} + self.server_configs = self.config.get("mcp", {}).get("servers", {}) + + def _get_default_config_path(self) -> str: + """Get default config path.""" + this_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + return os.path.join(this_dir, "config.toml") + + def _load_config(self) -> Dict[str, Any]: + """Load configuration from TOML file.""" + try: + with open(self.config_path, 'r') as f: + return toml.load(f) + except FileNotFoundError: + print(f"[WARNING] Config file not found: {self.config_path}") + return {} + except Exception as e: + print(f"[ERROR] Failed to load config: {e}") + return {} + + async def start_server(self, server_name: str) -> bool: + """Start a specific MCP server. + + Args: + server_name: Name of the server to start + + Returns: + bool: True if server started successfully + """ + if server_name in self.servers: + print(f"[MCP] Server {server_name} is already running") + return True + + server_config = self.server_configs.get(server_name) + if not server_config: + print(f"[ERROR] No configuration found for server: {server_name}") + return False + + try: + command = server_config.get("command", "") + args = server_config.get("args", []) + working_dir = server_config.get("cwd") + + # Build full command + full_command = [command] + args + + # Start the server process + process = subprocess.Popen( + full_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=working_dir, + text=True + ) + + self.servers[server_name] = process + + # Perform MCP initialization handshake + init_success = await self._initialize_server(server_name) + if not init_success: + await self.stop_server(server_name) + return False + + print(f"[MCP] Started server: {server_name}") + return True + + except Exception as e: + print(f"[ERROR] Failed to start server {server_name}: {e}") + return False + + async def _initialize_server(self, server_name: str) -> bool: + """Initialize MCP server with proper handshake. + + Args: + server_name: Name of the server to initialize + + Returns: + bool: True if initialization successful + """ + if server_name not in self.servers: + return False + + try: + process = self.servers[server_name] + + # Send initialize request + init_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "clientInfo": { + "name": "tiny-scientist-mcp-client", + "version": "1.0.0" + } + } + } + + request_json = json.dumps(init_request) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return False + process.stdin.write(request_json) + process.stdin.flush() + + # Read initialization response + if process.stdout is None: + print(f"[ERROR] No stdout available for server {server_name}") + return False + response_line = process.stdout.readline() + if not response_line: + print(f"[ERROR] No initialization response from {server_name}") + return False + + response = json.loads(response_line.strip()) + + # Check for initialization success + if "error" in response: + print(f"[ERROR] Server initialization failed: {response['error']}") + return False + + # Send initialized notification + initialized_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + } + + notification_json = json.dumps(initialized_notification) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return False + process.stdin.write(notification_json) + process.stdin.flush() + + return True + + except Exception as e: + print(f"[ERROR] Failed to initialize server {server_name}: {e}") + return False + + async def stop_server(self, server_name: str) -> bool: + """Stop a specific MCP server. + + Args: + server_name: Name of the server to stop + + Returns: + bool: True if server stopped successfully + """ + if server_name not in self.servers: + print(f"[WARNING] Server {server_name} is not running") + return True + + try: + process = self.servers[server_name] + process.terminate() + + # Wait for process to terminate + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + del self.servers[server_name] + print(f"[MCP] Stopped server: {server_name}") + return True + + except Exception as e: + print(f"[ERROR] Failed to stop server {server_name}: {e}") + return False + + async def start_all_servers(self) -> Dict[str, bool]: + """Start all configured MCP servers. + + Returns: + Dict mapping server names to success status + """ + results = {} + for server_name in self.server_configs.keys(): + results[server_name] = await self.start_server(server_name) + return results + + async def stop_all_servers(self) -> Dict[str, bool]: + """Stop all running MCP servers. + + Returns: + Dict mapping server names to success status + """ + results = {} + for server_name in list(self.servers.keys()): + results[server_name] = await self.stop_server(server_name) + return results + + async def call_tool(self, server_name: str, tool_name: str, **kwargs: Any) -> Optional[str]: + """Call a tool on a specific MCP server. + + Args: + server_name: Name of the server to call + tool_name: Name of the tool to call + **kwargs: Tool parameters + + Returns: + Tool response as string, or None if error + """ + if server_name not in self.servers: + print(f"[ERROR] Server {server_name} is not running") + return None + + try: + process = self.servers[server_name] + + # Create tool call request + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": kwargs + } + } + + # Send request to server + request_json = json.dumps(request) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return None + process.stdin.write(request_json) + process.stdin.flush() + + # Read response + if process.stdout is None: + print(f"[ERROR] No stdout available for server {server_name}") + return None + response_line = process.stdout.readline() + if not response_line: + print(f"[ERROR] No response from server {server_name}") + return None + + response = json.loads(response_line.strip()) + + # Check for errors + if "error" in response: + print(f"[ERROR] Tool call failed: {response['error']}") + return None + + # Extract result + result = response.get("result", {}) + if isinstance(result, dict) and "content" in result: + content = result["content"][0].get("text", "") + return content if isinstance(content, str) else str(content) + elif isinstance(result, str): + return result + else: + return json.dumps(result) + + except Exception as e: + print(f"[ERROR] Failed to call tool {tool_name} on {server_name}: {e}") + return None + + async def get_available_tools(self, server_name: str) -> Optional[List[Dict[str, Any]]]: + """Get list of available tools from a server. + + Args: + server_name: Name of the server to query + + Returns: + List of tool definitions, or None if error + """ + if server_name not in self.servers: + print(f"[ERROR] Server {server_name} is not running") + return None + + try: + process = self.servers[server_name] + + # Create list tools request + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + } + + # Send request to server + request_json = json.dumps(request) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return None + process.stdin.write(request_json) + process.stdin.flush() + + # Read response + if process.stdout is None: + print(f"[ERROR] No stdout available for server {server_name}") + return None + response_line = process.stdout.readline() + if not response_line: + print(f"[ERROR] No response from server {server_name}") + return None + + response = json.loads(response_line.strip()) + + # Check for errors + if "error" in response: + print(f"[ERROR] Failed to list tools: {response['error']}") + return None + + # Extract tools + result = response.get("result", {}) + tools = result.get("tools", []) + return tools if isinstance(tools, list) else [] + + except Exception as e: + print(f"[ERROR] Failed to get tools from {server_name}: {e}") + return None + + def is_server_running(self, server_name: str) -> bool: + """Check if a server is currently running. + + Args: + server_name: Name of the server to check + + Returns: + True if server is running + """ + if server_name not in self.servers: + return False + + process = self.servers[server_name] + return process.poll() is None + + def get_running_servers(self) -> List[str]: + """Get list of currently running servers. + + Returns: + List of server names + """ + return [name for name in self.servers.keys() if self.is_server_running(name)] + + async def health_check(self) -> Dict[str, bool]: + """Perform health check on all configured servers. + + Returns: + Dict mapping server names to health status + """ + results = {} + for server_name in self.server_configs.keys(): + if self.is_server_running(server_name): + # Try to get tools as a health check + tools = await self.get_available_tools(server_name) + results[server_name] = tools is not None + else: + results[server_name] = False + return results + + async def __aenter__(self) -> "MCPClient": + """Async context manager entry.""" + await self.start_all_servers() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.stop_all_servers() + + +# Convenience functions for common operations +async def search_github_code(query: str, client: MCPClient, result_limit: int = 10) -> Optional[str]: + """Search GitHub code using MCP client. + + Args: + query: Search query + client: MCP client instance + result_limit: Maximum results to return + + Returns: + Search results as JSON string + """ + if not client.is_server_running("code_search"): + await client.start_server("code_search") + + return await client.call_tool( + "code_search", + "search_github_code", + query=query, + result_limit=result_limit + ) + + +async def search_github_repositories(query: str, client: MCPClient, result_limit: int = 10) -> Optional[str]: + """Search GitHub repositories using MCP client. + + Args: + query: Search query or JSON research idea + client: MCP client instance + result_limit: Maximum results to return + + Returns: + Search results as JSON string + """ + if not client.is_server_running("code_search"): + await client.start_server("code_search") + + return await client.call_tool( + "code_search", + "search_github_repositories", + query=query, + result_limit=result_limit + ) + + +async def search_papers(query: str, client: MCPClient, result_limit: int = 3) -> Optional[str]: + """Search papers using MCP client. + + Args: + query: Search query + client: MCP client instance + result_limit: Maximum results to return + + Returns: + Search results as JSON string + """ + if not client.is_server_running("paper_search"): + await client.start_server("paper_search") + + return await client.call_tool( + "paper_search", + "search_papers", + query=query, + result_limit=result_limit + ) + + +async def generate_diagram(section_name: str, section_content: str, client: MCPClient) -> Optional[str]: + """Generate diagram using MCP client. + + Args: + section_name: Name of the paper section + section_content: Content of the section + client: MCP client instance + + Returns: + Diagram data as JSON string + """ + if not client.is_server_running("drawer"): + await client.start_server("drawer") + + return await client.call_tool( + "drawer", + "generate_diagram", + section_name=section_name, + section_content=section_content + ) \ No newline at end of file diff --git a/tiny_scientist/writer.py b/tiny_scientist/writer.py index 143767ec..b9200547 100644 --- a/tiny_scientist/writer.py +++ b/tiny_scientist/writer.py @@ -11,7 +11,7 @@ from rich import print from .configs import Config -from .tool import BaseTool, DrawerTool, PaperSearchTool +from .mcp.tool import BaseTool, DrawerTool, PaperSearchTool from .utils.cost_tracker import CostTracker from .utils.llm import ( create_client, @@ -35,13 +35,16 @@ def __init__( prompt_template_dir: Optional[str] = None, cost_tracker: Optional[CostTracker] = None, s2_api_key: Optional[str] = None, + mcp_client: Any = None, ) -> None: self.client, self.model = create_client(model) self.output_dir = output_dir self.template = template self.temperature = temperature - self.searcher: BaseTool = PaperSearchTool(s2_api_key=s2_api_key) - self.drawer: BaseTool = DrawerTool(model, prompt_template_dir, temperature) + self.mcp_client = mcp_client + # Fallback to traditional tools if MCP is not available + self.searcher: Optional[BaseTool] = PaperSearchTool(s2_api_key=s2_api_key) if not mcp_client else None + self.drawer: Optional[BaseTool] = DrawerTool(model, prompt_template_dir, temperature) if not mcp_client else None self.formatter: BaseOutputFormatter self.config = Config(prompt_template_dir) if self.template == "acl": @@ -158,10 +161,48 @@ def _generate_diagram_for_section(self) -> None: for section in ["Method", "Experimental_Setup", "Results"]: content = self.generated_sections[section] try: - query = json.dumps( - {"section_name": section, "section_content": content} - ) - diagram_result = self.drawer.run(query) + if self.mcp_client: + # Use MCP client for diagram generation + import asyncio + + from .utils.mcp_client import generate_diagram + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_diagram() -> Optional[str]: + """Run the async diagram function in a new event loop.""" + return asyncio.run(generate_diagram(section, content, self.mcp_client)) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_diagram) + results_json = future.result(timeout=60.0) # Longer timeout for diagram generation + + if results_json: + import json + diagram_result = json.loads(results_json) + else: + diagram_result = {} + except Exception as e: + print(f"[WARNING] MCP diagram generation failed, falling back to traditional drawer: {e}") + if self.drawer: + query = json.dumps( + {"section_name": section, "section_content": content} + ) + diagram_result = self.drawer.run(query) + else: + diagram_result = {} + else: + # Use traditional drawer + if self.drawer: + query = json.dumps( + {"section_name": section, "section_content": content} + ) + diagram_result = self.drawer.run(query) + else: + diagram_result = {} if diagram_result and "diagram" in diagram_result: diagram = diagram_result["diagram"] @@ -261,12 +302,16 @@ def _write_section( elif section == "Analysis": # For non-experimental papers, use the research plan content research_plan = idea.get("ResearchPlan", experiment) + approach = idea.get("Approach", "No approach specified") section_prompt = self.prompts.section_prompt.get( section, self.prompts.section_prompt.get("Results", "") ).format( section_tips=self.prompts.section_tips.get( section, self.prompts.section_tips.get("Results", "") ), + problem=idea["Problem"], # Add the required problem field + approach=approach, # Add the required approach field + research_plan=research_plan, # Add the required research_plan field experiment=research_plan, baseline_results=baseline_result, experiment_results=experiment_result, @@ -334,20 +379,92 @@ def _search_reference(self, paper_list: List[str]) -> Dict[str, Any]: for paper_name in paper_list: try: - result = self.searcher.run(paper_name) + print(f"[Writer] Searching for paper: {paper_name}") + + if self.mcp_client: + # Use MCP client for paper search + import asyncio + + from .utils.mcp_client import search_papers + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_search() -> Optional[str]: + """Run the async search function in a new event loop.""" + return asyncio.run(search_papers(paper_name, self.mcp_client)) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_search) + results_json = future.result(timeout=30.0) # Add timeout + + print(f"[Writer] MCP raw result: {results_json[:200] if results_json else 'None'}...") + + if results_json: + import json + result = json.loads(results_json) + print(f"[Writer] MCP JSON parsing successful: {type(result)}") + + # Check MCP return format and handle errors + if isinstance(result, dict): + if 'error' in result: + print(f"[Writer] MCP returned error: {result['error']}") + result = {} # Convert error to empty result + else: + # Validate and convert data format + formatted_result = {} + for title, meta in result.items(): + if isinstance(meta, dict) and 'bibtex' in meta: + # MCP format correct: {title: {title: "...", bibtex: "..."}} + formatted_result[title] = meta + print(f"[Writer] Valid format paper: {title}") + elif isinstance(meta, str): + print(f"[Writer] Invalid format, skipping: {title} -> {meta}") + else: + print(f"[Writer] Unknown format, skipping: {title} -> {type(meta)}") + result = formatted_result + else: + print("[Writer] MCP returned empty result") + result = {} + + except Exception as e: + print(f"[Writer] MCP search failed, falling back to traditional search: {e}") + if self.searcher: + result = self.searcher.run(paper_name) + print(f"[Writer] Traditional search result: {type(result)}, length: {len(result) if result else 0}") + else: + result = {} + else: + # Use traditional searcher + print("[Writer] Using traditional searcher") + if self.searcher: + result = self.searcher.run(paper_name) + print(f"[Writer] Traditional search result: {type(result)}, length: {len(result) if result else 0}") + else: + result = {} + # Process search results if result: + print(f"[Writer] Found search results, count: {len(result)}") if paper_name in result: results_dict[paper_name] = result[paper_name] + print(f"[Writer] Exact match: {paper_name}") else: + # Use first result first_key = next(iter(result)) results_dict[first_key] = result[first_key] + print(f"[Writer] Using first result: {first_key}") + else: + print(f"[Writer] No papers found for: {paper_name}") time.sleep(1.0) except Exception as e: - print(f"[ERROR] While processing '{paper_name}': {e}") + print(f"[Writer] Error while processing '{paper_name}': {e}") traceback.print_exc() + print(f"[Writer] Search completed, found {len(results_dict)} papers total") return results_dict def _write_related_work(self, idea: Dict[str, Any]) -> None: @@ -379,7 +496,17 @@ def _write_related_work(self, idea: Dict[str, Any]) -> None: ) for title, meta in paper_source.items(): - match = re.search(r"@\w+\{(.+?),", meta.get("bibtex", "")) + # Ensure meta is a dictionary before accessing 'bibtex' + if isinstance(meta, dict): + bibtex = meta.get("bibtex", "") + elif isinstance(meta, str): + print(f"[Writer] Warning: meta is string for {title}, skipping citation replacement") + continue + else: + print(f"[Writer] Warning: unexpected meta type {type(meta)} for {title}, skipping citation replacement") + continue + + match = re.search(r"@\w+\{(.+?),", bibtex) if match: try: bibtex_key = match.group(1) @@ -391,7 +518,7 @@ def _write_related_work(self, idea: Dict[str, Any]) -> None: relatedwork_content, ) except Exception: - print(f"[ERROR] Failed to replace citation for title: {title}") + print(f"[Writer] Failed to replace citation for title: {title}") traceback.print_exc() self.generated_sections["Related_Work"] = relatedwork_content @@ -538,7 +665,17 @@ def _add_citations(self, idea: Dict[str, Any]) -> None: print(f"Refined section for {section}: {refined_section}") for title, meta in paper_source.items(): - match = re.search(r"@\w+\{(.+?),", meta.get("bibtex", "")) + # Ensure meta is a dictionary before accessing 'bibtex' + if isinstance(meta, dict): + bibtex = meta.get("bibtex", "") + elif isinstance(meta, str): + print(f"[Writer] Warning: meta is string for {title}, skipping citation replacement") + continue + else: + print(f"[Writer] Warning: unexpected meta type {type(meta)} for {title}, skipping citation replacement") + continue + + match = re.search(r"@\w+\{(.+?),", bibtex) if match: bibtex_key = match.group(1) escaped_title = re.escape(title) @@ -551,5 +688,5 @@ def _add_citations(self, idea: Dict[str, Any]) -> None: self.generated_sections[section] = refined_section except Exception: - print(f"[ERROR] Failed to add citations to section: {section}") + print(f"[Writer] Failed to add citations to section: {section}") traceback.print_exc()