Skip to content

Commit 60f2d16

Browse files
authored
Merge pull request #1257 from guardrails-ai/fix-async-streaming-context
Fix Async Stream Contexts
2 parents 8b71574 + ac016b5 commit 60f2d16

File tree

8 files changed

+327
-178
lines changed

8 files changed

+327
-178
lines changed

guardrails/llm_providers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def _invoke_llm(
218218
llm_response = cast(Iterator[str], response)
219219
return LLMResponse(
220220
output="",
221+
# FIXME: Why is this different from the async streaming implementation?
221222
stream_output=llm_response,
222223
)
223224

@@ -491,6 +492,7 @@ def _invoke_llm(self, *args, **kwargs) -> LLMResponse:
491492
llm_response = cast(Iterator[str], llm_response)
492493
return LLMResponse(
493494
output="",
495+
# FIXME: Why is this different from the async streaming implementation?
494496
stream_output=llm_response,
495497
)
496498

@@ -685,6 +687,8 @@ async def invoke_llm(
685687
# response = cast(AsyncIterator[str], response)
686688
return LLMResponse(
687689
output="",
690+
# FIXME: Why is this different from the synchronous streaming implementation? ## noqa: E501
691+
# This shouldn't be necessary: https://docs.litellm.ai/docs/completion/stream#async-streaming
688692
async_stream_output=response.completion_stream, # pyright: ignore[reportGeneralTypeIssues]
689693
)
690694

@@ -842,6 +846,8 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:
842846
# the callable returns a generator object
843847
return LLMResponse(
844848
output="",
849+
# FIXME: Why is this different from the synchronous streaming implementation? ## noqa: E501
850+
# This shouldn't be necessary: https://docs.litellm.ai/docs/completion/stream#async-streaming
845851
async_stream_output=output.completion_stream,
846852
)
847853

guardrails/run/async_stream_runner.py

Lines changed: 182 additions & 140 deletions
Large diffs are not rendered by default.

guardrails/run/stream_runner.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from guardrails.actions.reask import ReAsk, SkeletonReAsk
1818
from guardrails.constants import pass_status
1919
from guardrails.telemetry import trace_stream_step
20+
from guardrails.utils.safe_get import safe_get
2021

2122

2223
class StreamRunner(Runner):
@@ -240,35 +241,74 @@ def prepare_chunk_generator(stream) -> Iterator[Tuple[Any, bool]]:
240241
def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool:
241242
"""Detect if chunk is final chunk."""
242243
try:
244+
if (
245+
not chunk.choices or len(chunk.choices) == 0
246+
) and chunk.usage is not None:
247+
# This is the last extra chunk for usage statistics
248+
return True
243249
finished = chunk.choices[0].finish_reason
244250
return finished is not None
245251
except (AttributeError, TypeError):
246252
return False
247253

248254
def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str:
249-
"""Get the text from a chunk."""
250-
chunk_text = ""
251-
try:
252-
finished = chunk.choices[0].finish_reason
253-
content = chunk.choices[0].delta.content
254-
if not finished and content:
255-
chunk_text = content
256-
except Exception:
257-
try:
258-
finished = chunk.choices[0].finish_reason
259-
content = chunk.choices[0].text
260-
if not finished and content:
261-
chunk_text = content
262-
except Exception:
263-
try:
264-
chunk_text = chunk
265-
except Exception as e:
266-
raise ValueError(
267-
f"Error getting chunk from stream: {e}. "
268-
"Non-OpenAI API callables expected to return "
269-
"a generator of strings."
270-
) from e
271-
return chunk_text
255+
"""Get the text from a chunk.
256+
257+
chunk is assumed to be an Iterator of either string or
258+
ChatCompletionChunk
259+
260+
These types are not properly enforced upstream so we must use
261+
reflection
262+
"""
263+
# Safeguard against None
264+
# which can happen when the user provides
265+
# custom LLM wrappers
266+
if not chunk:
267+
return ""
268+
elif isinstance(chunk, str):
269+
# If chunk is a string, return it
270+
return chunk
271+
elif hasattr(chunk, "choices") and hasattr(chunk.choices, "__iter__"):
272+
# If chunk is a ChatCompletionChunk, return the text
273+
# from the first choice
274+
chunk_text = ""
275+
first_choice = safe_get(chunk.choices, 0)
276+
if not first_choice:
277+
return chunk_text
278+
279+
if hasattr(first_choice, "delta") and hasattr(
280+
first_choice.delta, "content"
281+
):
282+
chunk_text = first_choice.delta.content
283+
elif hasattr(first_choice, "text"):
284+
chunk_text = first_choice.text
285+
else:
286+
# If chunk is not a string or ChatCompletionChunk, raise an error
287+
raise ValueError(
288+
"chunk.choices[0] does not have "
289+
"delta.content or text. "
290+
"Non-OpenAI compliant callables must return "
291+
"a generator of strings."
292+
)
293+
294+
if not chunk_text:
295+
# If chunk_text is empty, return an empty string
296+
return ""
297+
elif not isinstance(chunk_text, str):
298+
# If chunk_text is not a string, raise an error
299+
raise ValueError(
300+
"Chunk text is not a string. "
301+
"Non-OpenAI compliant callables must return "
302+
"a generator of strings."
303+
)
304+
return chunk_text
305+
else:
306+
# If chunk is not a string or ChatCompletionChunk, raise an error
307+
raise ValueError(
308+
"Chunk is not a string or ChatCompletionChunk. "
309+
"Non-OpenAI compliant callables must return "
310+
"a generator of strings."
311+
)
272312

273313
def parse(
274314
self, output: str, output_schema: Dict[str, Any], *, verified: set, **kwargs

guardrails/utils/safe_get.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ def safe_get_with_brackets(
1212
return value
1313
except Exception as e:
1414
logger.debug(
15-
f"Failed to get value for key: {key} out of container: {container}!"
15+
f"""
16+
Failed to get value for key: {key} out of container: {container}.
17+
Reason: {e}
18+
Fallbacking to default value...
19+
"""
1620
)
17-
logger.debug(e)
1821
return default
1922

2023

guardrails/validator_base.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# - [ ] Remove validator_base.py in 0.6.x
55

66
import asyncio
7-
import contextlib
7+
from contextvars import Context, ContextVar
88
from functools import partial
99
import inspect
1010
import logging
@@ -67,10 +67,8 @@ def split_sentence_word_tokenizers_jl_separator(
6767
# we check for a . to avoid wastefully calling the tokenizer
6868

6969
# check at least 3 characters have been accumulated before splitting
70-
is_minimum_length = False
71-
with contextlib.suppress(IndexError):
72-
chunk[2]
73-
is_minimum_length = True
70+
third_chunk = safe_get(chunk, 2)
71+
is_minimum_length = third_chunk is not None
7472

7573
# check for potential line endings, which is what split_sentences does
7674
chunk_with_potential_line_endings, count = re.subn(
@@ -292,7 +290,14 @@ def _chunking_function(self, chunk: str) -> List[str]:
292290
return split_sentence_word_tokenizers_jl_separator(chunk)
293291

294292
def validate_stream(
295-
self, chunk: Any, metadata: Dict[str, Any], **kwargs
293+
self,
294+
chunk: Any,
295+
metadata: Dict[str, Any],
296+
*,
297+
property_path: Optional[str] = "$",
298+
context_vars: Optional[ContextVar[Dict[str, ContextVar[List[str]]]]] = None,
299+
context: Optional[Context] = None,
300+
**kwargs,
296301
) -> Optional[ValidationResult]:
297302
"""Validates a chunk emitted by an LLM. If the LLM chunk is smaller
298303
than the validator's chunking strategy, it will be accumulated until it
@@ -307,8 +312,20 @@ def validate_stream(
307312
result.
308313
"""
309314
# combine accumulated chunks and new [:-1]chunk
310-
self.accumulated_chunks.append(chunk)
311-
accumulated_text = "".join(self.accumulated_chunks)
315+
accumulated_chunks = self.accumulated_chunks
316+
317+
# if context_vars is passed, use it to get the accumulated chunks
318+
context_var: Optional[ContextVar[List[str]]] = None
319+
ctx_var_map: Optional[Dict[str, ContextVar[List[str]]]] = None
320+
context_key = f"{property_path}_{self.rail_alias}"
321+
if context_vars and context:
322+
ctx_var_map = context.run(context_vars.get)
323+
context_var = ctx_var_map.get(context_key)
324+
if context_var:
325+
accumulated_chunks = context.run(context_var.get)
326+
327+
accumulated_chunks.append(chunk)
328+
accumulated_text = "".join(accumulated_chunks)
312329
# check if enough chunks have accumulated for validation
313330
split_contents = self._chunking_function(accumulated_text)
314331

@@ -318,9 +335,20 @@ def validate_stream(
318335
split_contents = [accumulated_text, ""]
319336
# if no chunks are returned, we haven't accumulated enough
320337
if len(split_contents) == 0:
338+
if context_vars and context_var and context and ctx_var_map:
339+
context.run(context_var.set, accumulated_chunks)
340+
ctx_var_map[context_key] = context_var
341+
context.run(context_vars.set, ctx_var_map)
342+
else:
343+
self.accumulated_chunks = accumulated_chunks
321344
return None
322345
[chunk_to_validate, new_accumulated_chunks] = split_contents
323-
self.accumulated_chunks = [new_accumulated_chunks]
346+
if context_vars and context_var and context and ctx_var_map:
347+
context.run(context_var.set, [new_accumulated_chunks])
348+
ctx_var_map[context_key] = context_var
349+
context.run(context_vars.set, ctx_var_map)
350+
else:
351+
self.accumulated_chunks = [new_accumulated_chunks]
324352
# exclude last chunk, because it may not be a complete chunk
325353
validation_result = self.validate(chunk_to_validate, metadata)
326354
# if validate doesn't set validated chunk, we set it

guardrails/validator_service/async_validator_service.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ async def run_validator(
8484
metadata: Dict,
8585
absolute_property_path: str,
8686
stream: Optional[bool] = False,
87+
*,
88+
reference_path: Optional[str] = None,
8789
**kwargs,
8890
) -> ValidatorRun:
8991
validator_logs = self.before_run_validator(
@@ -96,6 +98,7 @@ async def run_validator(
9698
metadata,
9799
stream,
98100
validation_session_id=iteration.id,
101+
reference_path=reference_path,
99102
**kwargs,
100103
)
101104

@@ -111,6 +114,7 @@ async def run_validator(
111114
result.metadata or {},
112115
stream,
113116
validation_session_id=iteration.id,
117+
reference_path=reference_path,
114118
**kwargs,
115119
)
116120
value = self.perform_correction(
@@ -160,6 +164,7 @@ async def run_validators(
160164
metadata,
161165
absolute_property_path,
162166
stream=stream,
167+
reference_property_path=reference_property_path,
163168
**kwargs,
164169
)
165170
)
@@ -277,6 +282,7 @@ async def async_partial_validate(
277282
metadata,
278283
absolute_path,
279284
stream=stream,
285+
reference_path=reference_path,
280286
**kwargs,
281287
)
282288
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ authorized_licenses = [
162162
"python software foundation",
163163
"python software foundation license",
164164
"zpl 2.1",
165+
"mit and python-2.0"
165166
]
166167
unauthorized_licenses = [
167168
"gpl v3",

tests/unit_tests/validator_service/test_async_validator_service.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,12 @@ async def test_pass_result(self, mocker):
503503

504504
assert mock_run_validator_async.call_count == 1
505505
mock_run_validator_async.assert_called_once_with(
506-
validator, "value", {}, False, validation_session_id=iteration.id
506+
validator,
507+
"value",
508+
{},
509+
False,
510+
validation_session_id=iteration.id,
511+
reference_path=None,
507512
)
508513

509514
assert mock_after_run_validator.call_count == 1
@@ -562,7 +567,12 @@ async def test_pass_result_with_override(self, mocker):
562567

563568
assert mock_run_validator_async.call_count == 1
564569
mock_run_validator_async.assert_called_once_with(
565-
validator, "value", {}, False, validation_session_id=iteration.id
570+
validator,
571+
"value",
572+
{},
573+
False,
574+
validation_session_id=iteration.id,
575+
reference_path=None,
566576
)
567577

568578
assert mock_after_run_validator.call_count == 1
@@ -625,7 +635,12 @@ async def test_fail_result(self, mocker):
625635

626636
assert mock_run_validator_async.call_count == 1
627637
mock_run_validator_async.assert_called_once_with(
628-
validator, "value", {}, False, validation_session_id=iteration.id
638+
validator,
639+
"value",
640+
{},
641+
False,
642+
validation_session_id=iteration.id,
643+
reference_path=None,
629644
)
630645

631646
assert mock_after_run_validator.call_count == 1
@@ -699,13 +714,21 @@ async def test_fail_result_with_fix_reask(self, mocker):
699714
assert mock_run_validator_async.call_count == 2
700715
mock_run_validator_async.assert_has_calls(
701716
[
702-
call(validator, "value", {}, False, validation_session_id=iteration.id),
717+
call(
718+
validator,
719+
"value",
720+
{},
721+
False,
722+
validation_session_id=iteration.id,
723+
reference_path=None,
724+
),
703725
call(
704726
validator,
705727
"fixed-value",
706728
{},
707729
False,
708730
validation_session_id=iteration.id,
731+
reference_path=None,
709732
),
710733
]
711734
)

0 commit comments

Comments
 (0)