step0a_semantic_keyword_identification.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """
  2. Step 0a: Semantic keyword identification using embeddings.
  3. Identifies keywords semantically related to subpoena criteria.
  4. """
  5. from typing import List, Dict, Set, Tuple
  6. from collections import Counter
  7. import pandas as pd
  8. import numpy as np
  9. from sentence_transformers import SentenceTransformer
  10. from sklearn.metrics.pairwise import cosine_similarity
  11. from pipeline.models.base import PipelineStep
  12. from pipeline.common_defs import SUBPOENA_CRITERIA
  13. from pipeline.utils.text_utils import normalize_text
  14. class SemanticKeywordIdentifier(PipelineStep):
  15. """
  16. Identify keywords semantically related to subpoena criteria.
  17. Uses embedding similarity rather than frequency.
  18. """
  19. def __init__(
  20. self,
  21. similarity_threshold: float = 0.25,
  22. max_keywords_per_criterion: int = 80,
  23. min_word_length: int = 3,
  24. output_dir: str = "./pipeline_output",
  25. ):
  26. super().__init__(output_dir)
  27. self.similarity_threshold = similarity_threshold
  28. self.max_keywords_per_criterion = max_keywords_per_criterion
  29. self.min_word_length = min_word_length
  30. self.logger.info("Loading embedding model: all-mpnet-base-v2...")
  31. self.embedding_model = SentenceTransformer("all-mpnet-base-v2")
  32. def _load_embedding_model(self):
  33. """Load sentence transformer model"""
  34. return
  35. # if self.embedding_model is None:
  36. # self.logger.info("Loading embedding model: all-mpnet-base-v2...")
  37. # self.embedding_model = SentenceTransformer("all-mpnet-base-v2")
  38. def execute(self, df: pd.DataFrame) -> Dict[str, List[Dict]]:
  39. """Identify keywords semantically related to subpoena criteria"""
  40. self.logger.info("SEMANTIC KEYWORD IDENTIFICATION")
  41. self.logger.info(f"Analyzing {len(df):,} messages")
  42. self._load_embedding_model()
  43. # Extract unique words
  44. unique_words = self._extract_unique_words(df)
  45. self.logger.info(f"Found {len(unique_words):,} unique words")
  46. suspicious = [
  47. w for w in unique_words if w.startswith("medical") and len(w) > 10
  48. ]
  49. if suspicious:
  50. self.logger.error(f"SUSPICIOUS WORDS IN EXTRACTION: {suspicious}")
  51. self.logger.info(f"Found {len(unique_words):,} unique words")
  52. # Create criteria descriptions
  53. criteria_descriptions = self._create_criteria_descriptions()
  54. # Compute embeddings
  55. word_embeddings = self._compute_word_embeddings(unique_words)
  56. criteria_embeddings = self._compute_criteria_embeddings(criteria_descriptions)
  57. # Find similar keywords
  58. keywords_by_criterion = self._find_similar_keywords(
  59. unique_words, word_embeddings, criteria_descriptions, criteria_embeddings
  60. )
  61. # Add frequency info
  62. word_freq = self._compute_word_frequencies(df)
  63. keywords_by_criterion = self._add_frequency_info(keywords_by_criterion, word_freq)
  64. # Save results
  65. self._save_semantic_keywords(keywords_by_criterion, criteria_descriptions)
  66. return keywords_by_criterion
  67. def _extract_unique_words(self, df: pd.DataFrame) -> List[str]:
  68. """Extract unique words from messages"""
  69. words = set()
  70. for message in df["message"].fillna(""):
  71. normalized = normalize_text(str(message))
  72. tokens = [t for t in normalized.split() if len(t) >= self.min_word_length and t.isalpha()]
  73. words.update(tokens)
  74. return sorted(list(words))
  75. def _create_criteria_descriptions(self) -> Dict[int, str]:
  76. """Create detailed descriptions for each criterion"""
  77. return SUBPOENA_CRITERIA
  78. def _compute_word_embeddings(self, words: List[str]) -> np.ndarray:
  79. """Compute embeddings for words"""
  80. self.logger.info(f"Computing embeddings for {len(words):,} words...")
  81. return self.embedding_model.encode(words, show_progress_bar=True, batch_size=32)
  82. def _compute_criteria_embeddings(self, criteria_descriptions: Dict[int, str]) -> Dict[int, np.ndarray]:
  83. """Compute embeddings for criteria"""
  84. embeddings = {}
  85. for num, desc in criteria_descriptions.items():
  86. embeddings[num] = self.embedding_model.encode([desc])[0]
  87. return embeddings
  88. def _find_similar_keywords(self, words, word_embeddings, criteria_descriptions, criteria_embeddings):
  89. """Find keywords similar to each criterion"""
  90. keywords_by_criterion = {}
  91. for num, emb in criteria_embeddings.items():
  92. similarities = cosine_similarity(word_embeddings, emb.reshape(1, -1)).flatten()
  93. similar_indices = np.where(similarities >= self.similarity_threshold)[0]
  94. similar_indices = similar_indices[np.argsort(-similarities[similar_indices])]
  95. similar_indices = similar_indices[:self.max_keywords_per_criterion]
  96. keywords_by_criterion[num] = [
  97. {"word": words[idx], "similarity": float(similarities[idx]), "frequency": 0}
  98. for idx in similar_indices
  99. ]
  100. self.logger.info(f"Criterion {num}: {len(keywords_by_criterion[num])} keywords")
  101. return keywords_by_criterion
  102. def _compute_word_frequencies(self, df: pd.DataFrame) -> Counter:
  103. """Compute word frequencies"""
  104. word_freq = Counter()
  105. for message in df["message"].fillna(""):
  106. normalized = normalize_text(str(message))
  107. tokens = [t for t in normalized.split() if len(t) >= self.min_word_length and t.isalpha()]
  108. word_freq.update(tokens)
  109. return word_freq
  110. def _add_frequency_info(self, keywords_by_criterion, word_freq):
  111. """Add frequency information"""
  112. for keywords in keywords_by_criterion.values():
  113. for kw in keywords:
  114. kw["frequency"] = word_freq.get(kw["word"], 0)
  115. return keywords_by_criterion
  116. def _save_semantic_keywords(self, keywords_by_criterion, criteria_descriptions):
  117. """Save results"""
  118. results = {
  119. "method": "semantic_similarity",
  120. "criteria": {str(n): {"keywords": k} for n, k in keywords_by_criterion.items()}
  121. }
  122. self.save_results(results, "semantic_keywords.json")
  123. self.logger.info("Saved semantic keywords")