-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathdepth_search_graph.py
156 lines (133 loc) · 5.05 KB
/
depth_search_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
depth search graph Module
"""
from typing import Optional, Type
from pydantic import BaseModel
from ..nodes import (
DescriptionNode,
FetchNodeLevelK,
GenerateAnswerNodeKLevel,
ParseNodeDepthK,
RAGNode,
)
from .abstract_graph import AbstractGraph
from .base_graph import BaseGraph
class DepthSearchGraph(AbstractGraph):
"""
CodeGeneratorGraph is a script generator pipeline that generates
the function extract_data(html: str) -> dict() for
extracting the wanted information from a HTML page. The
code generated is in Python and uses the library BeautifulSoup.
It requires a user prompt, a source URL, and an output schema.
Attributes:
prompt (str): The prompt for the graph.
source (str): The source of the graph.
config (dict): Configuration parameters for the graph.
schema (BaseModel): The schema for the graph output.
llm_model: An instance of a language model client, configured for generating answers.
embedder_model: An instance of an embedding model client,
configured for generating embeddings.
verbose (bool): A flag indicating whether to show print statements during execution.
headless (bool): A flag indicating whether to run the graph in headless mode.
library (str): The library used for web scraping (beautiful soup).
Args:
prompt (str): The prompt for the graph.
source (str): The source of the graph.
config (dict): Configuration parameters for the graph.
schema (BaseModel): The schema for the graph output.
Example:
>>> code_gen = CodeGeneratorGraph(
... "List me all the attractions in Chioggia.",
... "https://en.wikipedia.org/wiki/Chioggia",
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
... )
>>> result = code_gen.run()
)
"""
def __init__(
self,
prompt: str,
source: str,
config: dict,
schema: Optional[Type[BaseModel]] = None,
):
super().__init__(prompt, config, source, schema)
self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping.
Returns:
BaseGraph: A graph instance representing the web scraping workflow.
"""
fetch_node_k = FetchNodeLevelK(
input="url| local_dir",
output=["docs"],
node_config={
"loader_kwargs": self.config.get("loader_kwargs", {}),
"force": self.config.get("force", False),
"cut": self.config.get("cut", True),
"browser_base": self.config.get("browser_base"),
"storage_state": self.config.get("storage_state"),
"depth": self.config.get("depth", 1),
"only_inside_links": self.config.get("only_inside_links", False),
},
)
parse_node_k = ParseNodeDepthK(
input="docs",
output=["docs"],
node_config={"verbose": self.config.get("verbose", False)},
)
description_node = DescriptionNode(
input="docs",
output=["docs"],
node_config={
"llm_model": self.llm_model,
"verbose": self.config.get("verbose", False),
"cache_path": self.config.get("cache_path", False),
},
)
rag_node = RAGNode(
input="docs",
output=["vectorial_db"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.config.get("embedder_model", False),
"verbose": self.config.get("verbose", False),
},
)
generate_answer_k = GenerateAnswerNodeKLevel(
input="vectorial_db",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.config.get("embedder_model", False),
"verbose": self.config.get("verbose", False),
},
)
return BaseGraph(
nodes=[
fetch_node_k,
parse_node_k,
description_node,
rag_node,
generate_answer_k,
],
edges=[
(fetch_node_k, parse_node_k),
(parse_node_k, description_node),
(description_node, rag_node),
(rag_node, generate_answer_k),
],
entry_point=fetch_node_k,
graph_name=self.__class__.__name__,
)
def run(self) -> str:
"""
Executes the scraping process and returns the generated code.
Returns:
str: The generated code.
"""
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
docs = self.final_state.get("answer", "No answer")
return docs