2
2
3
3
from typing import Any , Dict , List , Optional
4
4
5
- from langchain import BasePromptTemplate , LLMChain
6
5
from langchain .base_language import BaseLanguageModel
7
6
from langchain .callbacks .manager import (
8
7
AsyncCallbackManagerForChainRun ,
9
8
CallbackManagerForChainRun ,
10
9
)
10
+ from langchain .chains import LLMChain
11
11
from langchain .chains .base import Chain
12
12
from langchain .output_parsers import OutputFixingParser , PydanticOutputParser
13
13
from langchain .schema import BaseOutputParser
14
+ from langchain_core .prompts import BasePromptTemplate
14
15
from pydantic import Extra , Field
15
16
16
17
from codedog .chains .pr_summary .prompts import CODE_SUMMARY_PROMPT , PR_SUMMARY_PROMPT
20
21
PullRequestProcessor ,
21
22
)
22
23
24
+ processor = PullRequestProcessor .build ()
25
+
23
26
24
27
class PRSummaryChain (Chain ):
25
28
"""Summarize a pull request.
@@ -32,17 +35,13 @@ class PRSummaryChain(Chain):
32
35
- code_summaries(Dict[str, str]): changed code file summarizations, key is file path.
33
36
"""
34
37
35
- # TODO: input keys validation
36
-
37
38
code_summary_chain : LLMChain = Field (exclude = True )
38
39
"""Chain to use to summarize code change."""
39
40
pr_summary_chain : LLMChain = Field (exclude = True )
40
41
"""Chain to use to summarize PR."""
41
42
42
43
parser : BaseOutputParser = Field (exclude = True )
43
44
"""Parse pr summarized result to PRSummary object."""
44
- processor : PullRequestProcessor = Field (exclude = True , default_factory = PullRequestProcessor .build )
45
- """PR data process."""
46
45
47
46
_input_keys : List [str ] = ["pull_request" ]
48
47
_output_keys : List [str ] = ["pr_summary" , "code_summaries" ]
@@ -78,15 +77,21 @@ def review(self, inputs, _run_manager) -> Dict[str, Any]:
78
77
79
78
code_summary_inputs = self ._process_code_summary_inputs (pr )
80
79
code_summary_outputs = (
81
- self .code_summary_chain .apply (code_summary_inputs , callbacks = _run_manager .get_child (tag = "CodeSummary" ))
80
+ self .code_summary_chain .apply (
81
+ code_summary_inputs , callbacks = _run_manager .get_child (tag = "CodeSummary" )
82
+ )
82
83
if code_summary_inputs
83
84
else []
84
85
)
85
86
86
- code_summaries = self .processor .build_change_summaries (code_summary_inputs , code_summary_outputs )
87
+ code_summaries = processor .build_change_summaries (
88
+ code_summary_inputs , code_summary_outputs
89
+ )
87
90
88
91
pr_summary_input = self ._process_pr_summary_input (pr , code_summaries )
89
- pr_summary_output = self .pr_summary_chain (pr_summary_input , callbacks = _run_manager .get_child (tag = "PRSummary" ))
92
+ pr_summary_output = self .pr_summary_chain (
93
+ pr_summary_input , callbacks = _run_manager .get_child (tag = "PRSummary" )
94
+ )
90
95
91
96
return self ._process_result (pr_summary_output , code_summaries )
92
97
@@ -95,26 +100,38 @@ async def areview(self, inputs, _run_manager) -> Dict[str, Any]:
95
100
96
101
code_summary_inputs = self ._process_code_summary_inputs (pr )
97
102
code_summary_outputs = (
98
- await self .code_summary_chain .aapply (code_summary_inputs , callbacks = _run_manager .get_child ())
103
+ await self .code_summary_chain .aapply (
104
+ code_summary_inputs , callbacks = _run_manager .get_child ()
105
+ )
99
106
if code_summary_inputs
100
107
else []
101
108
)
102
109
103
- code_summaries = self .processor .build_change_summaries (code_summary_inputs , code_summary_outputs )
110
+ code_summaries = processor .build_change_summaries (
111
+ code_summary_inputs , code_summary_outputs
112
+ )
104
113
105
114
pr_summary_input = self ._process_pr_summary_input (pr , code_summaries )
106
- pr_summary_output = await self .pr_summary_chain .acall (pr_summary_input , callbacks = _run_manager .get_child ())
115
+ pr_summary_output = await self .pr_summary_chain .ainvoke (
116
+ pr_summary_input , callbacks = _run_manager .get_child ()
117
+ )
107
118
108
119
return await self ._aprocess_result (pr_summary_output , code_summaries )
109
120
110
- def _call (self , inputs : Dict [str , Any ], run_manager : Optional [CallbackManagerForChainRun ] = None ) -> Dict [str , Any ]:
121
+ def _call (
122
+ self ,
123
+ inputs : Dict [str , Any ],
124
+ run_manager : Optional [CallbackManagerForChainRun ] = None ,
125
+ ) -> Dict [str , Any ]:
111
126
_run_manager = run_manager or CallbackManagerForChainRun .get_noop_manager ()
112
127
_run_manager .on_text (inputs ["pull_request" ].json () + "\n " )
113
128
114
129
return self .review (inputs , _run_manager )
115
130
116
131
async def _acall (
117
- self , inputs : Dict [str , Any ], run_manager : Optional [AsyncCallbackManagerForChainRun ] = None
132
+ self ,
133
+ inputs : Dict [str , Any ],
134
+ run_manager : Optional [AsyncCallbackManagerForChainRun ] = None ,
118
135
) -> Dict [str , Any ]:
119
136
_run_manager = run_manager or CallbackManagerForChainRun .get_noop_manager ()
120
137
await _run_manager .on_text (inputs ["pull_request" ].json () + "\n " )
@@ -123,28 +140,36 @@ async def _acall(
123
140
124
141
def _process_code_summary_inputs (self , pr : PullRequest ) -> List [Dict [str , str ]]:
125
142
input_data = []
126
- code_files = self . processor .get_diff_code_files (pr )
143
+ code_files = processor .get_diff_code_files (pr )
127
144
for code_file in code_files :
128
145
input_item = {
129
- "content" : code_file .diff_content .content [:2000 ], # TODO: handle long diff
146
+ "content" : code_file .diff_content .content [
147
+ :2000
148
+ ], # TODO: handle long diff
130
149
"name" : code_file .full_name ,
131
150
"language" : SUFFIX_LANGUAGE_MAPPING .get (code_file .suffix , "" ),
132
151
}
133
152
input_data .append (input_item )
134
153
135
154
return input_data
136
155
137
- def _process_pr_summary_input (self , pr : PullRequest , code_summaries : List [ChangeSummary ]) -> Dict [str , str ]:
138
- change_files_material : str = self .processor .gen_material_change_files (pr .change_files )
139
- code_summaries_material = self .processor .gen_material_code_summaries (code_summaries )
140
- pr_metadata_material = self .processor .gen_material_pr_metadata (pr )
156
+ def _process_pr_summary_input (
157
+ self , pr : PullRequest , code_summaries : List [ChangeSummary ]
158
+ ) -> Dict [str , str ]:
159
+ change_files_material : str = processor .gen_material_change_files (
160
+ pr .change_files
161
+ )
162
+ code_summaries_material = processor .gen_material_code_summaries (code_summaries )
163
+ pr_metadata_material = processor .gen_material_pr_metadata (pr )
141
164
return {
142
165
"change_files" : change_files_material ,
143
166
"code_summaries" : code_summaries_material ,
144
167
"metadata" : pr_metadata_material ,
145
168
}
146
169
147
- def _process_result (self , pr_summary_output : Dict [str , Any ], code_summaries : List [ChangeSummary ]) -> Dict [str , Any ]:
170
+ def _process_result (
171
+ self , pr_summary_output : Dict [str , Any ], code_summaries : List [ChangeSummary ]
172
+ ) -> Dict [str , Any ]:
148
173
return {
149
174
"pr_summary" : pr_summary_output ["text" ],
150
175
"code_summaries" : code_summaries ,
@@ -167,7 +192,16 @@ def from_llm(
167
192
pr_summary_prompt : BasePromptTemplate = PR_SUMMARY_PROMPT ,
168
193
** kwargs ,
169
194
) -> PRSummaryChain :
170
- parser = OutputFixingParser .from_llm (llm = pr_summary_llm , parser = PydanticOutputParser (pydantic_object = PRSummary ))
195
+ parser = OutputFixingParser .from_llm (
196
+ llm = pr_summary_llm , parser = PydanticOutputParser (pydantic_object = PRSummary )
197
+ )
171
198
code_summary_chain = LLMChain (llm = code_summary_llm , prompt = code_summary_prompt )
172
- pr_summary_chain = LLMChain (llm = pr_summary_llm , prompt = pr_summary_prompt , output_parser = parser )
173
- return cls (code_summary_chain = code_summary_chain , pr_summary_chain = pr_summary_chain , parser = parser , ** kwargs )
199
+ pr_summary_chain = LLMChain (
200
+ llm = pr_summary_llm , prompt = pr_summary_prompt , output_parser = parser
201
+ )
202
+ return cls (
203
+ code_summary_chain = code_summary_chain ,
204
+ pr_summary_chain = pr_summary_chain ,
205
+ parser = parser ,
206
+ ** kwargs ,
207
+ )
0 commit comments