# How to Inspect Function and Class Signatures in Python?

[![Twitter Handle](https://img.shields.io/badge/Twitter-@gaohongnan-blue?style=social&logo=twitter)](https://twitter.com/gaohongnan)
[![LinkedIn Profile](https://img.shields.io/badge/@gaohongnan-blue?style=social&logo=linkedin)](https://linkedin.com/in/gao-hongnan)
[![GitHub Profile](https://img.shields.io/badge/GitHub-gao--hongnan-lightgrey?style=social&logo=github)](https://github.com/gao-hongnan)
![Tag](https://img.shields.io/badge/Tag-Brain_Dump-red)
[![Code](https://img.shields.io/badge/View-Code-blue?style=flat-square&logo=github)](https://github.com/gao-hongnan/omniverse/blob/main/omnivault/utils/inspector/core.py)

```{contents}
:local:
```

In [1]:
import inspect
from dataclasses import field, make_dataclass
from inspect import Parameter, Signature
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, _GenericAlias, get_type_hints, overload

from pydantic import BaseModel
from rich.pretty import pprint
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments

## Motivation

There are two motivations for why we want to inspect function and class
signatures in Python.

1.  **Code Introspection in Open Source Projects**: When we are working on open
    source projects, we often need to inspect the function and class signatures
    of the codebase. This is especially true when we are working on a large
    codebase with many functions and classes. In this case, we need to inspect
    the function and class signatures to understand how the codebase is
    structured and how the functions and classes are used.

    Sometimes there are nested abstractions, a child class $\mathcal{C}_N$ that
    inherits from a parent class $\mathcal{C}_{N-1}$, which in turn inherits
    from another parent class $\mathcal{C}_{N-2}$, and so on. Sometimes the
    child class does not immediately show what types of arguments the
    constructor can take. In this case, we need to inspect the parent classes to
    understand the constructor signature of the child class.

2.  **Agent-Based Function Calling**: In agent-based programming, where
    automated agents are tasked with executing functions or interacting with
    libraries, the risk of 'hallucination'â€”where an agent attempts to invoke
    non-existent methods or improperly structured calls from a external
    module/libraryâ€”is a notable concern. By equipping agents with the capability
    to query the actual signatures of libraries' classes or functions, we can
    significantly mitigate this risk. This ensures that agents operate based on
    accurate, real-time information, thereby improving the reliability and
    effectiveness of automated tasks.


## Construct Hypothetical Function, Child and Parent Classes

In [2]:
class ParentClass:
    """This is the parent class."""

    parent_class_attr = 'a parent class attribute'

    def __init__(self, parent_instance_attr: str) -> None:
        self.parent_instance_attr = parent_instance_attr

    def parent_method(self) -> str:
        """This is a method in the parent class."""
        return "Parent method called"

class ChildClass(ParentClass):
    """This is a subclass of ParentClass."""

    # Class attribute
    class_attr = 'a class attribute'

    # Private and protected attributes
    _protected_attr = 'a protected attribute'
    __private_attr = 'a private attribute'

    def __init__(self, instance_attr: str, parent_instance_attr: str) -> None:
        """Initialize the instance."""
        super().__init__(parent_instance_attr)
        # Instance attribute
        self.instance_attr = instance_attr
        self.instance_not_in_constructor_attr = 'an instance attribute not in the constructor'
        self._private_instance_attr = 'a private instance attribute'

    @property
    def read_only_attr(self) -> str:
        """This is a read-only attribute."""
        return 'You can read me, but you cannot change me.'

    def instance_method(self, arg: str) -> str:
        """This is an instance method."""
        return f'Instance method called with argument: {arg}'

    @classmethod
    def class_method(cls, arg: str) -> str:
        """This is a class method."""
        return f'Class method called with argument: {arg}'

    @staticmethod
    def static_method(arg: str) -> str:
        """This is a static method."""
        return f'Static method called with argument: {arg}'

    def __str__(self) -> str:
        """Return a string representation of the instance."""
        return f'MyClass(instance_attr={self.instance_attr})'


In [3]:
instance_child = ChildClass(instance_attr='an instance attribute', parent_instance_attr='a parent instance attribute')
class_child = ChildClass

instance_parent = ParentClass(parent_instance_attr='a parent instance attribute')
class_parent = ParentClass

In [4]:
def func(a: int, b: str, c: List[int], d: Tuple[str, str], e: Union[int, str], **kwargs: Any) -> str:
    return a, b, c, d, e, kwargs

## Inspect All Members

In [5]:
@overload
def get_members_of_function_or_method(
    func_or_class: Type[object], predicate: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[str, Any]]:
    ...


@overload
def get_members_of_function_or_method(
    func_or_class: Callable[..., Any], predicate: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[str, Any]]:
    ...


def get_members_of_function_or_method(
    func_or_class: Union[Type[object], Callable[..., Any]], predicate: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[str, Any]]:
    return inspect.getmembers(func_or_class, predicate)

def loop_through_members(members: List[Tuple[str, Any]], filter: Optional[str] = None) -> None:
    if filter is not None:
        members = [member for member in members if filter in member[0]]
    for member in members:
        name, value = member
        print(f'{name}: {value}')

Our initial goal is to get all signatures and type annotations of a class or function. We can use the `inspect` module to achieve this. The `getmembers` function returns all members of a class or module. We can then filter out the functions and classes and inspect their signatures.

However, for our purpose, it may be overkill since it retrieves all members
within a module, the scope is very broad, for example, inspecting just the `func`
defined will also return all `__globals__`, which may not be what we want.

In [6]:
func_all_members = get_members_of_function_or_method(func, predicate=None)
loop_through_members(func_all_members)

__annotations__: {'a': <class 'int'>, 'b': <class 'str'>, 'c': typing.List[int], 'd': typing.Tuple[str, str], 'e': typing.Union[int, str], 'kwargs': typing.Any, 'return': <class 'str'>}
__call__: <method-wrapper '__call__' of function object at 0x29bce6310>
__class__: <class 'function'>
__closure__: None
__code__: <code object func at 0x29bcfb2f0, file "/var/folders/l2/jjqj299126j0gycr9kkkt9xm0000gn/T/ipykernel_3165/2139551385.py", line 1>
__defaults__: None
__delattr__: <method-wrapper '__delattr__' of function object at 0x29bce6310>
__dict__: {}
__dir__: <built-in method __dir__ of function object at 0x29bce6310>
__doc__: None
__eq__: <method-wrapper '__eq__' of function object at 0x29bce6310>
__format__: <built-in method __format__ of function object at 0x29bce6310>
__ge__: <method-wrapper '__ge__' of function object at 0x29bce6310>
__get__: <method-wrapper '__get__' of function object at 0x29bce6310>
__getattribute__: <method-wrapper '__getattribute__' of function object at 0x29bce

And to get the signature, we can just filter `'__annotations__'`.

In [7]:
loop_through_members(func_all_members, filter='__annotations__')

__annotations__: {'a': <class 'int'>, 'b': <class 'str'>, 'c': typing.List[int], 'd': typing.Tuple[str, str], 'e': typing.Union[int, str], 'kwargs': typing.Any, 'return': <class 'str'>}


In [8]:
class_child_all_members = get_members_of_function_or_method(class_child, predicate=None)
loop_through_members(class_child_all_members)

_ChildClass__private_attr: a private attribute
__class__: <class 'type'>
__delattr__: <slot wrapper '__delattr__' of 'object' objects>
__dict__: {'__module__': '__main__', '__doc__': 'This is a subclass of ParentClass.', 'class_attr': 'a class attribute', '_protected_attr': 'a protected attribute', '_ChildClass__private_attr': 'a private attribute', '__init__': <function ChildClass.__init__ at 0x29bce64c0>, 'read_only_attr': <property object at 0x29bcec040>, 'instance_method': <function ChildClass.instance_method at 0x29bce6700>, 'class_method': <classmethod object at 0x29bcd5580>, 'static_method': <staticmethod object at 0x29bcd5340>, '__str__': <function ChildClass.__str__ at 0x29bce68b0>}
__dir__: <method '__dir__' of 'object' objects>
__doc__: This is a subclass of ParentClass.
__eq__: <slot wrapper '__eq__' of 'object' objects>
__format__: <method '__format__' of 'object' objects>
__ge__: <slot wrapper '__ge__' of 'object' objects>
__getattribute__: <slot wrapper '__getattribute__

In [9]:
instance_child_all_members = get_members_of_function_or_method(instance_child, predicate=None)
loop_through_members(instance_child_all_members)

_ChildClass__private_attr: a private attribute
__class__: <class '__main__.ChildClass'>
__delattr__: <method-wrapper '__delattr__' of ChildClass object at 0x29bce59d0>
__dict__: {'parent_instance_attr': 'a parent instance attribute', 'instance_attr': 'an instance attribute', 'instance_not_in_constructor_attr': 'an instance attribute not in the constructor', '_private_instance_attr': 'a private instance attribute'}
__dir__: <built-in method __dir__ of ChildClass object at 0x29bce59d0>
__doc__: This is a subclass of ParentClass.
__eq__: <method-wrapper '__eq__' of ChildClass object at 0x29bce59d0>
__format__: <built-in method __format__ of ChildClass object at 0x29bce59d0>
__ge__: <method-wrapper '__ge__' of ChildClass object at 0x29bce59d0>
__getattribute__: <method-wrapper '__getattribute__' of ChildClass object at 0x29bce59d0>
__gt__: <method-wrapper '__gt__' of ChildClass object at 0x29bce59d0>
__hash__: <method-wrapper '__hash__' of ChildClass object at 0x29bce59d0>
__init__: <bound

In [10]:
trainer_all_members = get_members_of_function_or_method(Trainer, predicate=None)
loop_through_members(trainer_all_members)

__class__: <class 'type'>
__delattr__: <slot wrapper '__delattr__' of 'object' objects>
__dir__: <method '__dir__' of 'object' objects>
__doc__: 
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for ðŸ¤— Transformers.

    Args:
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.

            <Tip>

            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the ðŸ¤— Transformers
            models.

            </Tip>

        args ([`TrainingArguments`], *optional*):
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directo

## Retrieve All Methods of a Class

There are a few ways to do it.

### Using `__dict__`

In [11]:
child_class_methods_using_dict = list(ChildClass.__dict__.keys())
pprint(sorted(child_class_methods_using_dict))

assert 'parent_method' not in child_class_methods_using_dict
assert 'read_only_attr' in child_class_methods_using_dict
assert 'class_method' in child_class_methods_using_dict

Notice that the parent class methods are not included!

In [12]:
pprint(instance_child.__class__.__dict__.keys() == ChildClass.__dict__.keys())

### Using `vars`

`vars` and `__dict__` are equivalent, but people are preferring the former due
to some efficiency reasons, which can be found [in this post](https://stackoverflow.com/questions/21297203/use-dict-or-vars).

In [13]:
child_class_methods_using_vars = list(vars(ChildClass).keys())
pprint(sorted(child_class_methods_using_vars))

assert 'parent_method' not in child_class_methods_using_vars
assert 'read_only_attr' in child_class_methods_using_vars
assert 'class_method' in child_class_methods_using_vars

assert set(child_class_methods_using_dict) == set(child_class_methods_using_vars)

### Using `dir`

To include the base/parent class methods, we can use `dir` instead.

In [14]:
child_class_methods_using_dir = list(dir(ChildClass))
pprint(sorted(child_class_methods_using_dir))

assert 'parent_method' in child_class_methods_using_dir
assert 'read_only_attr' in child_class_methods_using_dir
assert 'class_method' in child_class_methods_using_dir

### Using `inspect.getmembers`

We use `inspect.getmembers` to get all members of a class, and then filter out
via the predicate `inspect.isroutine`, a stronger filter than `inspect.isfunction`
or `inspect.ismethod`.

We attach the source code of `inspect.isroutine` here for reference.

```python
def isroutine(object):
    """Return true if the object is any kind of function or method."""
    return (isbuiltin(object)
            or isfunction(object)
            or ismethod(object)
            or ismethoddescriptor(object))
```

In [15]:
predicate = inspect.isroutine
child_class_methods_using_getmembers = list(get_members_of_function_or_method(ChildClass, predicate=predicate))

pprint(sorted(child_class_methods_using_getmembers))

Of course, the reason to retrieve all methods is a convenience if we want to
inspect all methods at once. And if we can obtain all methods, we can then
iteratively inspect each method's signature.

### Method Resolution Order

The above examples do not take into account complicated cases, such as when
the class is a subclass of **multiple** classes, in which case if you just
print out the methods of the class, you will have a hard time to know which
methods are from which class. You can do so via more filtering, but this is
beyond the scope of this notebook. 

In [16]:
predicate = inspect.isroutine
GPT2LMHeadModel_methods_using_getmembers = list(get_members_of_function_or_method(GPT2LMHeadModel, predicate=predicate))

pprint(sorted(GPT2LMHeadModel_methods_using_getmembers))

You can get the method resolution order (MRO) of a class via `cls.__mro__`. 

In [17]:
inspect.getmro(GPT2LMHeadModel)

(transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel,
 transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel,
 transformers.modeling_utils.PreTrainedModel,
 torch.nn.modules.module.Module,
 transformers.modeling_utils.ModuleUtilsMixin,
 transformers.generation.utils.GenerationMixin,
 transformers.utils.hub.PushToHubMixin,
 transformers.integrations.peft.PeftAdapterMixin,
 object)

A pseudocode to get all signatures of a class via MRO is as follows:

```python
def get_all_args(cls: Type[object]) -> Dict[str, inspect.Parameter]:
    mro = inspect.getmro(cls)
    all_args = {}
    for base_class in mro[::-1]:  # reverse to start from topmost class
        if base_class is object:  # skip the base 'object' class
            continue
        sig = inspect.signature(base_class.__init__)
        all_args.update(sig.parameters)
    return all_args
```


## Get Class and Instance Attributes

In [18]:
pprint(list(class_child.__dict__.keys())) # class attributes
pprint(list(instance_child.__dict__.keys())) # instance attributes

In [19]:
union_class_and_instance_attributes = list(set(class_child.__dict__.keys()).union(set(instance_child.__dict__.keys())))
pprint(union_class_and_instance_attributes)

## Get Signature and Type Annotations of a Function

In [None]:
func_sig: Signature = inspect.signature(func)
pprint(func_sig.parameters)
pprint(func_sig.return_annotation)

Here are the 4 key properties of the `Parameter` object
of the `Signature` object.

```python
@property
def name(self):
    return self._name

@property
def default(self):
    return self._default

@property
def annotation(self):
    return self._annotation

@property
def kind(self):
    return self._kind
```

We will also use `get_type_hints` to get the type hints of a function
instead of using the `annotations` property of `inspect.Signature`. The reason
can be found in the docstring of `get_type_hints`:

```python
def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
    """Return type hints for an object.

    This is often the same as obj.__annotations__, but it handles
    forward references encoded as string literals, adds Optional[t] if a
    default value equal to None is set and recursively replaces all
    'Annotated[T, ...]' with 'T' (unless 'include_extras=True').

    ...
    """
```

In [None]:
def no_type_hints(a, b, c, d, e, **kwargs):
    return a, b, c, d, e, kwargs

In [None]:
get_type_hints(no_type_hints), inspect.signature(no_type_hints)

How to know if a parameter is optional or not? We can use the `inspect.Parameter.empty`
property.

In [None]:
for name, value in inspect.signature(func).parameters.items():
    print(value.default)
    print(value.default is inspect.Parameter.empty)

We will also use `__mro__` to get the method resolution order of a class
because `__bases__` only returns the immediate parent class.

In [None]:
ChildClass.__bases__, GPT2LMHeadModel.__bases__[0].__bases__[0].__bases__

In [None]:
list(reversed(inspect.getmro(GPT2LMHeadModel)))

In [None]:
def get_base_classes(cls: Type[Any], include_self: bool = False) -> Set[Type[Any]]:
    """
    Get the base classes of a class and all its base classes.
    """
    return set(cls.__mro__[0:-1] if include_self else cls.__mro__[1:-1])

pprint(get_base_classes(GPT2LMHeadModel, include_self=True))

In [None]:
def get_default(param: Parameter) -> Any:
    """Return the parameter's default value or None if not specified."""
    return param.default if param.default is not param.empty else None

def get_field_annotations(func_or_method: Callable[..., Any]) -> Tuple[List[Tuple[str, Any, Any]], Dict[str, Any]]:
    if not inspect.isroutine(func_or_method):
        raise ValueError("Expected a function or method")

    required_fields = []
    optional_fields = []
    annotations = {}

    try:
        sig: Signature = inspect.signature(func_or_method)
        type_hints: Dict[str, Any] = get_type_hints(func_or_method)
    except ValueError:
        raise ValueError("Object does not support signature or type hints extraction.") from None

    for name, param in sig.parameters.items():
        if name == "self":
            continue

        type_hint = type_hints.get(name, Any)
        annotations[name] = type_hint
        if param.default is param.empty:
            required_fields.append((name, type_hint, Ellipsis))
        else:
            default_value = param.default
            optional_fields.append((name, type_hint, default_value))

    fields = required_fields + optional_fields
    return fields, annotations


# TODO: Tuple[str, Any, Any] should be Tuple[str, Any, ellipsis]
def get_constructor_field_annotations(
    cls: Type[Any], include_bases: bool = True
) -> Tuple[List[Tuple[str, Any, Any]], Dict[str, Any]]:
    fields = []
    annotations = {}

    classes_to_inspect = [cls] + list(get_base_classes(cls, include_self=False)) if include_bases else [cls]

    for c in reversed(classes_to_inspect):  # Reverse to respect MRO
        if hasattr(c, "__init__"):
            class_fields, class_annotations = get_field_annotations(c.__init__)
            # Update fields and annotations with those from the current class,
            # avoiding duplicates.
            for field in class_fields:
                if field[0] not in annotations:
                    fields.append(field)  # noqa: PERF401
            annotations.update(class_annotations)

    return fields, annotations

In [None]:
fields, annotations = get_constructor_field_annotations(TrainingArguments, include_bases=False)
for field in fields:
    assert len(field) == 3
    print(f"{field[0]}, {field[1]}, {field[2]}")

assert get_field_annotations(TrainingArguments.__init__) == (fields, annotations)

Warning: it does not play too well with `dataclass` and `pydantic` classes
because they have more complex bells and whistles. However, because of the perks
of `dataclass` and `pydantic`, we can just use
[property](https://stackoverflow.com/questions/71183960/short-way-to-get-all-field-names-of-a-pydantic-class)
like `model_fields` to get all fields and their types.

As we can see from above, we did not handle `lr_scheduler_kwargs` well:

```python
lr_scheduler_kwargs, typing.Optional[typing.Dict], <factory>
```

where `<factory>` is the default value of the parameter. But it is actually
referring to the `default_factory` of the `dataclass` field, which can be a default
dict etc.

In [None]:
def type_hint_to_str(type_hint: Any) -> str:
    """
    Convert a type hint into its string representation.
    """
    if hasattr(type_hint, '__name__'):
        return type_hint.__name__
    elif hasattr(type_hint, '_name') and type_hint._name is not None:
        return str(type_hint._name)
    elif type(type_hint) == _GenericAlias:  # For Python 3.8+
        # Handles complex types, e.g., List[int], Union[str, int]
        origin = type_hint_to_str(type_hint.__origin__)
        args = ', '.join(type_hint_to_str(arg) for arg in type_hint.__args__)
        return f"{origin}[{args}]"
    else:
        # Fallback for unhandled types
        return str(type_hint)

def create_config_class_str(fields: List[Tuple[str, Any, Any]]) -> str:
    lines = ["class Config:"]
    if not fields:
        lines.append("    ...")
    else:
        init_params = ["self"]
        init_body = []
        for name, type_hint, default in fields:
            type_hint_str = type_hint_to_str(type_hint)
            if default is Ellipsis:  # Required argument
                param_str = f"{name}: {type_hint_str}"
            elif default is field:
                param_str = f"{name}: {type_hint_str} = field(default_factory=dict)"
            else:
                default_repr = repr(default) if default is not None else 'None'
                param_str = f"{name}: {type_hint_str} = {default_repr}"

            init_params.append(param_str)
            init_body.append(f"        self.{name} = {name}")

        lines.append(f"    def __init__({', '.join(init_params)}):")
        lines.extend(init_body)

    return '\n'.join(lines)

config_class_str = create_config_class_str(fields)
print(config_class_str)

Using this as is will yield a `SyntaxError` because of the `<factory>` issue
highlighted above. We can use on a "normal" class `Trainer`.

In [None]:
fields, annotations = get_constructor_field_annotations(Trainer, include_bases=False)
config_class_str = create_config_class_str(fields)
print(config_class_str)

In [None]:
import transformers
import typing
import torch
from transformers import DataCollator

NoneType = type(None)

config_class_str = create_config_class_str(fields)

# Execute the generated class definition string
namespace = {}
exec(config_class_str, globals(), namespace)

# Extract the newly created class from the namespace
ConfigClass = namespace['Config']


In [None]:
inspect.signature(ConfigClass.__init__)

## References and Further Readings

- [inspect â€” Inspect live objects](https://docs.python.org/3/library/inspect.html)
- [Getting attributes of a class](https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class)
- [Use __dict__ or vars()?](https://stackoverflow.com/questions/21297203/use-dict-or-vars)
- [How do I get list of methods in a Python class?](https://stackoverflow.com/questions/1911281/how-do-i-get-list-of-methods-in-a-python-class)