Skip to content

Commit 2361453

Browse files
Adding 1024 unquantized
1 parent 876b53a commit 2361453

File tree

4 files changed

+141
-17
lines changed

4 files changed

+141
-17
lines changed

vectorlink/src/batch.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,18 @@ fn perform_indexing(
505505
);
506506
HnswConfiguration::Quantized1024By16(model, quantized_hnsw)
507507
} else {
508-
panic!("No unquantized 1024 available");
508+
let comparator = OpenAIComparator::new(
509+
domain_obj.name().to_owned(),
510+
Arc::new(domain_obj.all_vecs()?),
511+
);
512+
let vids: Vec<_> = (0..domain_obj.num_vecs()).map(VectorId).collect();
513+
let hnsw = Hnsw::generate(
514+
comparator,
515+
vids,
516+
BuildParameters::default(),
517+
&mut SimpleProgressMonitor::default(),
518+
);
519+
HnswConfiguration::UnquantizedOpenAi(model, hnsw)
509520
};
510521
eprintln!("done generating hnsw");
511522
keepalive!(progress, hnsw.serialize(&staging_file))?;
@@ -656,8 +667,10 @@ mod tests {
656667

657668
impl Comparator for MemoryOpenAIComparator {
658669
type T = Embedding;
659-
type Borrowable<'a> = &'a Embedding
660-
where Self: 'a;
670+
type Borrowable<'a>
671+
= &'a Embedding
672+
where
673+
Self: 'a;
661674
fn lookup(&self, v: VectorId) -> &Embedding {
662675
&self.vectors[v.0]
663676
}

vectorlink/src/comparator.rs

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ impl DiskOpenAIComparator {
4747

4848
impl Comparator for DiskOpenAIComparator {
4949
type T = Embedding;
50-
type Borrowable<'a> = Box<Embedding>
51-
where Self: 'a;
50+
type Borrowable<'a>
51+
= Box<Embedding>
52+
where
53+
Self: 'a;
5254
fn lookup(&self, v: VectorId) -> Box<Embedding> {
5355
Box::new(self.vectors.vec(v.0).unwrap())
5456
}
@@ -158,8 +160,10 @@ impl Disk1024Comparator {
158160

159161
impl Comparator for Disk1024Comparator {
160162
type T = Embedding1024;
161-
type Borrowable<'a> = Box<Embedding1024>
162-
where Self: 'a;
163+
type Borrowable<'a>
164+
= Box<Embedding1024>
165+
where
166+
Self: 'a;
163167
fn lookup(&self, v: VectorId) -> Box<Embedding1024> {
164168
Box::new(self.vectors.vec(v.0).unwrap())
165169
}
@@ -269,8 +273,10 @@ pub struct ComparatorMeta {
269273

270274
impl Comparator for OpenAIComparator {
271275
type T = Embedding;
272-
type Borrowable<'a> = &'a Embedding
273-
where Self: 'a;
276+
type Borrowable<'a>
277+
= &'a Embedding
278+
where
279+
Self: 'a;
274280
fn lookup(&self, v: VectorId) -> &Embedding {
275281
&self.range[v.0]
276282
}
@@ -317,6 +323,73 @@ impl Serializable for OpenAIComparator {
317323
}
318324
}
319325

326+
/* Memory 1024 comparator */
327+
#[derive(Clone)]
328+
pub struct Memory1024Comparator {
329+
domain_name: String,
330+
range: Arc<LoadedSizedVectorRange<Embedding1024>>,
331+
}
332+
333+
impl Memory1024Comparator {
334+
pub fn new(domain_name: String, range: Arc<LoadedSizedVectorRange<Embedding1024>>) -> Self {
335+
Self { domain_name, range }
336+
}
337+
}
338+
339+
impl Comparator for Memory1024Comparator {
340+
type T = Embedding1024;
341+
type Borrowable<'a>
342+
= &'a Embedding1024
343+
where
344+
Self: 'a;
345+
fn lookup(&self, v: VectorId) -> &Embedding1024 {
346+
&self.range[v.0]
347+
}
348+
349+
fn compare_raw(&self, v1: &Embedding1024, v2: &Embedding1024) -> f32 {
350+
normalized_cosine_distance_1024(v1, v2)
351+
}
352+
}
353+
354+
impl Serializable for Memory1024Comparator {
355+
type Params = Arc<VectorStore>;
356+
fn serialize<P: AsRef<Path>>(&self, path: P) -> Result<(), SerializationError> {
357+
let mut comparator_file: std::fs::File = OpenOptions::new()
358+
.write(true)
359+
.create(true)
360+
.truncate(true)
361+
.open(path)?;
362+
eprintln!("opened comparator serialize file");
363+
// How do we get this value?
364+
let comparator = ComparatorMeta {
365+
domain_name: self.domain_name.clone(),
366+
size: self.range.len(),
367+
};
368+
let comparator_meta = serde_json::to_string(&comparator)?;
369+
eprintln!("serialized comparator");
370+
comparator_file.write_all(&comparator_meta.into_bytes())?;
371+
eprintln!("wrote comparator to file");
372+
Ok(())
373+
}
374+
375+
fn deserialize<P: AsRef<Path>>(
376+
path: P,
377+
store: Arc<VectorStore>,
378+
) -> Result<Self, SerializationError> {
379+
let mut comparator_file = OpenOptions::new().read(true).open(path)?;
380+
let mut contents = String::new();
381+
comparator_file.read_to_string(&mut contents)?;
382+
let ComparatorMeta { domain_name, .. } = serde_json::from_str(&contents)?;
383+
let domain = store.get_domain_sized(&domain_name, EMBEDDING_BYTE_LENGTH)?;
384+
Ok(Memory1024Comparator {
385+
domain_name,
386+
range: Arc::new(domain.all_vecs()?),
387+
})
388+
}
389+
}
390+
391+
/* End Memory comparator */
392+
320393
struct MemoizedPartialDistances {
321394
partial_distances: Vec<bf16>,
322395
partial_norms: Vec<bf16>,
@@ -593,7 +666,10 @@ impl<
593666
{
594667
type T = [f32; SIZE];
595668

596-
type Borrowable<'a> = &'a Self::T where QuantizedDistance: 'a;
669+
type Borrowable<'a>
670+
= &'a Self::T
671+
where
672+
QuantizedDistance: 'a;
597673

598674
fn lookup(&self, v: VectorId) -> Self::Borrowable<'_> {
599675
&self.centroids[v.0]

vectorlink/src/configuration.rs

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use serde::{Deserialize, Serialize};
1313
use crate::{
1414
comparator::{
1515
Centroid16Comparator, Centroid16Comparator1024, Centroid32Comparator, Centroid4Comparator,
16-
Centroid8Comparator, Disk1024Comparator, DiskOpenAIComparator, OpenAIComparator,
17-
Quantized16Comparator, Quantized16Comparator1024, Quantized32Comparator,
16+
Centroid8Comparator, Disk1024Comparator, DiskOpenAIComparator, Memory1024Comparator,
17+
OpenAIComparator, Quantized16Comparator, Quantized16Comparator1024, Quantized32Comparator,
1818
Quantized4Comparator, Quantized8Comparator,
1919
},
2020
openai::Model,
@@ -28,6 +28,7 @@ use crate::{
2828
};
2929

3030
pub type OpenAIHnsw = Hnsw<OpenAIComparator>;
31+
pub type Memory1024Hnsw = Hnsw<Memory1024Comparator>;
3132

3233
#[derive(Serialize, Deserialize)]
3334
pub enum HnswConfigurationType {
@@ -36,6 +37,7 @@ pub enum HnswConfigurationType {
3637
SmallQuantizedOpenAi8,
3738
SmallQuantizedOpenAi4,
3839
UnquantizedOpenAi,
40+
Unquantized1024,
3941
Quantized1024,
4042
}
4143

@@ -92,6 +94,7 @@ pub enum HnswConfiguration {
9294
DiskOpenAIComparator,
9395
>,
9496
),
97+
Unquantized1024(Model, Memory1024Hnsw),
9598
UnquantizedOpenAi(Model, OpenAIHnsw),
9699
Quantized1024By16(
97100
Model,
@@ -115,6 +118,7 @@ impl HnswConfiguration {
115118
HnswConfiguration::SmallQuantizedOpenAi4(_, q) => Some(q.quantization_statistics()),
116119
HnswConfiguration::Quantized1024By16(_, q) => Some(q.quantization_statistics()),
117120
HnswConfiguration::UnquantizedOpenAi(_, _) => None,
121+
HnswConfiguration::Unquantized1024(_, _) => None,
118122
}
119123
}
120124

@@ -138,6 +142,9 @@ impl HnswConfiguration {
138142
HnswConfiguration::Quantized1024By16(model, _) => {
139143
(HnswConfigurationType::Quantized1024, model)
140144
}
145+
HnswConfiguration::Unquantized1024(model, _) => {
146+
(HnswConfigurationType::Unquantized1024, model)
147+
}
141148
};
142149
let version = 1;
143150

@@ -156,6 +163,7 @@ impl HnswConfiguration {
156163
HnswConfiguration::SmallQuantizedOpenAi8(m, _) => *m,
157164
HnswConfiguration::SmallQuantizedOpenAi4(m, _) => *m,
158165
HnswConfiguration::Quantized1024By16(m, _) => *m,
166+
HnswConfiguration::Unquantized1024(m, _) => *m,
159167
}
160168
}
161169

@@ -168,6 +176,7 @@ impl HnswConfiguration {
168176
HnswConfiguration::SmallQuantizedOpenAi8(_model, q) => q.vector_count(),
169177
HnswConfiguration::SmallQuantizedOpenAi4(_model, q) => q.vector_count(),
170178
HnswConfiguration::Quantized1024By16(_, q) => q.vector_count(),
179+
HnswConfiguration::Unquantized1024(_, q) => q.vector_count(),
171180
}
172181
}
173182

@@ -182,7 +191,8 @@ impl HnswConfiguration {
182191
HnswConfiguration::UnquantizedOpenAi(_model, h) => h.search(v, search_parameters),
183192
HnswConfiguration::SmallQuantizedOpenAi8(_, q) => q.search(v, search_parameters),
184193
HnswConfiguration::SmallQuantizedOpenAi4(_, q) => q.search(v, search_parameters),
185-
HnswConfiguration::Quantized1024By16(_, _q) => {
194+
HnswConfiguration::Quantized1024By16(_, _)
195+
| HnswConfiguration::Unquantized1024(_, _) => {
186196
panic!();
187197
}
188198
}
@@ -195,9 +205,8 @@ impl HnswConfiguration {
195205
) -> Vec<(VectorId, f32)> {
196206
match self {
197207
HnswConfiguration::Quantized1024By16(_, q) => q.search(v, search_parameters),
198-
_ => {
199-
panic!();
200-
}
208+
HnswConfiguration::Unquantized1024(_, h) => h.search(v, search_parameters),
209+
_ => panic!(),
201210
}
202211
}
203212

@@ -225,6 +234,7 @@ impl HnswConfiguration {
225234
HnswConfiguration::Quantized1024By16(_, q) => {
226235
q.improve_index(build_parameters, progress)
227236
}
237+
HnswConfiguration::Unquantized1024(_, h) => h.improve_index(build_parameters, progress),
228238
}
229239
}
230240

@@ -253,6 +263,9 @@ impl HnswConfiguration {
253263
HnswConfiguration::Quantized1024By16(_, q) => {
254264
q.improve_index_at(layer, build_parameters, progress)
255265
}
266+
HnswConfiguration::Unquantized1024(_, h) => {
267+
h.improve_index_at(layer, build_parameters, progress)
268+
}
256269
}
257270
}
258271

@@ -280,6 +293,9 @@ impl HnswConfiguration {
280293
HnswConfiguration::Quantized1024By16(_, q) => {
281294
q.improve_neighbors(optimization_parameters, last_recall)
282295
}
296+
HnswConfiguration::Unquantized1024(_, h) => {
297+
h.improve_neighbors(optimization_parameters, last_recall)
298+
}
283299
}
284300
}
285301

@@ -308,6 +324,9 @@ impl HnswConfiguration {
308324
HnswConfiguration::Quantized1024By16(_, q) => {
309325
q.promote_at_layer(layer_from_top, build_parameters, &mut progress)
310326
}
327+
HnswConfiguration::Unquantized1024(_, h) => {
328+
h.promote_at_layer(layer_from_top, build_parameters, &mut progress)
329+
}
311330
}
312331
}
313332

@@ -319,6 +338,7 @@ impl HnswConfiguration {
319338
HnswConfiguration::SmallQuantizedOpenAi8(_model, q) => q.zero_neighborhood_size(),
320339
HnswConfiguration::SmallQuantizedOpenAi4(_model, q) => q.zero_neighborhood_size(),
321340
HnswConfiguration::Quantized1024By16(_model, q) => q.zero_neighborhood_size(),
341+
HnswConfiguration::Unquantized1024(_, h) => h.zero_neighborhood_size(),
322342
}
323343
}
324344
pub fn threshold_nn(
@@ -346,7 +366,12 @@ impl HnswConfiguration {
346366
}
347367
HnswConfiguration::Quantized1024By16(_, q) => {
348368
Either::Right(Either::Right(Either::Right(Either::Right(Either::Right(
349-
q.threshold_nn(threshold, search_parameters),
369+
Either::Left(q.threshold_nn(threshold, search_parameters)),
370+
)))))
371+
}
372+
HnswConfiguration::Unquantized1024(_, h) => {
373+
Either::Right(Either::Right(Either::Right(Either::Right(Either::Right(
374+
Either::Right(h.threshold_nn(threshold, search_parameters)),
350375
)))))
351376
}
352377
}
@@ -372,6 +397,9 @@ impl HnswConfiguration {
372397
HnswConfiguration::Quantized1024By16(_, q) => {
373398
q.stochastic_recall(optimization_parameters)
374399
}
400+
HnswConfiguration::Unquantized1024(_, h) => {
401+
h.stochastic_recall(optimization_parameters)
402+
}
375403
}
376404
}
377405

@@ -387,6 +415,7 @@ impl HnswConfiguration {
387415
}
388416
HnswConfiguration::UnquantizedOpenAi(_, h) => h.build_parameters,
389417
HnswConfiguration::Quantized1024By16(_, q) => q.build_parameters_for_improve_index(),
418+
HnswConfiguration::Unquantized1024(_, h) => h.build_parameters,
390419
}
391420
}
392421

@@ -398,6 +427,7 @@ impl HnswConfiguration {
398427
| HnswConfiguration::SmallQuantizedOpenAi4(_, _)
399428
| HnswConfiguration::UnquantizedOpenAi(_, _) => 1536,
400429
HnswConfiguration::Quantized1024By16(_, _) => 1024,
430+
HnswConfiguration::Unquantized1024(_, _) => 1024,
401431
}
402432
}
403433
}
@@ -416,6 +446,7 @@ impl Serializable for HnswConfiguration {
416446
HnswConfiguration::SmallQuantizedOpenAi8(_, qhnsw) => qhnsw.serialize(&path)?,
417447
HnswConfiguration::SmallQuantizedOpenAi4(_, qhnsw) => qhnsw.serialize(&path)?,
418448
HnswConfiguration::Quantized1024By16(_, qhnsw) => qhnsw.serialize(&path)?,
449+
HnswConfiguration::Unquantized1024(_, hnsw) => hnsw.serialize(&path)?,
419450
}
420451
let state_path: PathBuf = path.as_ref().join("state.json");
421452
let mut state_file = OpenOptions::new()
@@ -473,6 +504,9 @@ impl Serializable for HnswConfiguration {
473504
QuantizedHnsw::deserialize(path, params)?,
474505
)
475506
}
507+
HnswConfigurationType::Unquantized1024 => {
508+
HnswConfiguration::Unquantized1024(state.model, Hnsw::deserialize(path, params)?)
509+
}
476510
})
477511
}
478512
}

vectorlink/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
459459
HnswConfiguration::SmallQuantizedOpenAi8(_, _) => todo!(),
460460
HnswConfiguration::SmallQuantizedOpenAi4(_, _) => todo!(),
461461
HnswConfiguration::UnquantizedOpenAi(_, _) => todo!(),
462+
HnswConfiguration::Unquantized1024(_, _) => todo!(),
462463
HnswConfiguration::Quantized1024By16(_, q) => q.compare(v1, v2),
463464
};
464465
eprintln!("result: {res:?}");

0 commit comments

Comments
 (0)