Skip to content

Commit 37ad2ee

Browse files
feat(search_family): Add LOAD_FROM option to the FT.AGGREGATE command
Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent 719f3ea commit 37ad2ee

File tree

6 files changed

+876
-23
lines changed

6 files changed

+876
-23
lines changed

src/server/search/aggregator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using Value = ::dfly::search::SortableValue;
2727

2828
// DocValues sent through the pipeline
2929
// TODO: Replace DocValues with compact linear search map instead of hash map
30-
using DocValues = absl::flat_hash_map<std::string_view, Value>;
30+
using DocValues = absl::flat_hash_map<std::string, Value>;
3131

3232
struct AggregationResult {
3333
// Values to be passed to the next step

src/server/search/doc_index.cc

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,115 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
492492
return out;
493493
}
494494

495+
join::Vector<join::OwnedEntry> ShardDocIndex::PreagregateDataForJoin(
496+
const OpArgs& op_args, absl::Span<const std::string_view> join_fields,
497+
search::SearchAlgorithm* search_algo) const {
498+
auto search_results = search_algo->Search(&*indices_);
499+
500+
// First filter out sortable and non-sortable fields
501+
// We will load them in different ways
502+
const size_t fields_count = join_fields.size();
503+
std::vector<bool> is_sortable_field(fields_count);
504+
std::vector<FieldReference> basic_fields;
505+
basic_fields.reserve(fields_count);
506+
for (size_t i = 0; i < fields_count; ++i) {
507+
bool is_sortable = IsSortableField(join_fields[i], base_->schema);
508+
is_sortable_field[i] = is_sortable;
509+
if (!is_sortable) {
510+
basic_fields.emplace_back(join_fields[i]);
511+
}
512+
}
513+
514+
join::Vector<join::OwnedEntry> result;
515+
result.reserve(search_results.ids.size());
516+
517+
const ShardId shard_id = op_args.shard->shard_id();
518+
for (DocId doc : search_results.ids) {
519+
auto entry = LoadEntry(doc, op_args);
520+
if (!entry)
521+
continue;
522+
523+
auto& [key, accessor] = *entry;
524+
525+
SearchDocData loaded_basic_fields = accessor->Serialize(base_->schema, basic_fields);
526+
527+
bool insert_key = true;
528+
join::Vector<join::OwnedJoinableValue> join_fields_values(fields_count);
529+
for (size_t i = 0; i < fields_count; ++i) {
530+
search::SortableValue value;
531+
if (is_sortable_field[i]) {
532+
value = indices_->GetSortIndexValue(doc, join_fields[i]);
533+
} else {
534+
value = loaded_basic_fields[join_fields[i]];
535+
}
536+
537+
auto copy = [&](auto&& v) {
538+
using T = std::decay_t<decltype(v)>;
539+
if constexpr (!std::is_same_v<T, std::monostate>) {
540+
join_fields_values[i] = v;
541+
} else {
542+
// If the value is nil, we skip this key
543+
insert_key = false;
544+
}
545+
};
546+
547+
std::visit(std::move(copy), value);
548+
}
549+
550+
if (insert_key) {
551+
result.emplace_back(std::piecewise_construct, std::forward_as_tuple(shard_id, doc),
552+
std::forward_as_tuple(std::make_move_iterator(join_fields_values.begin()),
553+
std::make_move_iterator(join_fields_values.end())));
554+
}
555+
}
556+
557+
return result;
558+
}
559+
560+
ShardDocIndex::FieldsValuesPerDocId ShardDocIndex::LoadKeysData(
561+
const OpArgs& op_args, const absl::flat_hash_set<search::DocId>& doc_ids,
562+
absl::Span<const std::string_view> fields_to_load) const {
563+
const size_t fields_count = fields_to_load.size();
564+
std::vector<bool> is_sortable_field(fields_count);
565+
std::vector<FieldReference> basic_fields;
566+
basic_fields.reserve(fields_count);
567+
for (size_t i = 0; i < fields_count; ++i) {
568+
bool is_sortable = IsSortableField(fields_to_load[i], base_->schema);
569+
is_sortable_field[i] = is_sortable;
570+
if (!is_sortable) {
571+
basic_fields.emplace_back(fields_to_load[i]);
572+
}
573+
}
574+
575+
FieldsValuesPerDocId result;
576+
result.reserve(doc_ids.size());
577+
578+
for (DocId doc : doc_ids) {
579+
auto entry = LoadEntry(doc, op_args);
580+
if (!entry)
581+
continue;
582+
583+
auto& [key, accessor] = *entry;
584+
585+
SearchDocData loaded_basic_fields = accessor->Serialize(base_->schema, basic_fields);
586+
587+
FieldsValues fields_values(fields_count);
588+
for (size_t i = 0; i < fields_count; ++i) {
589+
if (is_sortable_field[i]) {
590+
fields_values[i] = indices_->GetSortIndexValue(doc, fields_to_load[i]);
591+
} else {
592+
fields_values[i] = loaded_basic_fields[fields_to_load[i]];
593+
}
594+
}
595+
596+
result.emplace(std::piecewise_construct, std::forward_as_tuple(doc),
597+
std::forward_as_tuple(std::make_move_iterator(fields_values.begin()),
598+
std::make_move_iterator(fields_values.end())));
599+
}
600+
601+
return result;
602+
}
603+
495604
DocIndexInfo ShardDocIndex::GetInfo() const {
496605
return {*base_, key_index_.Size()};
497606
}

src/server/search/doc_index.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "core/search/synonyms.h"
2121
#include "server/common.h"
2222
#include "server/search/aggregator.h"
23+
#include "server/search/index_join.h"
2324
#include "server/table.h"
2425

2526
namespace dfly {
@@ -66,6 +67,10 @@ struct FieldReference {
6667
return (is_json && IsJsonPath(name_)) ? name_ : schema.LookupAlias(name_);
6768
}
6869

70+
std::string_view Name() const {
71+
return name_;
72+
}
73+
6974
std::string_view OutputName() const {
7075
return alias_.empty() ? name_ : alias_;
7176
}
@@ -122,9 +127,31 @@ struct SearchParams {
122127
};
123128

124129
struct AggregateParams {
130+
struct JoinParams {
131+
// Fist field is the index name, second is the field name.
132+
using Field = std::pair<std::string, std::string>;
133+
134+
struct Condition {
135+
Condition(std::string_view field_, std::string_view foreign_index_,
136+
std::string_view foreign_field_)
137+
: field{field_}, foreign_field{Field{foreign_index_, foreign_field_}} {
138+
}
139+
140+
std::string field;
141+
Field foreign_field;
142+
};
143+
144+
std::string index;
145+
std::string index_alias;
146+
std::vector<Condition> conditions;
147+
std::string query = "*";
148+
};
149+
125150
std::string_view index, query;
126151
search::QueryParams params;
127152

153+
std::vector<JoinParams> joins;
154+
128155
std::optional<std::vector<FieldReference>> load_fields;
129156
std::vector<aggregate::AggregationStep> steps;
130157
};
@@ -160,6 +187,9 @@ class ShardDocIndex {
160187
friend class ShardDocIndices;
161188
using DocId = search::DocId;
162189

190+
// Used in FieldsValuesPerDocId to store values for each field per document
191+
using FieldsValues = absl::InlinedVector<search::SortableValue, 4>;
192+
163193
// DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface.
164194
struct DocKeyIndex {
165195
DocId Add(std::string_view key);
@@ -188,6 +218,16 @@ class ShardDocIndex {
188218
const AggregateParams& params,
189219
search::SearchAlgorithm* search_algo) const;
190220

221+
// Methods needed for join operation
222+
join::Vector<join::OwnedEntry> PreagregateDataForJoin(
223+
const OpArgs& op_args, absl::Span<const std::string_view> join_fields,
224+
search::SearchAlgorithm* search_algo) const;
225+
226+
using FieldsValuesPerDocId = absl::flat_hash_map<DocId, FieldsValues>;
227+
FieldsValuesPerDocId LoadKeysData(const OpArgs& op_args,
228+
const absl::flat_hash_set<search::DocId>& doc_ids,
229+
absl::Span<const std::string_view> fields_to_load) const;
230+
191231
// Return whether base index matches
192232
bool Matches(std::string_view key, unsigned obj_code) const;
193233

src/server/search/index_join.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ using Key = std::pair<ShardId, search::DocId>;
4040
using Entry = std::pair<Key, Vector<JoinableValue> /*fields values of this key*/>;
4141
using EntriesPerIndex = absl::Span<const Vector<Entry> /*one index can store several keys*/>;
4242

43+
// TODO: comments
44+
using OwnedJoinableValue = std::variant<double, std::string>;
45+
using OwnedEntry = std::pair<Key, Vector<OwnedJoinableValue>>;
46+
4347
// Stores data for single join expression,
4448
// e.g. index1.field1 = index2.field2:
4549
// field - "field1", foreign_index - "index2", foreign_field - "field2"

0 commit comments

Comments
 (0)