Skip to content

Commit 216f30b

Browse files
authored
Merge pull request #125 from majiayu000/fix/mhc-empty-and-low-similarity-issues-7-8
fix: skip empty MHC items and add similarity threshold (closes #7, #8)
2 parents 8e1cebe + 5f3bbad commit 216f30b

3 files changed

Lines changed: 246 additions & 8 deletions

File tree

src/harmony/matching/default_matcher.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def match_instruments(
8080
texts_cached_vectors: dict[str, List[float]] = {}, batch_size: int = 1000, max_batches: int = 2000,
8181
is_negate: bool = True,
8282
clustering_algorithm: str = "affinity_propagation",
83-
num_clusters_for_kmeans: int = None
83+
num_clusters_for_kmeans: int = None,
84+
mhc_min_similarity: float = 0.0
8485
) -> MatchResult:
8586
for instrument in instruments:
8687
for question in instrument.questions:
@@ -98,5 +99,6 @@ def match_instruments(
9899
texts_cached_vectors=texts_cached_vectors,
99100
is_negate=is_negate,
100101
clustering_algorithm=clustering_algorithm,
101-
num_clusters_for_kmeans=num_clusters_for_kmeans
102+
num_clusters_for_kmeans=num_clusters_for_kmeans,
103+
mhc_min_similarity=mhc_min_similarity
102104
)

src/harmony/matching/matcher.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,8 @@ def match_instruments_with_function(
609609
texts_cached_vectors: dict[str, List[float]] = {},
610610
is_negate: bool = True,
611611
clustering_algorithm: ClusteringAlgorithm = ClusteringAlgorithm.affinity_propagation,
612-
num_clusters_for_kmeans: int = None
612+
num_clusters_for_kmeans: int = None,
613+
mhc_min_similarity: float = 0.0
613614
) -> MatchResult:
614615

615616
all_questions: List[Question] = []
@@ -675,17 +676,31 @@ def match_instruments_with_function(
675676
if vectors_pos.size > 0 and len(mhc_embeddings) > 0:
676677
similarities_mhc = cosine_similarity(vectors_pos, mhc_embeddings)
677678
ctrs = {}
678-
top_mhc_match_ids = np.argmax(similarities_mhc, axis=1)
679-
for idx, mhc_item_idx in enumerate(top_mhc_match_ids):
680-
question_text = mhc_questions[mhc_item_idx].question_text
681-
if not question_text or len(question_text.strip()) < 3:
679+
680+
# Build mask of valid MHC questions (non-empty text)
681+
valid_mhc_mask = np.array([
682+
bool(mhc_questions[i].question_text and len(mhc_questions[i].question_text.strip()) >= 3)
683+
for i in range(len(mhc_questions))
684+
])
685+
686+
for idx in range(len(all_questions)):
687+
# Get similarities for this question, masking out invalid MHC items
688+
masked_similarities = np.where(valid_mhc_mask, similarities_mhc[idx], -np.inf)
689+
mhc_item_idx = int(np.argmax(masked_similarities))
690+
strength_of_match = similarities_mhc[idx, mhc_item_idx]
691+
692+
# Skip if no valid MHC items or similarity is below threshold
693+
if masked_similarities[mhc_item_idx] == -np.inf:
682694
continue
695+
if strength_of_match < mhc_min_similarity:
696+
continue
697+
698+
question_text = mhc_questions[mhc_item_idx].question_text
683699
if all_questions[idx].instrument_id not in ctrs:
684700
ctrs[all_questions[idx].instrument_id] = Counter()
685701
for topic in mhc_all_metadatas[mhc_item_idx]["topics"]:
686702
ctrs[all_questions[idx].instrument_id][topic] += 1
687703
all_questions[idx].nearest_match_from_mhc_auto = mhc_questions[mhc_item_idx].model_dump()
688-
strength_of_match = similarities_mhc[idx, mhc_item_idx]
689704
all_questions[idx].topics_strengths = {topic: float(strength_of_match)}
690705

691706
instrument_to_category = {}

tests/test_mhc_filtering.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
'''
2+
MIT License
3+
4+
Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk).
5+
Project: Harmony (https://harmonydata.ac.uk)
6+
Maintainer: Thomas Wood (https://fastdatascience.com)
7+
8+
Permission is hereby granted, free of charge, to any person obtaining a copy
9+
of this software and associated documentation files (the "Software"), to deal
10+
in the Software without restriction, including without limitation the rights
11+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12+
copies of the Software, and to permit persons to whom the Software is
13+
furnished to do so, subject to the following conditions:
14+
15+
The above copyright notice and this permission notice shall be included in all
16+
copies or substantial portions of the Software.
17+
18+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24+
SOFTWARE.
25+
26+
'''
27+
28+
import sys
29+
import unittest
30+
31+
sys.path.append("../src")
32+
33+
import numpy as np
34+
from sentence_transformers import SentenceTransformer
35+
36+
from harmony import match_instruments
37+
from harmony.schemas.requests.text import Instrument, Question
38+
39+
40+
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
41+
42+
43+
def create_mhc_question_with_text(text):
44+
"""Create a Question object, bypassing validation for empty text (simulates MHC data)"""
45+
q = Question.model_construct(question_text=text)
46+
return q
47+
48+
49+
class TestMhcEmptyFiltering(unittest.TestCase):
50+
"""Tests for Issue #7: Remove empty items from MHC"""
51+
52+
def test_empty_mhc_questions_are_skipped(self):
53+
"""Verify that empty MHC questions are not matched"""
54+
questions_en = [
55+
Question(question_text="I feel anxious and worried about things"),
56+
Question(question_text="I have trouble sleeping at night"),
57+
]
58+
instrument_en = Instrument(questions=questions_en)
59+
60+
# MHC data with empty questions (using model_construct to bypass validation)
61+
mhc_metadata = [
62+
{'topics': ['anxiety']},
63+
{'topics': ['empty_topic']}, # This has empty text
64+
{'topics': ['sleep disorders']},
65+
]
66+
67+
mhc_questions = [
68+
create_mhc_question_with_text("Do you feel nervous or anxious?"),
69+
create_mhc_question_with_text(""), # Empty question - should be skipped
70+
create_mhc_question_with_text("Do you have difficulty sleeping?"),
71+
]
72+
73+
mhc_embeddings = model.encode(np.asarray([
74+
"Do you feel nervous or anxious?",
75+
"placeholder", # Will be masked out
76+
"Do you have difficulty sleeping?",
77+
]))
78+
79+
match_response = match_instruments(
80+
[instrument_en],
81+
mhc_questions=mhc_questions,
82+
mhc_embeddings=mhc_embeddings,
83+
mhc_all_metadatas=mhc_metadata,
84+
mhc_min_similarity=0.3
85+
)
86+
87+
self.assertEqual(2, len(match_response.questions))
88+
89+
# Verify no question matched to the empty MHC item
90+
for q in match_response.questions:
91+
if q.nearest_match_from_mhc_auto:
92+
self.assertNotEqual("", q.nearest_match_from_mhc_auto.get("question_text", ""))
93+
94+
def test_whitespace_only_mhc_questions_are_skipped(self):
95+
"""Verify that whitespace-only MHC questions are not matched"""
96+
questions_en = [Question(question_text="I feel depressed")]
97+
instrument_en = Instrument(questions=questions_en)
98+
99+
mhc_metadata = [
100+
{'topics': ['whitespace']},
101+
{'topics': ['depression']},
102+
]
103+
104+
mhc_questions = [
105+
create_mhc_question_with_text(" "), # Whitespace only - should be skipped
106+
create_mhc_question_with_text("Do you feel depressed or sad?"),
107+
]
108+
109+
mhc_embeddings = model.encode(np.asarray([
110+
"placeholder",
111+
"Do you feel depressed or sad?",
112+
]))
113+
114+
match_response = match_instruments(
115+
[instrument_en],
116+
mhc_questions=mhc_questions,
117+
mhc_embeddings=mhc_embeddings,
118+
mhc_all_metadatas=mhc_metadata,
119+
mhc_min_similarity=0.3
120+
)
121+
122+
# Should match to the valid depression question
123+
if match_response.questions[0].nearest_match_from_mhc_auto:
124+
matched_text = match_response.questions[0].nearest_match_from_mhc_auto.get("question_text", "")
125+
self.assertIn("depressed", matched_text.lower())
126+
127+
128+
class TestMhcSimilarityThreshold(unittest.TestCase):
129+
"""Tests for Issue #8: Don't match to MHC items if similarity is too low"""
130+
131+
def test_low_similarity_no_match(self):
132+
"""Verify that questions with low similarity to MHC are not matched"""
133+
# Unrelated question
134+
questions = [Question(question_text="I lost my car keys")]
135+
instrument = Instrument(questions=questions)
136+
137+
mhc_metadata = [
138+
{'topics': ['eating disorders']},
139+
{'topics': ['anxiety']},
140+
]
141+
142+
mhc_questions_as_text = [
143+
"Do you worry about your weight?",
144+
"Do you feel anxious?",
145+
]
146+
147+
mhc_embeddings = model.encode(np.asarray(mhc_questions_as_text))
148+
mhc_questions = [Question(question_text=t) for t in mhc_questions_as_text]
149+
150+
# Use high threshold to ensure no match
151+
match_response = match_instruments(
152+
[instrument],
153+
mhc_questions=mhc_questions,
154+
mhc_embeddings=mhc_embeddings,
155+
mhc_all_metadatas=mhc_metadata,
156+
mhc_min_similarity=0.8
157+
)
158+
159+
# Should not have MHC match due to low similarity
160+
self.assertIsNone(match_response.questions[0].nearest_match_from_mhc_auto)
161+
162+
def test_high_similarity_match(self):
163+
"""Verify that questions with high similarity to MHC are matched"""
164+
questions = [Question(question_text="I feel nervous and anxious")]
165+
instrument = Instrument(questions=questions)
166+
167+
mhc_metadata = [
168+
{'topics': ['anxiety']},
169+
]
170+
171+
mhc_questions_as_text = [
172+
"Do you feel nervous or anxious?",
173+
]
174+
175+
mhc_embeddings = model.encode(np.asarray(mhc_questions_as_text))
176+
mhc_questions = [Question(question_text=t) for t in mhc_questions_as_text]
177+
178+
# Use low threshold to allow match
179+
match_response = match_instruments(
180+
[instrument],
181+
mhc_questions=mhc_questions,
182+
mhc_embeddings=mhc_embeddings,
183+
mhc_all_metadatas=mhc_metadata,
184+
mhc_min_similarity=0.3
185+
)
186+
187+
# Should have MHC match
188+
self.assertIsNotNone(match_response.questions[0].nearest_match_from_mhc_auto)
189+
190+
def test_threshold_filters_unrelated(self):
191+
"""Verify that mhc_min_similarity threshold filters unrelated questions"""
192+
# Completely unrelated question about cooking
193+
questions = [Question(question_text="How do I make a chocolate cake?")]
194+
instrument = Instrument(questions=questions)
195+
196+
mhc_metadata = [
197+
{'topics': ['depression']},
198+
]
199+
200+
mhc_questions_as_text = [
201+
"Have you felt hopeless about the future?",
202+
]
203+
204+
mhc_embeddings = model.encode(np.asarray(mhc_questions_as_text))
205+
mhc_questions = [Question(question_text=t) for t in mhc_questions_as_text]
206+
207+
# Use explicit threshold of 0.5 to filter unrelated
208+
match_response = match_instruments(
209+
[instrument],
210+
mhc_questions=mhc_questions,
211+
mhc_embeddings=mhc_embeddings,
212+
mhc_all_metadatas=mhc_metadata,
213+
mhc_min_similarity=0.5
214+
)
215+
216+
# Should not have MHC match - cooking and depression are unrelated
217+
self.assertIsNone(match_response.questions[0].nearest_match_from_mhc_auto)
218+
219+
220+
if __name__ == '__main__':
221+
unittest.main()

0 commit comments

Comments
 (0)