Skip to content

Commit 6ee90bd

Browse files
nefertitirogersNefertiti  RogersNefertiti  Rogers
authored
Additional test coverage (#318)
* init commit * merge * update * async tests * more updates to async * lint fixes * update mocks and lint fixes * lint * lint * lint * lint --------- Co-authored-by: Nefertiti Rogers <nefertitirogers@Nefertitis-MacBook-Pro.local> Co-authored-by: Nefertiti Rogers <nefertitirogers@Nefertitis-MBP.localdomain>
1 parent d124f94 commit 6ee90bd

File tree

10 files changed

+332
-1
lines changed

10 files changed

+332
-1
lines changed

tests/integration_tests/mock_llm_outputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _invoke_llm(self, prompt, *args, **kwargs):
2828
pydantic.COMPILED_PROMPT_REASK_2: pydantic.LLM_OUTPUT_REASK_2,
2929
string.COMPILED_PROMPT: string.LLM_OUTPUT,
3030
string.COMPILED_PROMPT_REASK: string.LLM_OUTPUT_REASK,
31+
string.COMPILED_LIST_PROMPT: string.LIST_LLM_OUTPUT,
3132
python_rail.VALIDATOR_PARALLELISM_PROMPT_1: python_rail.VALIDATOR_PARALLELISM_RESPONSE_1, # noqa: E501
3233
python_rail.VALIDATOR_PARALLELISM_PROMPT_2: python_rail.VALIDATOR_PARALLELISM_RESPONSE_2, # noqa: E501
3334
python_rail.VALIDATOR_PARALLELISM_PROMPT_3: python_rail.VALIDATOR_PARALLELISM_RESPONSE_3, # noqa: E501

tests/integration_tests/test_assets/string/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111

1212
COMPILED_INSTRUCTIONS = reader("compiled_instructions.txt")
1313
COMPILED_PROMPT = reader("compiled_prompt.txt")
14+
COMPILED_LIST_PROMPT = reader("compiled_list_prompt.txt")
1415
LLM_OUTPUT = reader("llm_output.txt")
16+
LIST_LLM_OUTPUT = reader("llm_list_output.txt")
1517
RAIL_SPEC_FOR_STRING = reader("string.rail")
18+
RAIL_SPEC_FOR_LIST = reader("list.rail")
1619

1720
COMPILED_PROMPT_REASK = reader("compiled_prompt_reask.txt")
1821
RAIL_SPEC_FOR_STRING_REASK = reader("string_reask.rail")
@@ -30,6 +33,8 @@
3033

3134
__all__ = [
3235
"COMPILED_PROMPT",
36+
"RAIL_SPEC_FOR_LIST",
37+
"LIST_LLM_OUTPUT",
3338
"LLM_OUTPUT",
3439
"RAIL_SPEC_FOR_STRING",
3540
"COMPILED_PROMPT_REASK",
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
Generate a dataset of fake user orders. Each row of the dataset should be valid.
3+
4+
5+
Given below is XML that describes the information to extract from this document and the tags to extract it into.
6+
7+
<output>
8+
<list name="user_orders" description="Generate a list of user, and how many orders they have placed in the past." format="length: min=10 max=10">
9+
<object>
10+
<string name="user_id" description="The user's id." format="1-indexed"/>
11+
<string name="user_name" description="The user's first name and last name" format="two-words"/>
12+
<integer name="num_orders" description="The number of orders the user has placed" format="valid-range: min=0 max=50"/>
13+
<date name="last_order_date" description="Date of last order"/>
14+
</object>
15+
</list>
16+
</output>
17+
18+
19+
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
20+
21+
Here are examples of simple (XML, JSON) pairs that show the expected behavior:
22+
- `<string name='foo' format='two-words lower-case' />` => `{'foo': 'example one'}`
23+
- `<list name='bar'><string format='upper-case' /></list>` => `{"bar": ['STRING ONE', 'STRING TWO', etc.]}`
24+
- `<object name='baz'><string name="foo" format="capitalize two-words" /><integer name="index" format="1-indexed" /></object>` => `{'baz': {'foo': 'Some String', 'index': 1}}`
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<rail version="0.1">
2+
<output>
3+
<list name="user_orders" description="Generate a list of user, and how many orders they have placed in the past." format="length: 10 10" on-fail-length="noop">
4+
<object>
5+
<string name="user_id" description="The user's id." format="1-indexed" />
6+
<string name="user_name" description="The user's first name and last name" format="two-words" />
7+
<integer name="num_orders" description="The number of orders the user has placed" format="valid-range: 0 50" />
8+
<date name="last_order_date" description="Date of last order" />
9+
</object>
10+
</list>
11+
</output>
12+
13+
<prompt>
14+
Generate a dataset of fake user orders. Each row of the dataset should be valid.
15+
16+
${gr.complete_json_suffix}
17+
</prompt>
18+
</rail>
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
'user_orders': [
3+
{'user_id': 1, 'user_name': 'John Smith', 'num_orders': 10, 'last_order_date': '2020-01-01'},
4+
{'user_id': 2, 'user_name': 'Jane Doe', 'num_orders': 20, 'last_order_date': '2020-02-01'},
5+
{'user_id': 3, 'user_name': 'Bob Jones', 'num_orders': 30, 'last_order_date': '2020-03-01'},
6+
{'user_id': 4, 'user_name': 'Alice Smith', 'num_orders': 40, 'last_order_date': '2020-04-01'},
7+
{'user_id': 5, 'user_name': 'John Doe', 'num_orders': 50, 'last_order_date': '2020-05-01'},
8+
{'user_id': 6, 'user_name': 'Jane Jones', 'num_orders': 0, 'last_order_date': '2020-06-01'},
9+
{'user_id': 7, 'user_name': 'Bob Smith', 'num_orders': 10, 'last_order_date': '2020-07-01'},
10+
{'user_id': 8, 'user_name': 'Alice Doe', 'num_orders': 20, 'last_order_date': '2020-08-01'},
11+
{'user_id': 9, 'user_name': 'John Jones', 'num_orders': 30, 'last_order_date': '2020-09-01'},
12+
{'user_id': 10, 'user_name': 'Jane Smith', 'num_orders': 40, 'last_order_date': '2020-10-01'}
13+
]
14+
}

tests/integration_tests/test_async.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,157 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators:
6363
)
6464

6565

66+
@pytest.mark.asyncio
67+
async def test_entity_extraction_with_noop(mocker):
68+
mocker.patch(
69+
"guardrails.llm_providers.AsyncOpenAICallable",
70+
new=MockAsyncOpenAICallable,
71+
)
72+
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
73+
guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_NOOP)
74+
_, final_output = await guard(
75+
llm_api=openai.Completion.acreate,
76+
prompt_params={"document": content[:6000]},
77+
num_reasks=1,
78+
)
79+
80+
# Assertions are made on the guard state object.
81+
assert final_output == entity_extraction.VALIDATED_OUTPUT_NOOP
82+
83+
guard_history = guard.guard_state.most_recent_call.history
84+
85+
# Check that the guard state object has the correct number of re-asks.
86+
assert len(guard_history) == 1
87+
88+
# For orginal prompt and output
89+
assert guard_history[0].prompt == gd.Prompt(entity_extraction.COMPILED_PROMPT)
90+
assert guard_history[0].output == entity_extraction.LLM_OUTPUT
91+
assert guard_history[0].validated_output == entity_extraction.VALIDATED_OUTPUT_NOOP
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_entity_extraction_with_noop_pydantic(mocker):
96+
mocker.patch(
97+
"guardrails.llm_providers.AsyncOpenAICallable",
98+
new=MockAsyncOpenAICallable,
99+
)
100+
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
101+
guard = gd.Guard.from_pydantic(
102+
entity_extraction.PYDANTIC_RAIL_WITH_NOOP, entity_extraction.PYDANTIC_PROMPT
103+
)
104+
_, final_output = await guard(
105+
llm_api=openai.Completion.acreate,
106+
prompt_params={"document": content[:6000]},
107+
num_reasks=1,
108+
)
109+
110+
# Assertions are made on the guard state object.
111+
assert final_output == entity_extraction.VALIDATED_OUTPUT_NOOP
112+
113+
guard_history = guard.guard_state.most_recent_call.history
114+
115+
# Check that the guard state object has the correct number of re-asks.
116+
assert len(guard_history) == 1
117+
118+
# For orginal prompt and output
119+
assert guard_history[0].prompt == gd.Prompt(entity_extraction.COMPILED_PROMPT)
120+
assert guard_history[0].output == entity_extraction.LLM_OUTPUT
121+
assert guard_history[0].validated_output == entity_extraction.VALIDATED_OUTPUT_NOOP
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_entity_extraction_with_filter(mocker):
126+
"""Test that the entity extraction works with re-asking."""
127+
mocker.patch(
128+
"guardrails.llm_providers.AsyncOpenAICallable",
129+
new=MockAsyncOpenAICallable,
130+
)
131+
132+
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
133+
guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FILTER)
134+
_, final_output = await guard(
135+
llm_api=openai.Completion.acreate,
136+
prompt_params={"document": content[:6000]},
137+
num_reasks=1,
138+
)
139+
140+
# Assertions are made on the guard state object.
141+
assert final_output == entity_extraction.VALIDATED_OUTPUT_FILTER
142+
143+
guard_history = guard.guard_state.most_recent_call.history
144+
145+
# Check that the guard state object has the correct number of re-asks.
146+
assert len(guard_history) == 1
147+
148+
# For orginal prompt and output
149+
assert guard_history[0].prompt == gd.Prompt(entity_extraction.COMPILED_PROMPT)
150+
assert guard_history[0].output == entity_extraction.LLM_OUTPUT
151+
assert (
152+
guard_history[0].validated_output == entity_extraction.VALIDATED_OUTPUT_FILTER
153+
)
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_entity_extraction_with_fix(mocker):
158+
"""Test that the entity extraction works with re-asking."""
159+
mocker.patch(
160+
"guardrails.llm_providers.AsyncOpenAICallable",
161+
new=MockAsyncOpenAICallable,
162+
)
163+
164+
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
165+
guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FIX)
166+
_, final_output = await guard(
167+
llm_api=openai.Completion.acreate,
168+
prompt_params={"document": content[:6000]},
169+
num_reasks=1,
170+
)
171+
172+
# Assertions are made on the guard state object.
173+
assert final_output == entity_extraction.VALIDATED_OUTPUT_FIX
174+
175+
guard_history = guard.guard_state.most_recent_call.history
176+
177+
# Check that the guard state object has the correct number of re-asks.
178+
assert len(guard_history) == 1
179+
180+
# For orginal prompt and output
181+
assert guard_history[0].prompt == gd.Prompt(entity_extraction.COMPILED_PROMPT)
182+
assert guard_history[0].output == entity_extraction.LLM_OUTPUT
183+
assert guard_history[0].validated_output == entity_extraction.VALIDATED_OUTPUT_FIX
184+
185+
186+
@pytest.mark.asyncio
187+
async def test_entity_extraction_with_refrain(mocker):
188+
"""Test that the entity extraction works with re-asking."""
189+
mocker.patch(
190+
"guardrails.llm_providers.AsyncOpenAICallable",
191+
new=MockAsyncOpenAICallable,
192+
)
193+
194+
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
195+
guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_REFRAIN)
196+
_, final_output = await guard(
197+
llm_api=openai.Completion.acreate,
198+
prompt_params={"document": content[:6000]},
199+
num_reasks=1,
200+
)
201+
# Assertions are made on the guard state object.
202+
assert final_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN
203+
204+
guard_history = guard.guard_state.most_recent_call.history
205+
206+
# Check that the guard state object has the correct number of re-asks.
207+
assert len(guard_history) == 1
208+
209+
# For orginal prompt and output
210+
assert guard_history[0].prompt == gd.Prompt(entity_extraction.COMPILED_PROMPT)
211+
assert guard_history[0].output == entity_extraction.LLM_OUTPUT
212+
assert (
213+
guard_history[0].validated_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN
214+
)
215+
216+
66217
@pytest.mark.asyncio
67218
async def test_rail_spec_output_parse(rail_spec, llm_output, validated_output):
68219
"""Test that the rail_spec fixture is working."""

tests/integration_tests/test_guard.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ def test_string_output(mocker):
399399
prompt_params={"ingredients": "tomato, cheese, sour cream"},
400400
num_reasks=1,
401401
)
402-
403402
assert final_output == string.LLM_OUTPUT
404403

405404
guard_history = guard.guard_state.most_recent_call.history
@@ -486,6 +485,31 @@ def test_skeleton_reask(mocker):
486485
)
487486

488487

488+
'''def test_json_output(mocker):
489+
"""Test single string (non-JSON) generation."""
490+
mocker.patch(
491+
"guardrails.llm_providers.openai_wrapper", new=openai_completion_create
492+
)
493+
494+
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_LIST)
495+
_, final_output = guard(
496+
llm_api=openai.Completion.create,
497+
num_reasks=1,
498+
)
499+
assert final_output == string.LIST_LLM_OUTPUT
500+
501+
guard_history = guard.guard_state.most_recent_call.history
502+
503+
# Check that the guard state object has the correct number of re-asks.
504+
assert len(guard_history) == 1
505+
506+
# For original prompt and output
507+
#assert guard_history[0].prompt == gd.Prompt(string.COMPILED_PROMPT)
508+
assert guard_history[0].output == string.LLM_OUTPUT
509+
510+
'''
511+
512+
489513
@pytest.mark.parametrize(
490514
"rail,prompt,instructions,history,llm_api,expected_prompt,"
491515
"expected_instructions,expected_reask_prompt,expected_reask_instructions",
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<rail version="0.1">
2+
<output>
3+
<string name="test_string" description="A string for testing." />
4+
</output>
5+
<instructions>
6+
7+
You are a helpful bot, who answers only with valid JSON
8+
9+
</instructions>
10+
11+
<prompt>
12+
13+
Extract a string from the text
14+
15+
</prompt>
16+
</rail>

tests/unit_tests/test_guard.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,12 @@ class EmptyModel(BaseModel):
152152
def test_configure(guard: Guard, expected_num_reasks: int, config_num_reasks: int):
153153
guard.configure(config_num_reasks)
154154
assert guard.num_reasks == expected_num_reasks
155+
156+
157+
def guard_init_from_rail():
158+
guard = Guard.from_rail("tests/unit_tests/test_assets/simple.rail")
159+
assert (
160+
guard.instructions.format().source.strip()
161+
== "You are a helpful bot, who answers only with valid JSON"
162+
)
163+
assert guard.prompt.format().source.strip() == "Extract a string from the text"

0 commit comments

Comments
 (0)