-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathabstract_graph_test.py
107 lines (93 loc) · 4.6 KB
/
abstract_graph_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Tests for the AbstractGraph.
"""
import pytest
from unittest.mock import patch
from scrapegraphai.graphs import AbstractGraph, BaseGraph
from scrapegraphai.nodes import (
FetchNode,
ParseNode
)
from scrapegraphai.models import OneApi, DeepSeek
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_ollama import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_aws import ChatBedrock
class TestGraph(AbstractGraph):
def __init__(self, prompt: str, config: dict):
super().__init__(prompt, config)
def _create_graph(self) -> BaseGraph:
fetch_node = FetchNode(
input="url| local_dir",
output=["doc"],
node_config={
"llm_model": self.llm_model,
"force": self.config.get("force", False),
"cut": self.config.get("cut", True),
"loader_kwargs": self.config.get("loader_kwargs", {}),
"browser_base": self.config.get("browser_base")
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"llm_model": self.llm_model,
"chunk_size": self.model_token
}
)
return BaseGraph(
nodes=[
fetch_node,
parse_node
],
edges=[
(fetch_node, parse_node),
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)
def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
class TestAbstractGraph:
@pytest.mark.parametrize("llm_config, expected_model", [
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
({
"model": "azure_openai/gpt-3.5-turbo",
"api_key": "random-api-key",
"api_version": "no version",
"azure_endpoint": "https://www.example.com/"},
AzureChatOpenAI),
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
({"model": "ollama/llama2"}, ChatOllama),
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi),
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek),
({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK"}, ChatBedrock),
])
def test_create_llm(self, llm_config, expected_model):
graph = TestGraph("Test prompt", {"llm": llm_config})
assert isinstance(graph.llm_model, expected_model)
def test_create_llm_unknown_provider(self):
with pytest.raises(ValueError):
TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
@pytest.mark.parametrize("llm_config, expected_model", [
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001", "rate_limit": {"requests_per_second": 1}}, ChatOpenAI),
({"model": "azure_openai/gpt-3.5-turbo", "api_key": "random-api-key", "api_version": "no version", "azure_endpoint": "https://www.example.com/", "rate_limit": {"requests_per_second": 1}}, AzureChatOpenAI),
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test", "rate_limit": {"requests_per_second": 1}}, ChatGoogleGenerativeAI),
({"model": "ollama/llama2", "rate_limit": {"requests_per_second": 1}}, ChatOllama),
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key", "rate_limit": {"requests_per_second": 1}}, OneApi),
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key", "rate_limit": {"requests_per_second": 1}}, DeepSeek),
({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK", "rate_limit": {"requests_per_second": 1}}, ChatBedrock),
])
def test_create_llm_with_rate_limit(self, llm_config, expected_model):
graph = TestGraph("Test prompt", {"llm": llm_config})
assert isinstance(graph.llm_model, expected_model)
@pytest.mark.asyncio
async def test_run_safe_async(self):
graph = TestGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}})
with patch.object(graph, 'run', return_value="Async result") as mock_run:
result = await graph.run_safe_async()
assert result == "Async result"
mock_run.assert_called_once()