Coverage for src / mcp_server_langgraph / core / storage / conversation_store.py: 66%

113 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 00:43 +0000

1""" 

2Lightweight conversation metadata store 

3 

4Tracks conversation metadata for search functionality without requiring OpenFGA. 

5Provides a fallback for development environments where OpenFGA isn't running. 

6""" 

7 

8import json 

9import time 

10from dataclasses import asdict, dataclass 

11 

12try: 

13 import redis 

14 from redis import Redis 

15 

16 REDIS_AVAILABLE = True 

17 # Note: types-redis >= 4.6.0.20241115 requires generic parameters 

18 # Using Redis[str] for decode_responses=True 

19 RedisType = Redis 

20except ImportError: 

21 REDIS_AVAILABLE = False 

22 RedisType = None # type: ignore[assignment,misc] 

23 

24 

25@dataclass 

26class ConversationMetadata: 

27 """Metadata for a conversation""" 

28 

29 thread_id: str 

30 user_id: str 

31 created_at: float # Unix timestamp 

32 last_activity: float # Unix timestamp 

33 message_count: int 

34 title: str | None = None 

35 tags: list[str] | None = None 

36 

37 def __post_init__(self) -> None: 

38 if self.tags is None: 

39 self.tags = [] 

40 

41 

42class ConversationStore: 

43 """ 

44 Store for conversation metadata. 

45 

46 Supports both in-memory and Redis backends for flexibility. 

47 Used as fallback when OpenFGA is not available. 

48 """ 

49 

50 def __init__( 

51 self, backend: str = "memory", redis_url: str = "redis://localhost:6379/2", ttl_seconds: int = 604800 

52 ) -> None: 

53 """ 

54 Initialize conversation store. 

55 

56 Args: 

57 backend: "memory" or "redis" 

58 redis_url: Redis connection URL (for redis backend) 

59 ttl_seconds: TTL for conversation metadata (default: 7 days) 

60 """ 

61 self.backend = backend.lower() 

62 self.ttl_seconds = ttl_seconds 

63 self._memory_store: dict[str, ConversationMetadata] = {} 

64 self._redis_client: Redis[str] | None = None # type: ignore[type-arg] 

65 

66 if self.backend == "redis": 

67 if not REDIS_AVAILABLE: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 msg = "Redis backend requires redis-py. Add 'redis' to pyproject.toml dependencies, then run: uv sync" 

69 raise ImportError(msg) 

70 

71 try: 

72 self._redis_client = redis.from_url(redis_url, decode_responses=True) # type: ignore[no-untyped-call] 

73 # Test connection 

74 self._redis_client.ping() 

75 except Exception as e: 

76 msg = f"Failed to connect to Redis at {redis_url}: {e}" 

77 raise ConnectionError(msg) from e 

78 

79 def _redis_key(self, thread_id: str) -> str: 

80 """Generate Redis key for conversation""" 

81 return f"conversation:metadata:{thread_id}" 

82 

83 async def record_conversation( 

84 self, 

85 thread_id: str, 

86 user_id: str, 

87 message_count: int = 1, 

88 title: str | None = None, 

89 tags: list[str] | None = None, 

90 ) -> None: 

91 """ 

92 Record or update conversation metadata. 

93 

94 Args: 

95 thread_id: Conversation thread ID 

96 user_id: User who owns the conversation 

97 message_count: Number of messages in conversation 

98 title: Optional conversation title 

99 tags: Optional tags for categorization 

100 """ 

101 now = time.time() 

102 

103 # Get existing metadata or create new 

104 existing = await self.get_conversation(thread_id) 

105 

106 if existing: 

107 # Update existing 

108 metadata = existing 

109 metadata.last_activity = now 

110 metadata.message_count = message_count 

111 if title: 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was always true

112 metadata.title = title 

113 if tags: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 metadata.tags = tags 

115 else: 

116 # Create new 

117 metadata = ConversationMetadata( 

118 thread_id=thread_id, 

119 user_id=user_id, 

120 created_at=now, 

121 last_activity=now, 

122 message_count=message_count, 

123 title=title, 

124 tags=tags or [], 

125 ) 

126 

127 # Store based on backend 

128 if self.backend == "redis" and self._redis_client: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 key = self._redis_key(thread_id) 

130 data = json.dumps(asdict(metadata)) 

131 self._redis_client.setex(key, self.ttl_seconds, data) 

132 else: 

133 self._memory_store[thread_id] = metadata 

134 

135 async def get_conversation(self, thread_id: str) -> ConversationMetadata | None: 

136 """ 

137 Get conversation metadata. 

138 

139 Args: 

140 thread_id: Conversation thread ID 

141 

142 Returns: 

143 ConversationMetadata or None if not found 

144 """ 

145 if self.backend == "redis" and self._redis_client: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 key = self._redis_key(thread_id) 

147 data = self._redis_client.get(key) 

148 if data: 

149 data_str = data if isinstance(data, str) else (data.decode("utf-8") if hasattr(data, "decode") else str(data)) 

150 return ConversationMetadata(**json.loads(data_str)) 

151 return None 

152 else: 

153 return self._memory_store.get(thread_id) 

154 

155 async def list_user_conversations(self, user_id: str, limit: int = 50) -> list[ConversationMetadata]: 

156 """ 

157 List all conversations for a user. 

158 

159 Args: 

160 user_id: User identifier 

161 limit: Maximum number of conversations to return 

162 

163 Returns: 

164 List of conversation metadata, sorted by last_activity (descending) 

165 """ 

166 if self.backend == "redis" and self._redis_client: 166 ↛ 168line 166 didn't jump to line 168 because the condition on line 166 was never true

167 # Scan all conversation keys 

168 pattern = self._redis_key("*") 

169 conversations = [] 

170 

171 for key in self._redis_client.scan_iter(match=pattern, count=100): 

172 data = self._redis_client.get(key) 

173 if data: 

174 data_str = ( 

175 data if isinstance(data, str) else (data.decode("utf-8") if hasattr(data, "decode") else str(data)) 

176 ) 

177 metadata = ConversationMetadata(**json.loads(data_str)) 

178 if metadata.user_id == user_id: 

179 conversations.append(metadata) 

180 

181 # Sort by last activity 

182 conversations.sort(key=lambda c: c.last_activity, reverse=True) 

183 return conversations[:limit] 

184 

185 else: 

186 # In-memory: filter and sort 

187 user_conversations = [c for c in self._memory_store.values() if c.user_id == user_id] 

188 user_conversations.sort(key=lambda c: c.last_activity, reverse=True) 

189 return user_conversations[:limit] 

190 

191 async def search_conversations(self, user_id: str, query: str, limit: int = 10) -> list[ConversationMetadata]: 

192 """ 

193 Search conversations for a user. 

194 

195 Args: 

196 user_id: User identifier 

197 query: Search query (searches in thread_id and title) 

198 limit: Maximum number of results 

199 

200 Returns: 

201 List of matching conversations, sorted by relevance/recency 

202 """ 

203 # Get all user conversations 

204 all_conversations = await self.list_user_conversations(user_id, limit=1000) 

205 

206 if not query: 

207 # No query: return most recent 

208 return all_conversations[:limit] 

209 

210 # Simple text matching (in production, use proper search index) 

211 query_lower = query.lower() 

212 matches = [] 

213 

214 for conv in all_conversations: 

215 # Search in thread_id and title 

216 if ( 

217 query_lower in conv.thread_id.lower() 

218 or (conv.title and query_lower in conv.title.lower()) 

219 or (conv.tags and any(query_lower in tag.lower() for tag in conv.tags)) 

220 ): 

221 matches.append(conv) 

222 

223 return matches[:limit] 

224 

225 async def delete_conversation(self, thread_id: str) -> bool: 

226 """ 

227 Delete conversation metadata. 

228 

229 Args: 

230 thread_id: Conversation thread ID 

231 

232 Returns: 

233 True if deleted, False if not found 

234 """ 

235 if self.backend == "redis" and self._redis_client: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 key = self._redis_key(thread_id) 

237 deleted = self._redis_client.delete(key) 

238 return int(deleted) > 0 # type: ignore[arg-type] 

239 else: 

240 if thread_id in self._memory_store: 

241 del self._memory_store[thread_id] 

242 return True 

243 return False 

244 

245 async def get_stats(self) -> dict[str, object]: 

246 """ 

247 Get store statistics. 

248 

249 Returns: 

250 Dictionary with store stats 

251 """ 

252 if self.backend == "redis" and self._redis_client: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 pattern = self._redis_key("*") 

254 count = sum(1 for _ in self._redis_client.scan_iter(match=pattern, count=100)) 

255 return {"backend": "redis", "conversation_count": count, "ttl_seconds": self.ttl_seconds} 

256 else: 

257 return {"backend": "memory", "conversation_count": len(self._memory_store), "ttl_seconds": None} 

258 

259 

260# Singleton instance 

261_conversation_store: ConversationStore | None = None 

262 

263 

264def get_conversation_store( 

265 backend: str = "memory", redis_url: str = "redis://localhost:6379/2", ttl_seconds: int = 604800 

266) -> ConversationStore: 

267 """ 

268 Get or create the conversation store singleton. 

269 

270 Args: 

271 backend: "memory" or "redis" 

272 redis_url: Redis connection URL 

273 ttl_seconds: TTL for conversation metadata 

274 

275 Returns: 

276 ConversationStore instance 

277 """ 

278 global _conversation_store 

279 

280 if _conversation_store is None: 

281 _conversation_store = ConversationStore(backend=backend, redis_url=redis_url, ttl_seconds=ttl_seconds) 

282 

283 return _conversation_store