Skip to content

Commit d9c4cf6

Browse files
authored
Fix: Ensure proper embedding dimensions (#54)
* fix embedding dimensions for various models sizes Signed-off-by: dvaz-external <dvaz.external@epo.org> * update package version Signed-off-by: dvaz-external <dvaz.external@epo.org> * bump package version Signed-off-by: dvaz-external <dvaz.external@epo.org> * fix test Signed-off-by: dvaz-external <dvaz.external@epo.org> * change embbedings dimension size reference method Signed-off-by: dvaz-external <dvaz.external@epo.org> * fix new dimension parameter issue Signed-off-by: dvaz-external <dvaz.external@epo.org> --------- Signed-off-by: dvaz-external <dvaz.external@epo.org>
1 parent 823f197 commit d9c4cf6

12 files changed

Lines changed: 111 additions & 56 deletions

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ Configuration is managed through two files:
145145
OPENAI_API_KEY="sk-..."
146146
OPENAI_MODEL="text-embedding-3-large" # Optional, defaults to text-embedding-3-large
147147
148+
# Optional: Embedding dimension size (defaults to 3072)
149+
EMBEDDING_DIMENSION="3072"
150+
148151
# Required: Your Azure OpenAI credentials (if using Azure provider)
149152
AZURE_OPENAI_KEY="your-azure-key"
150153
AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com"
@@ -223,13 +226,18 @@ Configuration is managed through two files:
223226
* `qdrant_port`: (Optional) Port for the Qdrant REST API. Defaults to `443` if `qdrant_url` starts with `https`, otherwise `6333`.
224227
* `collection_name`: (Optional) Name of the Qdrant collection to use. Defaults to `<product_name>_<version>` (lowercased, spaces replaced with underscores).
225228

229+
Optional embedding configuration:
230+
* `embedding.provider`: Provider for embeddings (`openai` or `azure`).
231+
* `embedding.dimension`: Embedding vector size. Defaults to `3072` when not set.
232+
226233
**Example (`config.yaml`):**
227234
```yaml
228235
# Optional: Configure embedding provider
229236
# Can also be set via EMBEDDING_PROVIDER environment variable
230237
# Defaults to OpenAI if not specified
231238
embedding:
232239
provider: 'openai' # or 'azure'
240+
dimension: 3072 # Optional, defaults to 3072
233241
openai:
234242
api_key: '${OPENAI_API_KEY}' # Optional, uses env var by default
235243
model: 'text-embedding-3-large' # Optional, defaults to text-embedding-3-large

config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
# Doc2Vec Configuration
2+
embedding:
3+
provider: 'openai'
4+
dimension: 3072
5+
openai:
6+
model: 'text-embedding-3-large'
7+
28
sources:
39
# GitHub Sources
410
- type: github

content-processor.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ export class ContentProcessor {
619619
let newLinksFound = 0;
620620

621621
for (const href of result.links) {
622-
const fullUrl = Utils.buildUrl(href, pageUrlForLinks);
622+
const fullUrl = Utils.buildUrl(href, pageUrlForLinks, logger);
623623
if (fullUrl.startsWith(sourceConfig.url)) {
624624
addReferrer(fullUrl, pageUrlForLinks);
625625
if (!visitedUrls.has(Utils.normalizeUrl(fullUrl))) {

database.ts

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {
1919
export class DatabaseManager {
2020
private static columnCache: WeakMap<Database, { hasBranch: boolean; hasRepo: boolean }> = new WeakMap();
2121

22-
static async initDatabase(config: SourceConfig, parentLogger: Logger): Promise<DatabaseConnection> {
22+
static async initDatabase(config: SourceConfig, parentLogger: Logger, embeddingDimension: number): Promise<DatabaseConnection> {
2323
const logger = parentLogger.child('database');
2424
const dbConfig = config.database_config;
2525

@@ -32,10 +32,10 @@ export class DatabaseManager {
3232
const db = new BetterSqlite3(dbPath, { allowExtension: true } as any);
3333
sqliteVec.load(db);
3434

35-
logger.debug(`Creating vec_items table if it doesn't exist`);
35+
logger.debug(`Creating vec_items table if it doesn't exist (dimension: ${embeddingDimension})`);
3636
db.exec(`
3737
CREATE VIRTUAL TABLE IF NOT EXISTS vec_items USING vec0(
38-
embedding FLOAT[3072],
38+
embedding FLOAT[${embeddingDimension}],
3939
product_name TEXT,
4040
version TEXT,
4141
branch TEXT,
@@ -61,7 +61,7 @@ export class DatabaseManager {
6161
logger.info(`Connecting to Qdrant at ${qdrantUrl}:${qdrantPort}, collection: ${collectionName}`);
6262
const qdrantClient = new QdrantClient({ url: qdrantUrl, apiKey: process.env.QDRANT_API_KEY, port: qdrantPort });
6363

64-
await this.createCollectionQdrant(qdrantClient, collectionName, logger);
64+
await this.createCollectionQdrant(qdrantClient, collectionName, logger, embeddingDimension);
6565
logger.info(`Qdrant connection established successfully`);
6666
return { client: qdrantClient, collectionName, type: 'qdrant' };
6767
} else {
@@ -71,7 +71,7 @@ export class DatabaseManager {
7171
}
7272
}
7373

74-
static async createCollectionQdrant(qdrantClient: QdrantClient, collectionName: string, logger: Logger) {
74+
static async createCollectionQdrant(qdrantClient: QdrantClient, collectionName: string, logger: Logger, embeddingDimension: number) {
7575
try {
7676
logger.debug(`Checking if collection ${collectionName} exists`);
7777
const collections = await qdrantClient.getCollections();
@@ -84,10 +84,10 @@ export class DatabaseManager {
8484
return;
8585
}
8686

87-
logger.info(`Creating new collection ${collectionName}`);
87+
logger.info(`Creating new collection ${collectionName} with dimension ${embeddingDimension}`);
8888
await qdrantClient.createCollection(collectionName, {
8989
vectors: {
90-
size: 3072,
90+
size: embeddingDimension,
9191
distance: "Cosine",
9292
},
9393
});
@@ -177,7 +177,8 @@ export class DatabaseManager {
177177
dbConnection: DatabaseConnection,
178178
key: string,
179179
value: string,
180-
logger: Logger
180+
logger: Logger,
181+
embeddingDimension: number
181182
): Promise<void> {
182183
try {
183184
if (dbConnection.type === 'sqlite') {
@@ -189,8 +190,7 @@ export class DatabaseManager {
189190
logger.debug(`Updated metadata value for ${key}`);
190191
} else if (dbConnection.type === 'qdrant') {
191192
const metadataUUID = Utils.generateMetadataUUID(key);
192-
const dummyEmbeddingSize = 3072;
193-
const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0);
193+
const dummyEmbedding = new Array(embeddingDimension).fill(0);
194194
const metadataPoint = {
195195
id: metadataUUID,
196196
vector: dummyEmbedding,
@@ -259,7 +259,12 @@ export class DatabaseManager {
259259
return defaultDate;
260260
}
261261

262-
static async updateLastRunDate(dbConnection: DatabaseConnection, repo: string, logger: Logger): Promise<void> {
262+
static async updateLastRunDate(
263+
dbConnection: DatabaseConnection,
264+
repo: string,
265+
logger: Logger,
266+
embeddingDimension: number
267+
): Promise<void> {
263268
const now = new Date().toISOString();
264269

265270
try {
@@ -279,8 +284,7 @@ export class DatabaseManager {
279284
logger.debug(`Using UUID: ${metadataUUID} for metadata`);
280285

281286
// Generate a dummy embedding (all zeros)
282-
const dummyEmbeddingSize = 3072; // Same size as your content embeddings
283-
const dummyEmbedding = new Array(dummyEmbeddingSize).fill(0);
287+
const dummyEmbedding = new Array(embeddingDimension).fill(0);
284288

285289
// Create a point with special metadata payload
286290
const metadataPoint = {

doc2vec.ts

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import {
2525
ZendeskSourceConfig,
2626
DatabaseConnection,
2727
DocumentChunk,
28-
BrokenLink
28+
BrokenLink,
29+
EmbeddingConfig
2930
} from './types';
3031

3132
const GITHUB_TOKEN = process.env.GITHUB_PERSONAL_ACCESS_TOKEN;
@@ -37,6 +38,7 @@ export class Doc2Vec {
3738
private config: Config;
3839
private openai: OpenAI | AzureOpenAI;
3940
private embeddingModel: string;
41+
private embeddingDimension: number;
4042
private contentProcessor: ContentProcessor;
4143
private logger: Logger;
4244
private configDir: string;
@@ -58,6 +60,7 @@ export class Doc2Vec {
5860
// Check environment variable if not specified in config
5961
const embeddingProvider = this.config.embedding?.provider || (process.env.EMBEDDING_PROVIDER as 'openai' | 'azure') || 'openai';
6062
const embeddingConfig = this.config.embedding || { provider: embeddingProvider };
63+
this.embeddingDimension = this.resolveEmbeddingDimension(embeddingConfig);
6164

6265
if (embeddingProvider === 'azure') {
6366
const azureApiKey = embeddingConfig.azure?.api_key || process.env.AZURE_OPENAI_KEY;
@@ -77,7 +80,7 @@ export class Doc2Vec {
7780
apiVersion: azureApiVersion,
7881
});
7982
this.embeddingModel = azureDeploymentName;
80-
this.logger.info(`Using Azure OpenAI with deployment: ${azureDeploymentName}`);
83+
this.logger.info(`Using Azure OpenAI with deployment: ${azureDeploymentName} (${this.embeddingDimension} dimensions)`);
8184
} else {
8285
const openaiApiKey = embeddingConfig.openai?.api_key || process.env.OPENAI_API_KEY;
8386
const openaiModel = embeddingConfig.openai?.model || process.env.OPENAI_MODEL || 'text-embedding-3-large';
@@ -89,7 +92,7 @@ export class Doc2Vec {
8992

9093
this.openai = new OpenAI({ apiKey: openaiApiKey });
9194
this.embeddingModel = openaiModel;
92-
this.logger.info(`Using OpenAI with model: ${openaiModel}`);
95+
this.logger.info(`Using OpenAI with model: ${openaiModel} (${this.embeddingDimension} dimensions)`);
9396
}
9497

9598
this.contentProcessor = new ContentProcessor(this.logger);
@@ -138,6 +141,25 @@ export class Doc2Vec {
138141
}
139142
}
140143

144+
private resolveEmbeddingDimension(embeddingConfig: EmbeddingConfig | undefined): number {
145+
const defaultDimension = 3072;
146+
const rawConfigValue = embeddingConfig?.dimension;
147+
const rawEnvValue = process.env.EMBEDDING_DIMENSION;
148+
149+
const candidate = rawConfigValue ?? (rawEnvValue ? Number(rawEnvValue) : undefined);
150+
if (candidate === undefined) {
151+
return defaultDimension;
152+
}
153+
154+
const parsedValue = typeof candidate === 'string' ? Number(candidate) : candidate;
155+
if (!Number.isFinite(parsedValue) || parsedValue <= 0 || !Number.isInteger(parsedValue)) {
156+
this.logger.warn(`Invalid embedding dimension provided (${candidate}), falling back to ${defaultDimension}`);
157+
return defaultDimension;
158+
}
159+
160+
return parsedValue;
161+
}
162+
141163
public async run(): Promise<void> {
142164
this.logger.section('PROCESSING SOURCES');
143165

@@ -388,7 +410,7 @@ export class Doc2Vec {
388410
}
389411

390412
// Update the last run date in the database after processing all issues
391-
await DatabaseManager.updateLastRunDate(dbConnection, repo, logger);
413+
await DatabaseManager.updateLastRunDate(dbConnection, repo, logger, this.embeddingDimension);
392414

393415
logger.info(`Successfully processed ${issues.length} issues`);
394416
}
@@ -397,7 +419,7 @@ export class Doc2Vec {
397419
const logger = parentLogger.child('process');
398420
logger.info(`Starting processing for GitHub repo: ${config.repo}`);
399421

400-
const dbConnection = await DatabaseManager.initDatabase(config, logger);
422+
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
401423

402424
// Initialize metadata storage
403425
await DatabaseManager.initDatabaseMetadata(dbConnection, logger);
@@ -414,8 +436,8 @@ export class Doc2Vec {
414436
const logger = parentLogger.child('process');
415437
logger.info(`Starting processing for website: ${config.url}`);
416438

417-
const dbConnection = await DatabaseManager.initDatabase(config, logger);
418-
await DatabaseManager.initDatabaseMetadata(dbConnection, logger);
439+
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
440+
await DatabaseManager.initDatabaseMetadata(dbConnection, logger);
419441
const validChunkIds: Set<string> = new Set();
420442
const visitedUrls: Set<string> = new Set();
421443
const urlPrefix = Utils.getUrlPrefix(config.url);
@@ -446,7 +468,7 @@ export class Doc2Vec {
446468
return DatabaseManager.getMetadataValue(dbConnection, `etag:${url}`, undefined, logger);
447469
},
448470
set: async (url: string, etag: string): Promise<void> => {
449-
await DatabaseManager.setMetadataValue(dbConnection, `etag:${url}`, etag, logger);
471+
await DatabaseManager.setMetadataValue(dbConnection, `etag:${url}`, etag, logger, this.embeddingDimension);
450472
},
451473
};
452474

@@ -455,7 +477,7 @@ export class Doc2Vec {
455477
return DatabaseManager.getMetadataValue(dbConnection, `lastmod:${url}`, undefined, logger);
456478
},
457479
set: async (url: string, lastmod: string): Promise<void> => {
458-
await DatabaseManager.setMetadataValue(dbConnection, `lastmod:${url}`, lastmod, logger);
480+
await DatabaseManager.setMetadataValue(dbConnection, `lastmod:${url}`, lastmod, logger, this.embeddingDimension);
459481
},
460482
};
461483

@@ -539,7 +561,7 @@ export class Doc2Vec {
539561
const logger = parentLogger.child('process');
540562
logger.info(`Starting processing for local directory: ${config.path}`);
541563
542-
const dbConnection = await DatabaseManager.initDatabase(config, logger);
564+
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
543565
const validChunkIds: Set<string> = new Set();
544566
const processedFiles: Set<string> = new Set();
545567
@@ -611,7 +633,7 @@ export class Doc2Vec {
611633
const logger = parentLogger.child('process');
612634
logger.info(`Starting processing for code source (${config.source})`);
613635
614-
const dbConnection = await DatabaseManager.initDatabase(config, logger);
636+
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
615637
const validChunkIds: Set<string> = new Set();
616638
const processedFiles: Set<string> = new Set();
617639
@@ -765,10 +787,10 @@ export class Doc2Vec {
765787
}
766788
}
767789
768-
await DatabaseManager.setMetadataValue(dbConnection, fileListKey, JSON.stringify(currentList), logger);
790+
await DatabaseManager.setMetadataValue(dbConnection, fileListKey, JSON.stringify(currentList), logger, this.embeddingDimension);
769791
if (lastMtimeKey) {
770792
const nextMtime = maxObservedMtime > 0 ? maxObservedMtime : Date.now();
771-
await DatabaseManager.setMetadataValue(dbConnection, lastMtimeKey, `${nextMtime}`, logger);
793+
await DatabaseManager.setMetadataValue(dbConnection, lastMtimeKey, `${nextMtime}`, logger, this.embeddingDimension);
772794
}
773795
}
774796
} else {
@@ -785,7 +807,7 @@ export class Doc2Vec {
785807
const headSha = await this.getRepoHeadSha(basePath, logger);
786808
if (headSha) {
787809
const shaKey = this.buildCodeShaMetadataKey(config.repo as string, repoBranch);
788-
await DatabaseManager.setMetadataValue(dbConnection, shaKey, headSha, logger);
810+
await DatabaseManager.setMetadataValue(dbConnection, shaKey, headSha, logger, this.embeddingDimension);
789811
}
790812
}
791813
@@ -974,7 +996,7 @@ export class Doc2Vec {
974996
const logger = parentLogger.child('process');
975997
logger.info(`Starting processing for Zendesk: ${config.zendesk_subdomain}.zendesk.com`);
976998

977-
const dbConnection = await DatabaseManager.initDatabase(config, logger);
999+
const dbConnection = await DatabaseManager.initDatabase(config, logger, this.embeddingDimension);
9781000

9791001
// Initialize metadata storage
9801002
await DatabaseManager.initDatabaseMetadata(dbConnection, logger);
@@ -1180,7 +1202,7 @@ export class Doc2Vec {
11801202
}
11811203

11821204
// Update the last run date in the database
1183-
await DatabaseManager.updateLastRunDate(dbConnection, `zendesk_tickets_${config.zendesk_subdomain}`, logger);
1205+
await DatabaseManager.updateLastRunDate(dbConnection, `zendesk_tickets_${config.zendesk_subdomain}`, logger, this.embeddingDimension);
11841206

11851207
logger.info(`Successfully processed ${totalTickets} tickets`);
11861208
}

package-lock.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "doc2vec",
3-
"version": "2.4.0",
3+
"version": "2.5.0",
44
"type": "commonjs",
55
"description": "",
66
"main": "dist/doc2vec.js",

0 commit comments

Comments
 (0)