step0a_semantic_normalization.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. """
  2. Step 0b: Semantic text normalization analysis using embeddings and LLM.
  3. Identifies unclear terms, unknown acronyms, and ambiguous words.
  4. """
  5. from typing import List, Dict, Set, Tuple
  6. from collections import Counter
  7. import pandas as pd
  8. import numpy as np
  9. import re
  10. from sentence_transformers import SentenceTransformer
  11. from sklearn.metrics.pairwise import cosine_similarity
  12. from pipeline.models.base import PipelineStep
  13. from pipeline.utils.text_utils import normalize_text
  14. class SemanticNormalizationAnalyzer(PipelineStep):
  15. """
  16. Analyze text using semantic methods to identify:
  17. 1. Unclear/ambiguous terms (low semantic coherence)
  18. 2. Unknown acronyms (uppercase patterns not in dictionary)
  19. 3. Domain-specific jargon
  20. 4. Abbreviations needing expansion
  21. """
  22. def __init__(
  23. self,
  24. min_frequency: int = 3,
  25. coherence_threshold: float = 0.4,
  26. output_dir: str = "./pipeline_output",
  27. ):
  28. super().__init__(output_dir)
  29. self.min_frequency = min_frequency
  30. self.coherence_threshold = coherence_threshold
  31. self.logger.info("Loading embedding model: all-mpnet-base-v2...")
  32. self.embedding_model = SentenceTransformer("all-mpnet-base-v2")
  33. # Known medical/legal terms (high coherence expected)
  34. self.known_terms = {
  35. "doctor",
  36. "hospital",
  37. "treatment",
  38. "patient",
  39. "medical",
  40. "surgery",
  41. "appointment",
  42. "medication",
  43. "diagnosis",
  44. "procedure",
  45. "discrimination",
  46. "complaint",
  47. "lawsuit",
  48. "legal",
  49. "attorney",
  50. }
  51. # Known acronyms (to exclude from unknown list)
  52. self.known_acronyms = {
  53. "msk",
  54. "er",
  55. "icu",
  56. "ob",
  57. "gyn",
  58. "pcp",
  59. "np",
  60. "pa",
  61. "rn",
  62. "emr",
  63. "ehr",
  64. "hipaa",
  65. "lgbtq",
  66. "lgbt",
  67. "usa",
  68. "nyc",
  69. }
  70. def execute(self, df: pd.DataFrame) -> Dict[str, List[Dict]]:
  71. """
  72. Analyze text to identify unclear terms and unknown acronyms.
  73. Args:
  74. df: DataFrame with messages
  75. Returns:
  76. Dictionary with unclear terms, unknown acronyms, and suggestions
  77. """
  78. self.logger.info("=" * 80)
  79. self.logger.info("SEMANTIC TEXT NORMALIZATION ANALYSIS")
  80. self.logger.info("=" * 80)
  81. self.logger.info(f"Analyzing {len(df):,} messages")
  82. # Extract words with metadata
  83. self.logger.info("\\nExtracting words and computing frequencies...")
  84. word_data = self._extract_word_data(df)
  85. self.logger.info(f"Found {len(word_data):,} unique words")
  86. # Identify unknown acronyms
  87. self.logger.info("\\nIdentifying unknown acronyms...")
  88. unknown_acronyms = self._identify_unknown_acronyms(word_data)
  89. self.logger.info(f"Found {len(unknown_acronyms)} unknown acronyms")
  90. # Identify unclear terms using semantic coherence
  91. self.logger.info("\\nAnalyzing semantic coherence for unclear terms...")
  92. unclear_terms = self._identify_unclear_terms(word_data, df)
  93. self.logger.info(f"Found {len(unclear_terms)} unclear terms")
  94. # Identify abbreviations
  95. self.logger.info("\\nIdentifying abbreviations...")
  96. abbreviations = self._identify_abbreviations(word_data)
  97. self.logger.info(f"Found {len(abbreviations)} abbreviations")
  98. # Identify domain-specific jargon
  99. self.logger.info("\\nIdentifying domain-specific jargon...")
  100. jargon = self._identify_jargon(word_data)
  101. self.logger.info(f"Found {len(jargon)} jargon terms")
  102. # Compile results
  103. results = {
  104. "unknown_acronyms": unknown_acronyms,
  105. "unclear_terms": unclear_terms,
  106. "abbreviations": abbreviations,
  107. "jargon": jargon,
  108. }
  109. # Save results
  110. self._save_normalization_analysis(results)
  111. return results
  112. def _extract_word_data(self, df: pd.DataFrame) -> Dict[str, Dict]:
  113. """Extract words with frequency and context"""
  114. word_data = {}
  115. for message in df["message"].fillna(""):
  116. text = str(message)
  117. # Extract words with original casing
  118. words = re.findall(r"\\b[a-zA-Z][a-zA-Z0-9]*\\b", text)
  119. for word in words:
  120. word_lower = word.lower()
  121. if word_lower not in word_data:
  122. word_data[word_lower] = {
  123. "word": word_lower,
  124. "frequency": 0,
  125. "original_forms": set(),
  126. "contexts": [],
  127. }
  128. word_data[word_lower]["frequency"] += 1
  129. word_data[word_lower]["original_forms"].add(word)
  130. # Store context (surrounding words)
  131. if len(word_data[word_lower]["contexts"]) < 5:
  132. # Get 5 words before and after
  133. word_index = text.lower().find(word_lower)
  134. if word_index != -1:
  135. start = max(0, word_index - 50)
  136. end = min(len(text), word_index + len(word_lower) + 50)
  137. context = text[start:end]
  138. word_data[word_lower]["contexts"].append(context)
  139. # Filter by minimum frequency
  140. word_data = {
  141. w: data
  142. for w, data in word_data.items()
  143. if data["frequency"] >= self.min_frequency
  144. }
  145. return word_data
  146. def _identify_unknown_acronyms(self, word_data: Dict) -> List[Dict]:
  147. """Identify potential unknown acronyms"""
  148. unknown_acronyms = []
  149. for word, data in word_data.items():
  150. # Check if it's an acronym pattern
  151. is_acronym = (
  152. len(word) >= 2
  153. and len(word) <= 6
  154. and word.upper() in data["original_forms"]
  155. and word not in self.known_acronyms
  156. and not word.isdigit()
  157. )
  158. if is_acronym:
  159. unknown_acronyms.append(
  160. {
  161. "acronym": word.upper(),
  162. "frequency": data["frequency"],
  163. "contexts": data["contexts"][:3],
  164. "confidence": "high" if data["frequency"] >= 10 else "medium",
  165. }
  166. )
  167. # Sort by frequency
  168. unknown_acronyms.sort(key=lambda x: x["frequency"], reverse=True)
  169. return unknown_acronyms
  170. def _identify_unclear_terms(self, word_data: Dict, df: pd.DataFrame) -> List[Dict]:
  171. """Identify unclear terms using semantic coherence"""
  172. unclear_terms = []
  173. # Sample words for analysis (focus on medium frequency)
  174. candidate_words = [
  175. w
  176. for w, data in word_data.items()
  177. if 5 <= data["frequency"] <= 100
  178. and len(w) >= 4
  179. and w not in self.known_terms
  180. ]
  181. if not candidate_words:
  182. return unclear_terms
  183. self.logger.info(f" Analyzing {len(candidate_words)} candidate words...")
  184. # Compute embeddings for candidate words
  185. word_embeddings = self.embedding_model.encode(
  186. candidate_words, show_progress_bar=True, batch_size=32
  187. )
  188. # Compute embeddings for known terms
  189. known_embeddings = self.embedding_model.encode(
  190. list(self.known_terms), show_progress_bar=False
  191. )
  192. # Calculate semantic coherence (similarity to known terms)
  193. similarities = cosine_similarity(word_embeddings, known_embeddings)
  194. max_similarities = similarities.max(axis=1)
  195. # Identify words with low coherence
  196. for i, word in enumerate(candidate_words):
  197. coherence = float(max_similarities[i])
  198. if coherence < self.coherence_threshold:
  199. unclear_terms.append(
  200. {
  201. "term": word,
  202. "frequency": word_data[word]["frequency"],
  203. "coherence_score": coherence,
  204. "contexts": word_data[word]["contexts"][:3],
  205. "reason": "low_semantic_coherence",
  206. }
  207. )
  208. # Sort by coherence (lowest first)
  209. unclear_terms.sort(key=lambda x: x["coherence_score"])
  210. return unclear_terms[:50] # Top 50 most unclear
  211. def _identify_abbreviations(self, word_data: Dict) -> List[Dict]:
  212. """Identify potential abbreviations"""
  213. abbreviations = []
  214. # Common abbreviation patterns
  215. abbrev_patterns = [
  216. (r"^[a-z]{2,4}$", "short_word"), # 2-4 letter words
  217. (r"^[a-z]+\\.$", "period_ending"), # Words ending in period
  218. (r"^[a-z]\\d+$", "letter_number"), # Letter + number
  219. ]
  220. for word, data in word_data.items():
  221. for pattern, pattern_type in abbrev_patterns:
  222. if re.match(pattern, word):
  223. # Check if it has period in original forms
  224. has_period = any("." in form for form in data["original_forms"])
  225. if has_period or pattern_type == "short_word":
  226. abbreviations.append(
  227. {
  228. "abbreviation": word,
  229. "frequency": data["frequency"],
  230. "pattern_type": pattern_type,
  231. "contexts": data["contexts"][:2],
  232. }
  233. )
  234. break
  235. # Sort by frequency
  236. abbreviations.sort(key=lambda x: x["frequency"], reverse=True)
  237. return abbreviations[:30] # Top 30
  238. def _identify_jargon(self, word_data: Dict) -> List[Dict]:
  239. """Identify domain-specific jargon"""
  240. jargon = []
  241. # Jargon indicators
  242. jargon_indicators = {
  243. "medical": ["ology", "itis", "ectomy", "oscopy", "therapy"],
  244. "legal": ["tion", "ment", "ance", "ence"],
  245. "technical": ["tech", "system", "process", "protocol"],
  246. }
  247. for word, data in word_data.items():
  248. if len(word) < 6:
  249. continue
  250. # Check for jargon patterns
  251. for domain, suffixes in jargon_indicators.items():
  252. if any(word.endswith(suffix) for suffix in suffixes):
  253. if word not in self.known_terms:
  254. jargon.append(
  255. {
  256. "term": word,
  257. "frequency": data["frequency"],
  258. "domain": domain,
  259. "contexts": data["contexts"][:2],
  260. }
  261. )
  262. break
  263. # Sort by frequency
  264. jargon.sort(key=lambda x: x["frequency"], reverse=True)
  265. return jargon[:20] # Top 20
  266. def _save_normalization_analysis(self, results: Dict):
  267. """Save normalization analysis results"""
  268. # Save JSON
  269. json_results = {
  270. "method": "semantic_analysis",
  271. "statistics": {
  272. "unknown_acronyms": len(results["unknown_acronyms"]),
  273. "unclear_terms": len(results["unclear_terms"]),
  274. "abbreviations": len(results["abbreviations"]),
  275. "jargon": len(results["jargon"]),
  276. },
  277. "results": results,
  278. }
  279. self.save_results(json_results, "semantic_normalization_analysis.json")
  280. # Save human-readable text
  281. text_output = []
  282. text_output.append("SEMANTIC TEXT NORMALIZATION ANALYSIS")
  283. text_output.append("=" * 80)
  284. text_output.append("")
  285. text_output.append(
  286. "This analysis identifies terms that may need clarification or expansion."
  287. )
  288. text_output.append("")
  289. # Unknown acronyms
  290. text_output.append("=" * 80)
  291. text_output.append("UNKNOWN ACRONYMS (Need Investigation)")
  292. text_output.append("=" * 80)
  293. text_output.append("")
  294. if results["unknown_acronyms"]:
  295. text_output.append(
  296. f"{'Acronym':<15} {'Frequency':<12} {'Confidence':<12} {'Sample Context'}"
  297. )
  298. text_output.append("-" * 80)
  299. for item in results["unknown_acronyms"][:20]:
  300. context = item["contexts"][0][:50] if item["contexts"] else "N/A"
  301. text_output.append(
  302. f"{item['acronym']:<15} {item['frequency']:<12} "
  303. f"{item['confidence']:<12} {context}..."
  304. )
  305. else:
  306. text_output.append("No unknown acronyms found.")
  307. text_output.append("")
  308. # Unclear terms
  309. text_output.append("=" * 80)
  310. text_output.append("UNCLEAR TERMS (Low Semantic Coherence)")
  311. text_output.append("=" * 80)
  312. text_output.append("")
  313. text_output.append(
  314. "These terms have low semantic similarity to known medical/legal terms."
  315. )
  316. text_output.append(
  317. "They may be typos, slang, or domain-specific terms needing clarification."
  318. )
  319. text_output.append("")
  320. if results["unclear_terms"]:
  321. text_output.append(
  322. f"{'Term':<20} {'Frequency':<12} {'Coherence':<12} {'Sample Context'}"
  323. )
  324. text_output.append("-" * 80)
  325. for item in results["unclear_terms"][:20]:
  326. context = item["contexts"][0][:40] if item["contexts"] else "N/A"
  327. text_output.append(
  328. f"{item['term']:<20} {item['frequency']:<12} "
  329. f"{item['coherence_score']:<12.3f} {context}..."
  330. )
  331. else:
  332. text_output.append("No unclear terms found.")
  333. text_output.append("")
  334. # Abbreviations
  335. text_output.append("=" * 80)
  336. text_output.append("ABBREVIATIONS (May Need Expansion)")
  337. text_output.append("=" * 80)
  338. text_output.append("")
  339. if results["abbreviations"]:
  340. text_output.append(
  341. f"{'Abbreviation':<20} {'Frequency':<12} {'Pattern':<15} {'Context'}"
  342. )
  343. text_output.append("-" * 80)
  344. for item in results["abbreviations"][:15]:
  345. context = item["contexts"][0][:40] if item["contexts"] else "N/A"
  346. text_output.append(
  347. f"{item['abbreviation']:<20} {item['frequency']:<12} "
  348. f"{item['pattern_type']:<15} {context}..."
  349. )
  350. else:
  351. text_output.append("No abbreviations found.")
  352. text_output.append("")
  353. # Jargon
  354. text_output.append("=" * 80)
  355. text_output.append("DOMAIN-SPECIFIC JARGON")
  356. text_output.append("=" * 80)
  357. text_output.append("")
  358. if results["jargon"]:
  359. text_output.append(f"{'Term':<25} {'Frequency':<12} {'Domain':<15}")
  360. text_output.append("-" * 80)
  361. for item in results["jargon"][:15]:
  362. text_output.append(
  363. f"{item['term']:<25} {item['frequency']:<12} {item['domain']:<15}"
  364. )
  365. else:
  366. text_output.append("No jargon found.")
  367. text_output.append("")
  368. text_output.append("=" * 80)
  369. text_output.append("RECOMMENDATIONS")
  370. text_output.append("=" * 80)
  371. text_output.append("")
  372. text_output.append(
  373. "1. Investigate unknown acronyms - may be critical case-specific terms"
  374. )
  375. text_output.append("2. Review unclear terms - may be typos or need context")
  376. text_output.append("3. Expand abbreviations in TEXT_EXPANSIONS dictionary")
  377. text_output.append("4. Add jargon terms to KEY_TOPICS if relevant to case")
  378. filepath = self.output_dir / "semantic_normalization_analysis.txt"
  379. with open(filepath, "w") as f:
  380. f.write("\\n".join(text_output))
  381. self.logger.info(f"\\nSaved analysis to: {filepath}")
  382. if __name__ == "__main__":
  383. import pandas as pd
  384. df = pd.read_csv("../_sources/signal_messages.csv")
  385. analyzer = SemanticNormalizationAnalyzer(min_frequency=2, coherence_threshold=0.4)
  386. results = analyzer.execute(df)
  387. print("\\nSemantic normalization analysis complete:")
  388. print(f" Unknown acronyms: {len(results['unknown_acronyms'])}")
  389. print(f" Unclear terms: {len(results['unclear_terms'])}")
  390. print(f" Abbreviations: {len(results['abbreviations'])}")
  391. print(f" Jargon: {len(results['jargon'])}")