diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index 29ee31de99a7f..5bae745f0a102 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -52,7 +52,7 @@ def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: def embed_query(self, text: str) -> List[float]: """Generate a hypothetical document and embedded it.""" var_name = self.llm_chain.input_keys[0] - result = self.llm_chain.generate([{var_name: text}]) + result, _ = self.llm_chain.generate([{var_name: text}]) documents = [generation.text for generation in result.generations[0]] embeddings = self.embed_documents(documents) return self.combine_embeddings(embeddings) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 46b3cfb7ce2e4..d5d7f6a9f3a8b 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -3,13 +3,20 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field from langchain.chains.base import Chain from langchain.input import get_colored_text +from langchain.output_parsers.base import OutputGuardrail from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseLanguageModel, LLMResult, PromptValue +from langchain.schema import ( + BaseLanguageModel, + Guardrail, + LLMResult, + PromptValue, + ValidationError, +) class LLMChain(Chain, BaseModel): @@ -30,6 +37,8 @@ class LLMChain(Chain, BaseModel): """Prompt object to use.""" llm: BaseLanguageModel output_key: str = "text" #: :meta private: + output_parser: Optional[OutputGuardrail] = None + guardrails: List[Guardrail] = Field(default_factory=list) class Config: """Configuration for this pydantic object.""" @@ -56,15 +65,19 @@ def output_keys(self) -> List[str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: return self.apply([inputs])[0] - def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + def generate( + self, input_list: List[Dict[str, Any]] + ) -> Tuple[LLMResult, List[PromptValue]]: """Generate LLM result from inputs.""" prompts, stop = self.prep_prompts(input_list) - return self.llm.generate_prompt(prompts, stop) + return self.llm.generate_prompt(prompts, stop), prompts - async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + async def agenerate( + self, input_list: List[Dict[str, Any]] + ) -> Tuple[LLMResult, List[PromptValue]]: """Generate LLM result from inputs.""" prompts, stop = await self.aprep_prompts(input_list) - return await self.llm.agenerate_prompt(prompts, stop) + return await self.llm.agenerate_prompt(prompts, stop), prompts def prep_prompts( self, input_list: List[Dict[str, Any]] @@ -115,20 +128,37 @@ async def aprep_prompts( def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = self.generate(input_list) - return self.create_outputs(response) + response, prompts = self.generate(input_list) + return self.create_outputs(response, prompts) async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = await self.agenerate(input_list) - return self.create_outputs(response) - - def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: + response, prompts = await self.agenerate(input_list) + return self.create_outputs(response, prompts) + + def _get_final_output(self, text: str, prompt_value: PromptValue) -> Any: + result: Any = text + for guardrail in self.guardrails: + if isinstance(guardrail, OutputGuardrail): + try: + result = guardrail.output_parser.parse(result) + error = None + except Exception as e: + error = ValidationError(text=e) + else: + error = guardrail.check(prompt_value, result) + if error is not None: + result = guardrail.fix(prompt_value, result, error) + return result + + def create_outputs( + self, response: LLMResult, prompts: List[PromptValue] + ) -> List[Dict[str, str]]: """Create outputs from response.""" return [ # Get the text of the top generated string. - {self.output_key: generation[0].text} - for generation in response.generations + {self.output_key: self._get_final_output(generation[0].text, prompts[i])} + for i, generation in enumerate(response.generations) ] async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: diff --git a/langchain/chains/qa_generation/base.py b/langchain/chains/qa_generation/base.py index 66907befaec23..d6bd5e03c719a 100644 --- a/langchain/chains/qa_generation/base.py +++ b/langchain/chains/qa_generation/base.py @@ -47,7 +47,7 @@ def output_keys(self) -> List[str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: docs = self.text_splitter.create_documents([inputs[self.input_key]]) - results = self.llm_chain.generate([{"text": d.page_content} for d in docs]) + results, _ = self.llm_chain.generate([{"text": d.page_content} for d in docs]) qa = [json.loads(res[0].text) for res in results.generations] return {self.output_key: qa} diff --git a/langchain/output_parsers/base.py b/langchain/output_parsers/base.py index f35984160a8f7..17b88f7467918 100644 --- a/langchain/output_parsers/base.py +++ b/langchain/output_parsers/base.py @@ -1,10 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Optional from pydantic import BaseModel +from langchain.schema import Fixer, Guardrail, PromptValue, ValidationError + class BaseOutputParser(BaseModel, ABC): """Class to parse the output of an LLM call.""" @@ -26,3 +28,22 @@ def dict(self, **kwargs: Any) -> Dict: output_parser_dict = super().dict() output_parser_dict["_type"] = self._type return output_parser_dict + + +class OutputGuardrail(Guardrail, BaseModel): + output_parser: BaseOutputParser + fixer: Fixer + + def check( + self, prompt_value: PromptValue, result: Any + ) -> Optional[ValidationError]: + try: + self.output_parser.parse(result) + return None + except Exception as e: + return ValidationError(text=e) + + def fix( + self, prompt_value: PromptValue, result: Any, error: ValidationError + ) -> Any: + return self.fixer(prompt_value, result, error) diff --git a/langchain/schema.py b/langchain/schema.py index 286af79e72cfe..61735f59fb2d8 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -220,3 +220,29 @@ def clear(self) -> None: Memory = BaseMemory + + +class ValidationError(BaseModel): + error_message: str + + +class Guardrail(ABC): + @abstractmethod + def check( + self, prompt_value: PromptValue, result: Any + ) -> Optional[ValidationError]: + """Check whether there's a validation error.""" + + @abstractmethod + def fix( + self, prompt_value: PromptValue, result: Any, error: ValidationError + ) -> Any: + """""" + + +class Fixer(ABC): + @abstractmethod + def fix( + self, prompt_value: PromptValue, result: Any, error: ValidationError + ) -> Any: + """"""