-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathoutput_parser.py
100 lines (74 loc) · 2.62 KB
/
output_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Functions to retrieve the correct output parser and format instructions for the LLM model.
"""
from typing import Any, Callable, Dict, Type, Union
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
def get_structured_output_parser(
schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type],
) -> Callable:
"""
Get the correct output parser for the LLM model.
Returns:
Callable: The output parser function.
"""
if issubclass(schema, BaseModelV1):
return _base_model_v1_output_parser
if issubclass(schema, BaseModelV2):
return _base_model_v2_output_parser
return _dict_output_parser
def get_pydantic_output_parser(
schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type],
) -> JsonOutputParser:
"""
Get the correct output parser for the LLM model.
Returns:
JsonOutputParser: The output parser object.
"""
if issubclass(schema, BaseModelV1):
raise ValueError(
"""pydantic.v1 and langchain_core.pydantic_v1
are not supported with this LLM model. Please use pydantic v2 instead."""
)
if issubclass(schema, BaseModelV2):
return JsonOutputParser(pydantic_object=schema)
raise ValueError(
"""The schema is not a pydantic subclass.
With this LLM model you must use a pydantic schemas."""
)
def _base_model_v1_output_parser(x: BaseModelV1) -> dict:
"""
Parse the output of an LLM when the schema is BaseModelv1.
Args:
x (BaseModelV1): The output from the LLM model.
Returns:
dict: The parsed output.
"""
work_dict = x.dict()
def recursive_dict_parser(work_dict: dict) -> dict:
dict_keys = work_dict.keys()
for key in dict_keys:
if isinstance(work_dict[key], BaseModelV1):
work_dict[key] = work_dict[key].dict()
recursive_dict_parser(work_dict[key])
return work_dict
return recursive_dict_parser(work_dict)
def _base_model_v2_output_parser(x: BaseModelV2) -> dict:
"""
Parse the output of an LLM when the schema is BaseModelv2.
Args:
x (BaseModelV2): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x.model_dump()
def _dict_output_parser(x: dict) -> dict:
"""
Parse the output of an LLM when the schema is TypedDict or JsonSchema.
Args:
x (dict): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x