@@ -63,6 +63,157 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators:
63
63
)
64
64
65
65
66
+ @pytest .mark .asyncio
67
+ async def test_entity_extraction_with_noop (mocker ):
68
+ mocker .patch (
69
+ "guardrails.llm_providers.AsyncOpenAICallable" ,
70
+ new = MockAsyncOpenAICallable ,
71
+ )
72
+ content = gd .docs_utils .read_pdf ("docs/examples/data/chase_card_agreement.pdf" )
73
+ guard = gd .Guard .from_rail_string (entity_extraction .RAIL_SPEC_WITH_NOOP )
74
+ _ , final_output = await guard (
75
+ llm_api = openai .Completion .acreate ,
76
+ prompt_params = {"document" : content [:6000 ]},
77
+ num_reasks = 1 ,
78
+ )
79
+
80
+ # Assertions are made on the guard state object.
81
+ assert final_output == entity_extraction .VALIDATED_OUTPUT_NOOP
82
+
83
+ guard_history = guard .guard_state .most_recent_call .history
84
+
85
+ # Check that the guard state object has the correct number of re-asks.
86
+ assert len (guard_history ) == 1
87
+
88
+ # For orginal prompt and output
89
+ assert guard_history [0 ].prompt == gd .Prompt (entity_extraction .COMPILED_PROMPT )
90
+ assert guard_history [0 ].output == entity_extraction .LLM_OUTPUT
91
+ assert guard_history [0 ].validated_output == entity_extraction .VALIDATED_OUTPUT_NOOP
92
+
93
+
94
+ @pytest .mark .asyncio
95
+ async def test_entity_extraction_with_noop_pydantic (mocker ):
96
+ mocker .patch (
97
+ "guardrails.llm_providers.AsyncOpenAICallable" ,
98
+ new = MockAsyncOpenAICallable ,
99
+ )
100
+ content = gd .docs_utils .read_pdf ("docs/examples/data/chase_card_agreement.pdf" )
101
+ guard = gd .Guard .from_pydantic (
102
+ entity_extraction .PYDANTIC_RAIL_WITH_NOOP , entity_extraction .PYDANTIC_PROMPT
103
+ )
104
+ _ , final_output = await guard (
105
+ llm_api = openai .Completion .acreate ,
106
+ prompt_params = {"document" : content [:6000 ]},
107
+ num_reasks = 1 ,
108
+ )
109
+
110
+ # Assertions are made on the guard state object.
111
+ assert final_output == entity_extraction .VALIDATED_OUTPUT_NOOP
112
+
113
+ guard_history = guard .guard_state .most_recent_call .history
114
+
115
+ # Check that the guard state object has the correct number of re-asks.
116
+ assert len (guard_history ) == 1
117
+
118
+ # For orginal prompt and output
119
+ assert guard_history [0 ].prompt == gd .Prompt (entity_extraction .COMPILED_PROMPT )
120
+ assert guard_history [0 ].output == entity_extraction .LLM_OUTPUT
121
+ assert guard_history [0 ].validated_output == entity_extraction .VALIDATED_OUTPUT_NOOP
122
+
123
+
124
+ @pytest .mark .asyncio
125
+ async def test_entity_extraction_with_filter (mocker ):
126
+ """Test that the entity extraction works with re-asking."""
127
+ mocker .patch (
128
+ "guardrails.llm_providers.AsyncOpenAICallable" ,
129
+ new = MockAsyncOpenAICallable ,
130
+ )
131
+
132
+ content = gd .docs_utils .read_pdf ("docs/examples/data/chase_card_agreement.pdf" )
133
+ guard = gd .Guard .from_rail_string (entity_extraction .RAIL_SPEC_WITH_FILTER )
134
+ _ , final_output = await guard (
135
+ llm_api = openai .Completion .acreate ,
136
+ prompt_params = {"document" : content [:6000 ]},
137
+ num_reasks = 1 ,
138
+ )
139
+
140
+ # Assertions are made on the guard state object.
141
+ assert final_output == entity_extraction .VALIDATED_OUTPUT_FILTER
142
+
143
+ guard_history = guard .guard_state .most_recent_call .history
144
+
145
+ # Check that the guard state object has the correct number of re-asks.
146
+ assert len (guard_history ) == 1
147
+
148
+ # For orginal prompt and output
149
+ assert guard_history [0 ].prompt == gd .Prompt (entity_extraction .COMPILED_PROMPT )
150
+ assert guard_history [0 ].output == entity_extraction .LLM_OUTPUT
151
+ assert (
152
+ guard_history [0 ].validated_output == entity_extraction .VALIDATED_OUTPUT_FILTER
153
+ )
154
+
155
+
156
+ @pytest .mark .asyncio
157
+ async def test_entity_extraction_with_fix (mocker ):
158
+ """Test that the entity extraction works with re-asking."""
159
+ mocker .patch (
160
+ "guardrails.llm_providers.AsyncOpenAICallable" ,
161
+ new = MockAsyncOpenAICallable ,
162
+ )
163
+
164
+ content = gd .docs_utils .read_pdf ("docs/examples/data/chase_card_agreement.pdf" )
165
+ guard = gd .Guard .from_rail_string (entity_extraction .RAIL_SPEC_WITH_FIX )
166
+ _ , final_output = await guard (
167
+ llm_api = openai .Completion .acreate ,
168
+ prompt_params = {"document" : content [:6000 ]},
169
+ num_reasks = 1 ,
170
+ )
171
+
172
+ # Assertions are made on the guard state object.
173
+ assert final_output == entity_extraction .VALIDATED_OUTPUT_FIX
174
+
175
+ guard_history = guard .guard_state .most_recent_call .history
176
+
177
+ # Check that the guard state object has the correct number of re-asks.
178
+ assert len (guard_history ) == 1
179
+
180
+ # For orginal prompt and output
181
+ assert guard_history [0 ].prompt == gd .Prompt (entity_extraction .COMPILED_PROMPT )
182
+ assert guard_history [0 ].output == entity_extraction .LLM_OUTPUT
183
+ assert guard_history [0 ].validated_output == entity_extraction .VALIDATED_OUTPUT_FIX
184
+
185
+
186
+ @pytest .mark .asyncio
187
+ async def test_entity_extraction_with_refrain (mocker ):
188
+ """Test that the entity extraction works with re-asking."""
189
+ mocker .patch (
190
+ "guardrails.llm_providers.AsyncOpenAICallable" ,
191
+ new = MockAsyncOpenAICallable ,
192
+ )
193
+
194
+ content = gd .docs_utils .read_pdf ("docs/examples/data/chase_card_agreement.pdf" )
195
+ guard = gd .Guard .from_rail_string (entity_extraction .RAIL_SPEC_WITH_REFRAIN )
196
+ _ , final_output = await guard (
197
+ llm_api = openai .Completion .acreate ,
198
+ prompt_params = {"document" : content [:6000 ]},
199
+ num_reasks = 1 ,
200
+ )
201
+ # Assertions are made on the guard state object.
202
+ assert final_output == entity_extraction .VALIDATED_OUTPUT_REFRAIN
203
+
204
+ guard_history = guard .guard_state .most_recent_call .history
205
+
206
+ # Check that the guard state object has the correct number of re-asks.
207
+ assert len (guard_history ) == 1
208
+
209
+ # For orginal prompt and output
210
+ assert guard_history [0 ].prompt == gd .Prompt (entity_extraction .COMPILED_PROMPT )
211
+ assert guard_history [0 ].output == entity_extraction .LLM_OUTPUT
212
+ assert (
213
+ guard_history [0 ].validated_output == entity_extraction .VALIDATED_OUTPUT_REFRAIN
214
+ )
215
+
216
+
66
217
@pytest .mark .asyncio
67
218
async def test_rail_spec_output_parse (rail_spec , llm_output , validated_output ):
68
219
"""Test that the rail_spec fixture is working."""
0 commit comments