step0a1_semantic_normalization.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. """
  2. Step 0b: Semantic text normalization analysis using embeddings and LLM.
  3. Identifies unclear terms, unknown acronyms, and ambiguous words.
  4. """
  5. from typing import List, Dict, Set, Tuple
  6. from collections import Counter
  7. import pandas as pd
  8. import numpy as np
  9. import re
  10. from sentence_transformers import SentenceTransformer
  11. from sklearn.metrics.pairwise import cosine_similarity
  12. from pipeline.models.base import PipelineStep
  13. from pipeline.utils.text_utils import normalize_text
  14. class SemanticNormalizationAnalyzer(PipelineStep):
  15. """
  16. Analyze text using semantic methods to identify:
  17. 1. Unclear/ambiguous terms (low semantic coherence)
  18. 2. Unknown acronyms (uppercase patterns not in dictionary)
  19. 3. Domain-specific jargon
  20. 4. Abbreviations needing expansion
  21. """
  22. def __init__(
  23. self,
  24. min_frequency: int = 3,
  25. coherence_threshold: float = 0.4,
  26. output_dir: str = "./pipeline_output",
  27. ):
  28. super().__init__(output_dir)
  29. self.min_frequency = min_frequency
  30. self.coherence_threshold = coherence_threshold
  31. self.logger.info("Loading embedding model: all-mpnet-base-v2...")
  32. self.embedding_model = SentenceTransformer("all-mpnet-base-v2")
  33. # Known medical/legal terms (high coherence expected)
  34. self.known_terms = {
  35. "doctor",
  36. "hospital",
  37. "treatment",
  38. "patient",
  39. "medical",
  40. "surgery",
  41. "appointment",
  42. "medication",
  43. "diagnosis",
  44. "procedure",
  45. "discrimination",
  46. "complaint",
  47. "lawsuit",
  48. "legal",
  49. "attorney",
  50. }
  51. # Known acronyms (to exclude from unknown list)
  52. self.known_acronyms = {
  53. "msk",
  54. "er",
  55. "icu",
  56. "ob",
  57. "gyn",
  58. "pcp",
  59. "np",
  60. "pa",
  61. "rn",
  62. "emr",
  63. "ehr",
  64. "hipaa",
  65. "lgbtq",
  66. "lgbt",
  67. "usa",
  68. "nyc",
  69. }
  70. def execute(self, df: pd.DataFrame) -> Dict[str, List[Dict]]:
  71. """
  72. Analyze text to identify unclear terms and unknown acronyms.
  73. Args:
  74. df: DataFrame with messages
  75. Returns:
  76. Dictionary with unclear terms, unknown acronyms, and suggestions
  77. """
  78. self.logger.info("=" * 80)
  79. self.logger.info("SEMANTIC TEXT NORMALIZATION ANALYSIS")
  80. self.logger.info("=" * 80)
  81. self.logger.info(f"Analyzing {len(df):,} messages")
  82. # Extract words with metadata
  83. self.logger.info("\nExtracting words and computing frequencies...")
  84. word_data = self._extract_word_data(df)
  85. self.logger.info(f"Found {len(word_data):,} unique words")
  86. # Identify unknown acronyms
  87. self.logger.info("\nIdentifying unknown acronyms...")
  88. unknown_acronyms = self._identify_unknown_acronyms(word_data)
  89. self.logger.info(f"Found {len(unknown_acronyms)} unknown acronyms")
  90. # Identify unclear terms using semantic coherence
  91. self.logger.info("\nAnalyzing semantic coherence for unclear terms...")
  92. unclear_terms = self._identify_unclear_terms(word_data, df)
  93. self.logger.info(f"Found {len(unclear_terms)} unclear terms")
  94. # Identify abbreviations
  95. self.logger.info("\nIdentifying abbreviations...")
  96. abbreviations = self._identify_abbreviations(word_data)
  97. self.logger.info(f"Found {len(abbreviations)} abbreviations")
  98. # Identify domain-specific jargon
  99. self.logger.info("\nIdentifying domain-specific jargon...")
  100. jargon = self._identify_jargon(word_data)
  101. self.logger.info(f"Found {len(jargon)} jargon terms")
  102. # Compile results
  103. results = {
  104. "unknown_acronyms": unknown_acronyms,
  105. "unclear_terms": unclear_terms,
  106. "abbreviations": abbreviations,
  107. "jargon": jargon,
  108. }
  109. # Save results
  110. self._save_normalization_analysis(results)
  111. return results
  112. def _extract_word_data(self, df: pd.DataFrame) -> Dict[str, Dict]:
  113. """Extract words with frequency and context"""
  114. word_data = {}
  115. for message in df["message"].fillna(""):
  116. text = str(message)
  117. # Extract words with original casing
  118. words = re.findall(r"\b[a-zA-Z][a-zA-Z0-9]*\b", text)
  119. for word in words:
  120. word_lower = word.lower()
  121. if word_lower not in word_data:
  122. word_data[word_lower] = {
  123. "word": word_lower,
  124. "frequency": 0,
  125. "original_forms": set(),
  126. "contexts": [],
  127. }
  128. word_data[word_lower]["frequency"] += 1
  129. word_data[word_lower]["original_forms"].add(word)
  130. # Store context (surrounding words)
  131. if len(word_data[word_lower]["contexts"]) < 5:
  132. # Get 5 words before and after
  133. word_index = text.lower().find(word_lower)
  134. if word_index != -1:
  135. start = max(0, word_index - 50)
  136. end = min(len(text), word_index + len(word_lower) + 50)
  137. context = text[start:end]
  138. word_data[word_lower]["contexts"].append(context)
  139. # Filter by minimum frequency
  140. word_data = {
  141. w: data
  142. for w, data in word_data.items()
  143. if data["frequency"] >= self.min_frequency
  144. }
  145. return word_data
  146. def _identify_unknown_acronyms(self, word_data: Dict) -> List[Dict]:
  147. """Identify potential unknown acronyms"""
  148. unknown_acronyms = []
  149. for word, data in word_data.items():
  150. # Check if it's an acronym pattern
  151. is_acronym = (
  152. len(word) >= 2
  153. and len(word) <= 6
  154. and any(form.isupper() for form in data["original_forms"])
  155. and word not in self.known_acronyms
  156. and data["frequency"] < 1500
  157. and not word.isdigit()
  158. )
  159. if is_acronym:
  160. unknown_acronyms.append(
  161. {
  162. "acronym": word.upper(),
  163. "frequency": data["frequency"],
  164. "contexts": data["contexts"][:3],
  165. "confidence": "high" if data["frequency"] >= 10 else "medium",
  166. }
  167. )
  168. # Sort by frequency
  169. unknown_acronyms.sort(key=lambda x: x["frequency"], reverse=True)
  170. return unknown_acronyms
  171. def _identify_unclear_terms(self, word_data: Dict, df: pd.DataFrame) -> List[Dict]:
  172. """Identify unclear terms using semantic coherence"""
  173. unclear_terms = []
  174. # Sample words for analysis (focus on medium frequency)
  175. candidate_words = [
  176. w
  177. for w, data in word_data.items()
  178. if 5 <= data["frequency"] <= 200
  179. and len(w) >= 4
  180. and w not in self.known_terms
  181. ]
  182. if not candidate_words:
  183. return unclear_terms
  184. self.logger.info(f" Analyzing {len(candidate_words)} candidate words...")
  185. # Compute embeddings for candidate words
  186. word_embeddings = self.embedding_model.encode(
  187. candidate_words, show_progress_bar=True, batch_size=32
  188. )
  189. # Compute embeddings for known terms
  190. known_embeddings = self.embedding_model.encode(
  191. list(self.known_terms), show_progress_bar=False
  192. )
  193. # Calculate semantic coherence (similarity to known terms)
  194. similarities = cosine_similarity(word_embeddings, known_embeddings)
  195. max_similarities = similarities.max(axis=1)
  196. # Identify words with low coherence
  197. for i, word in enumerate(candidate_words):
  198. coherence = float(max_similarities[i])
  199. if coherence < self.coherence_threshold:
  200. unclear_terms.append(
  201. {
  202. "term": word,
  203. "frequency": word_data[word]["frequency"],
  204. "coherence_score": coherence,
  205. "contexts": word_data[word]["contexts"][:3],
  206. "reason": "low_semantic_coherence",
  207. }
  208. )
  209. # Sort by coherence (lowest first)
  210. unclear_terms.sort(key=lambda x: x["coherence_score"])
  211. return unclear_terms[:200] # Top 200 most unclear
  212. def _identify_abbreviations(self, word_data: Dict) -> List[Dict]:
  213. """Identify potential abbreviations"""
  214. abbreviations = []
  215. # Common abbreviation patterns
  216. abbrev_patterns = [
  217. (r"^[a-z]{2,4}$", "short_word"), # 2-4 letter words
  218. (r"^[a-z]+\.$", "period_ending"), # Words ending in period
  219. (r"^[a-z]\d+$", "letter_number"), # Letter + number
  220. ]
  221. for word, data in word_data.items():
  222. for pattern, pattern_type in abbrev_patterns:
  223. if re.match(pattern, word):
  224. # Check if it has period in original forms
  225. has_period = any("." in form for form in data["original_forms"])
  226. if (has_period or pattern_type == "short_word") and data[
  227. "frequency"
  228. ] < 1500:
  229. abbreviations.append(
  230. {
  231. "abbreviation": word,
  232. "frequency": data["frequency"],
  233. "pattern_type": pattern_type,
  234. "contexts": data["contexts"][:2],
  235. }
  236. )
  237. break
  238. # Sort by frequency
  239. abbreviations.sort(key=lambda x: x["frequency"], reverse=True)
  240. return abbreviations[:100] # Top 100
  241. def _identify_jargon(self, word_data: Dict) -> List[Dict]:
  242. """Identify domain-specific jargon"""
  243. jargon = []
  244. # Jargon indicators
  245. jargon_indicators = {
  246. "medical": ["ology", "itis", "ectomy", "oscopy", "therapy"],
  247. "legal": ["tion", "ment", "ance", "ence"],
  248. "technical": ["tech", "system", "process", "protocol"],
  249. }
  250. for word, data in word_data.items():
  251. if len(word) < 6:
  252. continue
  253. # Check for jargon patterns
  254. for domain, suffixes in jargon_indicators.items():
  255. if any(word.endswith(suffix) for suffix in suffixes):
  256. if word not in self.known_terms:
  257. jargon.append(
  258. {
  259. "term": word,
  260. "frequency": data["frequency"],
  261. "domain": domain,
  262. "contexts": data["contexts"][:2],
  263. }
  264. )
  265. break
  266. # Sort by frequency
  267. jargon.sort(key=lambda x: x["frequency"], reverse=True)
  268. return jargon[:100] # Top 100
  269. def _save_normalization_analysis(self, results: Dict):
  270. """Save normalization analysis results"""
  271. # Save JSON
  272. json_results = {
  273. "method": "semantic_analysis",
  274. "statistics": {
  275. "unknown_acronyms": len(results["unknown_acronyms"]),
  276. "unclear_terms": len(results["unclear_terms"]),
  277. "abbreviations": len(results["abbreviations"]),
  278. "jargon": len(results["jargon"]),
  279. },
  280. "results": results,
  281. }
  282. self.save_results(json_results, "semantic_normalization_analysis.json")
  283. # Save human-readable text
  284. text_output = []
  285. text_output.append("SEMANTIC TEXT NORMALIZATION ANALYSIS")
  286. text_output.append("=" * 80)
  287. text_output.append("")
  288. text_output.append(
  289. "This analysis identifies terms that may need clarification or expansion."
  290. )
  291. text_output.append("")
  292. # Unknown acronyms
  293. text_output.append("=" * 80)
  294. text_output.append("UNKNOWN ACRONYMS (Need Investigation)")
  295. text_output.append("=" * 80)
  296. text_output.append("")
  297. if results["unknown_acronyms"]:
  298. text_output.append(
  299. f"{'Acronym':<15} {'Frequency':<12} {'Confidence':<12} {'Sample Context'}"
  300. )
  301. text_output.append("-" * 80)
  302. for item in results["unknown_acronyms"][:20]:
  303. context = item["contexts"][0][:50] if item["contexts"] else "N/A"
  304. text_output.append(
  305. f"{item['acronym']:<15} {item['frequency']:<12} "
  306. f"{item['confidence']:<12} {context}..."
  307. )
  308. else:
  309. text_output.append("No unknown acronyms found.")
  310. text_output.append("")
  311. # Unclear terms
  312. text_output.append("=" * 80)
  313. text_output.append("UNCLEAR TERMS (Low Semantic Coherence)")
  314. text_output.append("=" * 80)
  315. text_output.append("")
  316. text_output.append(
  317. "These terms have low semantic similarity to known medical/legal terms."
  318. )
  319. text_output.append(
  320. "They may be typos, slang, or domain-specific terms needing clarification."
  321. )
  322. text_output.append("")
  323. if results["unclear_terms"]:
  324. text_output.append(
  325. f"{'Term':<20} {'Frequency':<12} {'Coherence':<12} {'Sample Context'}"
  326. )
  327. text_output.append("-" * 80)
  328. for item in results["unclear_terms"][:20]:
  329. context = item["contexts"][0][:40] if item["contexts"] else "N/A"
  330. text_output.append(
  331. f"{item['term']:<20} {item['frequency']:<12} "
  332. f"{item['coherence_score']:<12.3f} {context}..."
  333. )
  334. else:
  335. text_output.append("No unclear terms found.")
  336. text_output.append("")
  337. # Abbreviations
  338. text_output.append("=" * 80)
  339. text_output.append("ABBREVIATIONS (May Need Expansion)")
  340. text_output.append("=" * 80)
  341. text_output.append("")
  342. if results["abbreviations"]:
  343. text_output.append(
  344. f"{'Abbreviation':<20} {'Frequency':<12} {'Pattern':<15} {'Context'}"
  345. )
  346. text_output.append("-" * 80)
  347. for item in results["abbreviations"][:15]:
  348. context = item["contexts"][0][:40] if item["contexts"] else "N/A"
  349. text_output.append(
  350. f"{item['abbreviation']:<20} {item['frequency']:<12} "
  351. f"{item['pattern_type']:<15} {context}..."
  352. )
  353. else:
  354. text_output.append("No abbreviations found.")
  355. text_output.append("")
  356. # Jargon
  357. text_output.append("=" * 80)
  358. text_output.append("DOMAIN-SPECIFIC JARGON")
  359. text_output.append("=" * 80)
  360. text_output.append("")
  361. if results["jargon"]:
  362. text_output.append(f"{'Term':<25} {'Frequency':<12} {'Domain':<15}")
  363. text_output.append("-" * 80)
  364. for item in results["jargon"][:15]:
  365. text_output.append(
  366. f"{item['term']:<25} {item['frequency']:<12} {item['domain']:<15}"
  367. )
  368. else:
  369. text_output.append("No jargon found.")
  370. text_output.append("")
  371. text_output.append("=" * 80)
  372. text_output.append("RECOMMENDATIONS")
  373. text_output.append("=" * 80)
  374. text_output.append("")
  375. text_output.append(
  376. "1. Investigate unknown acronyms - may be critical case-specific terms"
  377. )
  378. text_output.append("2. Review unclear terms - may be typos or need context")
  379. text_output.append("3. Expand abbreviations in TEXT_EXPANSIONS dictionary")
  380. text_output.append("4. Add jargon terms to KEY_TOPICS if relevant to case")
  381. filepath = self.output_dir / "semantic_normalization_analysis.txt"
  382. with open(filepath, "w") as f:
  383. f.write("\n".join(text_output))
  384. self.logger.info(f"\nSaved analysis to: {filepath}")
  385. if __name__ == "__main__":
  386. import pandas as pd
  387. df = pd.read_csv("../_sources/signal_messages.csv")
  388. analyzer = SemanticNormalizationAnalyzer(min_frequency=1, coherence_threshold=0.4)
  389. results = analyzer.execute(df)
  390. print("\nSemantic normalization analysis complete:")
  391. print(f" Unknown acronyms: {len(results['unknown_acronyms'])}")
  392. print(f" Unclear terms: {len(results['unclear_terms'])}")
  393. print(f" Abbreviations: {len(results['abbreviations'])}")
  394. print(f" Jargon: {len(results['jargon'])}")