Skip to content

Utilities - Prompt Embedding Function #33

@umar-anzar

Description

@umar-anzar

Use in Prompt Embedding for json response

import inspect

def class_to_string(cls):
    """
    Convert a Python class and its base classes to a string representation.

    This function takes a Python class as input and returns its source code 
    as a string. It gathers the source code of the specified class and 
    recursively includes any user-defined parent classes that are part of the 
    same module. It does not include external library classes (like Pydantic's 
    BaseModel) or classes from other modules.

    Returns:
        str: A string containing the class definition and its base classes.

    Example:
        >>> class Parent:
        ...     x: int
        ...
        >>> class Child(Parent):
        ...     y: str
        ...
        >>> print(class_to_string(Child))
        '''
        class Parent:
            x: int
        
        class Child(Parent):
            y: str
        '''
    
    Limitations:
        - Only captures classes defined in the same script or module.
        - Does not include parent classes from external libraries.
        - If the class has a complex inheritance hierarchy, only user-defined
          parent classes are included.
    """
    class_definitions = []

    # Gather base class definitions recursively
    for base in cls.__bases__:
        if base.__module__ == "__main__":  # Only get user-defined classes
            class_definitions.append(inspect.getsource(base))

    # Add the main class definition
    class_definitions.append(inspect.getsource(cls))

    # Join all class definitions and wrap in triple quotes
    return f'"""\n{"".join(class_definitions)}"""'
import inspect

from pydantic import BaseModel


def get_related_classes(cls, seen: set[type] = None, ordered_classes: list[type] = None):
    """
    Recursively collects all user-defined classes related to the given class in the correct order.

    Args:
        cls (type): The class whose dependencies need to be collected.
        seen (set, optional): A set to track visited classes.
        ordered_classes (list, optional): A list to store classes in the correct order.

    Returns:
        list: A list of user-defined classes in the correct order.
    """
    if seen is None:
        seen = set()
    if ordered_classes is None:
        ordered_classes = []

    if cls in seen or not inspect.isclass(cls):
        return ordered_classes
    seen.add(cls)

    # Check all annotations to find referenced classes
    for _, annotation in getattr(cls, '__annotations__', {}).items():
        if hasattr(annotation, '__origin__'):  # Handle Annotated and Optional
            for arg in annotation.__args__:
                if inspect.isclass(arg) and issubclass(arg, BaseModel):
                    get_related_classes(arg, seen, ordered_classes)
        elif inspect.isclass(annotation) and issubclass(annotation, BaseModel):
            get_related_classes(annotation, seen, ordered_classes)

    # Add the class at the end to maintain the correct order
    if cls not in ordered_classes:
        ordered_classes.append(cls)

    return ordered_classes

def class_to_string(cls):
    """
    Convert a class and all its referenced user-defined classes to a string in the correct order.

    Args:
        cls (type): The main class to convert.

    Returns:
        str: A string containing all relevant class definitions in the correct order.
    """
    ordered_classes = get_related_classes(cls)
    class_definitions = [inspect.getsource(c) for c in ordered_classes]

    return f'"""\n{"".join(class_definitions)}"""'

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No fields configured for Task.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions