-
Notifications
You must be signed in to change notification settings - Fork 0
Harrison/guardrails #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,13 +3,20 @@ | |||||||||||||
|
|
||||||||||||||
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | ||||||||||||||
|
|
||||||||||||||
| from pydantic import BaseModel, Extra | ||||||||||||||
| from pydantic import BaseModel, Extra, Field | ||||||||||||||
|
|
||||||||||||||
| from langchain.chains.base import Chain | ||||||||||||||
| from langchain.input import get_colored_text | ||||||||||||||
| from langchain.output_parsers.base import OutputGuardrail | ||||||||||||||
| from langchain.prompts.base import BasePromptTemplate | ||||||||||||||
| from langchain.prompts.prompt import PromptTemplate | ||||||||||||||
| from langchain.schema import BaseLanguageModel, LLMResult, PromptValue | ||||||||||||||
| from langchain.schema import ( | ||||||||||||||
| BaseLanguageModel, | ||||||||||||||
| Guardrail, | ||||||||||||||
| LLMResult, | ||||||||||||||
| PromptValue, | ||||||||||||||
| ValidationError, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class LLMChain(Chain, BaseModel): | ||||||||||||||
|
|
@@ -30,6 +37,8 @@ class LLMChain(Chain, BaseModel): | |||||||||||||
| """Prompt object to use.""" | ||||||||||||||
| llm: BaseLanguageModel | ||||||||||||||
| output_key: str = "text" #: :meta private: | ||||||||||||||
| output_parser: Optional[OutputGuardrail] = None | ||||||||||||||
| guardrails: List[Guardrail] = Field(default_factory=list) | ||||||||||||||
|
|
||||||||||||||
| class Config: | ||||||||||||||
| """Configuration for this pydantic object.""" | ||||||||||||||
|
|
@@ -56,15 +65,19 @@ def output_keys(self) -> List[str]: | |||||||||||||
| def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||||||||||||
| return self.apply([inputs])[0] | ||||||||||||||
|
|
||||||||||||||
| def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: | ||||||||||||||
| def generate( | ||||||||||||||
| self, input_list: List[Dict[str, Any]] | ||||||||||||||
| ) -> Tuple[LLMResult, List[PromptValue]]: | ||||||||||||||
| """Generate LLM result from inputs.""" | ||||||||||||||
| prompts, stop = self.prep_prompts(input_list) | ||||||||||||||
| return self.llm.generate_prompt(prompts, stop) | ||||||||||||||
| return self.llm.generate_prompt(prompts, stop), prompts | ||||||||||||||
|
|
||||||||||||||
| async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult: | ||||||||||||||
| async def agenerate( | ||||||||||||||
| self, input_list: List[Dict[str, Any]] | ||||||||||||||
| ) -> Tuple[LLMResult, List[PromptValue]]: | ||||||||||||||
| """Generate LLM result from inputs.""" | ||||||||||||||
| prompts, stop = await self.aprep_prompts(input_list) | ||||||||||||||
| return await self.llm.agenerate_prompt(prompts, stop) | ||||||||||||||
| return await self.llm.agenerate_prompt(prompts, stop), prompts | ||||||||||||||
|
|
||||||||||||||
| def prep_prompts( | ||||||||||||||
| self, input_list: List[Dict[str, Any]] | ||||||||||||||
|
|
@@ -115,20 +128,37 @@ async def aprep_prompts( | |||||||||||||
|
|
||||||||||||||
| def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: | ||||||||||||||
| """Utilize the LLM generate method for speed gains.""" | ||||||||||||||
| response = self.generate(input_list) | ||||||||||||||
| return self.create_outputs(response) | ||||||||||||||
| response, prompts = self.generate(input_list) | ||||||||||||||
| return self.create_outputs(response, prompts) | ||||||||||||||
|
|
||||||||||||||
| async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: | ||||||||||||||
| """Utilize the LLM generate method for speed gains.""" | ||||||||||||||
| response = await self.agenerate(input_list) | ||||||||||||||
| return self.create_outputs(response) | ||||||||||||||
|
|
||||||||||||||
| def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: | ||||||||||||||
| response, prompts = await self.agenerate(input_list) | ||||||||||||||
| return self.create_outputs(response, prompts) | ||||||||||||||
|
|
||||||||||||||
| def _get_final_output(self, text: str, prompt_value: PromptValue) -> Any: | ||||||||||||||
| result: Any = text | ||||||||||||||
| for guardrail in self.guardrails: | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a chain constructed with output_parser=SomeParser() but no guardrails list, a trigger like model output 'not-json' is no longer parsed at all because create_outputs only consults self.guardrails and ignores the new output_parser field, breaking callers that expect parsed outputs after this refactor. Also reported at: Preserve previous/new contract by automatically running self.output_parser when set, or wrap it into guardrails during initialization so create_outputs always applies it. Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: |
||||||||||||||
| if isinstance(guardrail, OutputGuardrail): | ||||||||||||||
| try: | ||||||||||||||
| result = guardrail.output_parser.parse(result) | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When any parser or guardrail raises, _get_final_output constructs ValidationError(text=e) but langchain/schema.py now defines ValidationError.error_message, so a malformed output such as non-JSON text under an OutputGuardrail turns a recoverable validation failure into a pydantic initialization exception. Instantiate the new schema correctly in both call sites, e.g. Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If any guardrail parser/fixer raises (e.g. invalid output triggers OutputGuardrail on '{'), _get_final_output constructs ValidationError(text=e) but schema.ValidationError now requires error_message, so the guardrail path crashes instead of returning a fixed value. Instantiate ValidationError with the declared field, e.g. ValidationError(error_message=str(e)), and make OutputGuardrail.check do the same. Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: |
||||||||||||||
| error = None | ||||||||||||||
| except Exception as e: | ||||||||||||||
| error = ValidationError(text=e) | ||||||||||||||
|
Comment on lines
+146
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any invalid model output that triggers an OutputGuardrail parse/check failure now crashes in create_outputs because both llm.py and output_parsers/base.py instantiate ValidationError with Also reported at: Construct ValidationError with the new field name everywhere, e.g. Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM:
Comment on lines
+143
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The broad exception handler hides parser and programmer errors; catch the parser's validation exception explicitly and let unexpected exceptions propagate. Suggested fix try:
result = guardrail.output_parser.parse(result)
error = None
except ValidationError as e:
error = ePrompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: |
||||||||||||||
| else: | ||||||||||||||
| error = guardrail.check(prompt_value, result) | ||||||||||||||
| if error is not None: | ||||||||||||||
| result = guardrail.fix(prompt_value, result, error) | ||||||||||||||
| return result | ||||||||||||||
|
Comment on lines
+139
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When iterating |
||||||||||||||
|
|
||||||||||||||
| def create_outputs( | ||||||||||||||
| self, response: LLMResult, prompts: List[PromptValue] | ||||||||||||||
| ) -> List[Dict[str, str]]: | ||||||||||||||
| """Create outputs from response.""" | ||||||||||||||
| return [ | ||||||||||||||
| # Get the text of the top generated string. | ||||||||||||||
| {self.output_key: generation[0].text} | ||||||||||||||
| for generation in response.generations | ||||||||||||||
| {self.output_key: self._get_final_output(generation[0].text, prompts[i])} | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When an LLMChain is configured with an OutputGuardrail and the model returns invalid JSON like '{', create_outputs now stores the parsed/fixed object instead of the raw text, but cross-file callers such as QA generation still read res[0].text from generate() and therefore silently bypass guardrails/fixes on that trigger. Either apply guardrails inside generate()/agenerate() so callers consuming LLMResult see the fixed output too, or update all generate() callers to use apply()/create_outputs() instead of reading generation.text directly. Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: |
||||||||||||||
| for i, generation in enumerate(response.generations) | ||||||||||||||
| ] | ||||||||||||||
|
|
||||||||||||||
| async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,10 +1,12 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from langchain.schema import Fixer, Guardrail, PromptValue, ValidationError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class BaseOutputParser(BaseModel, ABC): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Class to parse the output of an LLM call.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -26,3 +28,22 @@ def dict(self, **kwargs: Any) -> Dict: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_parser_dict = super().dict() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_parser_dict["_type"] = self._type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return output_parser_dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class OutputGuardrail(Guardrail, BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_parser: BaseOutputParser | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixer: Fixer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def check( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, prompt_value: PromptValue, result: Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Optional[ValidationError]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.output_parser.parse(result) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ValidationError(text=e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+40
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching Exception here hides all parser and runtime errors; catch the parser-specific exception types you expect and let unexpected failures propagate. Suggested fix try:
self.output_parser.parse(result)
return None
except (ValueError, TypeError) as e:
return ValidationError(text=str(e))Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM:
Comment on lines
+43
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The exception object is passed directly into ValidationError.text, so convert it to a string before storing it. Suggested fix except (ValueError, TypeError) as e:
return ValidationError(text=str(e))Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ValidationError is constructed with
Suggested change
Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def fix( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, prompt_value: PromptValue, result: Any, error: ValidationError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Any: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.fixer(prompt_value, result, error) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+37
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
Prompt for AI assistanceCopy the prompt below and paste it into ChatGPT, Claude, or any LLM: |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -220,3 +220,29 @@ def clear(self) -> None: | |||||||||
|
|
||||||||||
|
|
||||||||||
| Memory = BaseMemory | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ValidationError(BaseModel): | ||||||||||
| error_message: str | ||||||||||
|
Comment on lines
+225
to
+226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Comment on lines
+225
to
+226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Defining a class named |
||||||||||
|
|
||||||||||
|
|
||||||||||
| class Guardrail(ABC): | ||||||||||
| @abstractmethod | ||||||||||
| def check( | ||||||||||
| self, prompt_value: PromptValue, result: Any | ||||||||||
| ) -> Optional[ValidationError]: | ||||||||||
| """Check whether there's a validation error.""" | ||||||||||
|
|
||||||||||
| @abstractmethod | ||||||||||
| def fix( | ||||||||||
| self, prompt_value: PromptValue, result: Any, error: ValidationError | ||||||||||
| ) -> Any: | ||||||||||
| """""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class Fixer(ABC): | ||||||||||
| @abstractmethod | ||||||||||
| def fix( | ||||||||||
| self, prompt_value: PromptValue, result: Any, error: ValidationError | ||||||||||
| ) -> Any: | ||||||||||
| """""" | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_parseris never readoutput_parser: Optional[OutputGuardrail] = Noneis declared onLLMChainbut is never consulted in_get_final_output,apply, or any other method. A user who sets this field expecting it to be applied will see it silently ignored. Either wire it into_get_final_output(e.g., prepend it to the guardrails loop) or remove the field.