openzeppelin_monitor/services/notification/
pool.rs

1use crate::services::blockchain::TransientErrorRetryStrategy;
2use crate::services::notification::SmtpConfig;
3use crate::utils::client_storage::ClientStorage;
4use crate::utils::{create_retryable_http_client, HttpRetryConfig};
5use lettre::transport::smtp::authentication::Credentials;
6use lettre::SmtpTransport;
7use reqwest::Client as ReqwestClient;
8use reqwest_middleware::ClientWithMiddleware;
9use std::sync::Arc;
10use std::time::Duration;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum NotificationPoolError {
15	#[error("Failed to create HTTP client: {0}")]
16	HttpClientBuildError(String),
17
18	#[error("Failed to create SMTP client: {0}")]
19	SmtpClientBuildError(String),
20}
21
22/// Notification client pool that manages HTTP and SMTP clients for sending notifications.
23///
24/// Provides a thread-safe way to access and create HTTP and SMTP clients
25/// for sending notifications. It uses a `ClientStorage` to hold the clients,
26/// allowing for efficient reuse and management of HTTP and SMTP connections.
27pub struct NotificationClientPool {
28	http_clients: ClientStorage<ClientWithMiddleware>,
29	smtp_clients: ClientStorage<SmtpTransport>,
30}
31
32impl NotificationClientPool {
33	pub fn new() -> Self {
34		Self {
35			http_clients: ClientStorage::new(),
36			smtp_clients: ClientStorage::new(),
37		}
38	}
39
40	/// Get or create an HTTP client with retry capabilities.
41	///
42	/// # Arguments
43	/// * `retry_policy` - Configuration for HTTP retry policy
44	/// # Returns
45	/// * `Result<Arc<ClientWithMiddleware>, NotificationPoolError>` - The HTTP client
46	///   wrapped in an `Arc` for shared ownership, or an error if client creation
47	///   fails.
48	pub async fn get_or_create_http_client(
49		&self,
50		retry_policy: &HttpRetryConfig,
51	) -> Result<Arc<ClientWithMiddleware>, NotificationPoolError> {
52		// Generate a unique key for the retry policy based on its configuration.
53		let key = format!("{:?}", retry_policy);
54
55		// Fast path: Read lock
56		if let Some(client) = self.http_clients.clients.read().await.get(key.as_str()) {
57			return Ok(client.clone());
58		}
59
60		// Slow path: Write lock
61		let mut clients = self.http_clients.clients.write().await;
62		// Double-check: Another thread might have created it
63		if let Some(client) = clients.get(&key) {
64			return Ok(client.clone());
65		}
66
67		// Create the new base client
68		let base_client = ReqwestClient::builder()
69			.pool_max_idle_per_host(10)
70			.pool_idle_timeout(Some(Duration::from_secs(90)))
71			.connect_timeout(Duration::from_secs(10))
72			.build()
73			.map_err(|e| NotificationPoolError::HttpClientBuildError(e.to_string()))?;
74
75		// Create the retryable client with the provided retry policy
76		let retryable_client = create_retryable_http_client(
77			retry_policy,
78			base_client,
79			Some(TransientErrorRetryStrategy),
80		);
81
82		let arc_client = Arc::new(retryable_client);
83		clients.insert(key.to_string(), arc_client.clone());
84		Ok(arc_client)
85	}
86
87	/// Get or create an SMTP client for sending emails.
88	/// # Arguments
89	/// * `smtp_config` - Configuration for the SMTP client, including host,
90	///   port, username, and password.
91	/// # Returns
92	/// * `Result<Arc<SmtpTransport>, NotificationPoolError>` - The SMTP client
93	///   wrapped in an `Arc` for shared ownership, or an error if client creation
94	///   fails.
95	pub async fn get_or_create_smtp_client(
96		&self,
97		smtp_config: &SmtpConfig,
98	) -> Result<Arc<SmtpTransport>, NotificationPoolError> {
99		// Generate a unique key for the retry policy based on its configuration.
100		let key = format!("{:?}", smtp_config);
101
102		// Fast path: Read lock to check for an existing client.
103		if let Some(client) = self.smtp_clients.clients.read().await.get(&key) {
104			return Ok(client.clone());
105		}
106
107		// Slow path: Write lock to create a new client if needed.
108		let mut clients = self.smtp_clients.clients.write().await;
109		// Double-check in case another thread created it while we waited for the lock.
110		if let Some(client) = clients.get(&key) {
111			return Ok(client.clone());
112		}
113
114		// Create the new SMTP client using the provided configuration.
115		let creds = Credentials::new(smtp_config.username.clone(), smtp_config.password.clone());
116		let client = SmtpTransport::relay(&smtp_config.host)
117			.map_err(|e| NotificationPoolError::SmtpClientBuildError(e.to_string()))?
118			.port(smtp_config.port)
119			.credentials(creds)
120			.build();
121
122		// Store the new client in the pool.
123		let arc_client = Arc::new(client);
124		clients.insert(key, arc_client.clone());
125
126		Ok(arc_client)
127	}
128
129	/// Get the number of active HTTP clients in the pool
130	#[cfg(test)]
131	pub async fn get_active_http_client_count(&self) -> usize {
132		self.http_clients.clients.read().await.len()
133	}
134
135	/// Get the number of active SMTP clients in the pool
136	#[cfg(test)]
137	pub async fn get_active_smtp_client_count(&self) -> usize {
138		self.smtp_clients.clients.read().await.len()
139	}
140}
141
142impl Default for NotificationClientPool {
143	fn default() -> Self {
144		Self::new()
145	}
146}
147
148#[cfg(test)]
149mod tests {
150	use super::*;
151
152	fn create_pool() -> NotificationClientPool {
153		NotificationClientPool::new()
154	}
155
156	#[tokio::test]
157	async fn test_pool_init_empty() {
158		let pool = create_pool();
159		let http_count = pool.get_active_http_client_count().await;
160		let smtp_count = pool.get_active_smtp_client_count().await;
161
162		assert_eq!(http_count, 0, "Pool should be empty initially");
163		assert_eq!(smtp_count, 0, "Pool should be empty initially");
164	}
165
166	#[tokio::test]
167	async fn test_pool_get_or_create_http_client() {
168		let pool = create_pool();
169		let retry_config = HttpRetryConfig::default();
170		let client = pool.get_or_create_http_client(&retry_config).await;
171
172		assert!(
173			client.is_ok(),
174			"Should successfully create or get HTTP client"
175		);
176
177		assert_eq!(
178			pool.get_active_http_client_count().await,
179			1,
180			"Pool should have one active HTTP client"
181		);
182	}
183
184	#[tokio::test]
185	async fn test_pool_returns_same_client() {
186		let pool = create_pool();
187		let retry_config = HttpRetryConfig::default();
188		let client1 = pool.get_or_create_http_client(&retry_config).await.unwrap();
189		let client2 = pool.get_or_create_http_client(&retry_config).await.unwrap();
190
191		assert!(
192			Arc::ptr_eq(&client1, &client2),
193			"Should return the same client instance"
194		);
195		assert_eq!(
196			pool.get_active_http_client_count().await,
197			1,
198			"Pool should still have one active HTTP client"
199		);
200	}
201
202	#[tokio::test]
203	async fn test_pool_concurrent_access() {
204		let pool = Arc::new(create_pool());
205		let retry_config = HttpRetryConfig::default();
206
207		let num_tasks = 10;
208		let mut tasks = Vec::new();
209
210		for _ in 0..num_tasks {
211			let pool_clone = Arc::clone(&pool);
212			let retry_config = retry_config.clone();
213			tasks.push(tokio::spawn(async move {
214				let client = pool_clone.get_or_create_http_client(&retry_config).await;
215				assert!(
216					client.is_ok(),
217					"Should successfully create or get HTTP client"
218				);
219			}));
220		}
221
222		let results = futures::future::join_all(tasks).await;
223
224		for result in results {
225			assert!(result.is_ok(), "All tasks should complete successfully");
226		}
227	}
228
229	#[tokio::test]
230	async fn test_pool_default() {
231		let pool = NotificationClientPool::default();
232		let retry_config = HttpRetryConfig::default();
233
234		assert_eq!(
235			pool.get_active_http_client_count().await,
236			0,
237			"Default pool should be empty initially"
238		);
239
240		assert_eq!(
241			pool.get_active_smtp_client_count().await,
242			0,
243			"Default pool should be empty initially"
244		);
245
246		let client = pool.get_or_create_http_client(&retry_config).await;
247
248		assert!(
249			client.is_ok(),
250			"Default pool should successfully create or get HTTP client"
251		);
252
253		assert_eq!(
254			pool.get_active_http_client_count().await,
255			1,
256			"Default pool should have one active HTTP client"
257		);
258	}
259
260	#[tokio::test]
261	async fn test_pool_returns_different_http_clients_for_different_configs() {
262		let pool = create_pool();
263
264		// Config 1 (default)
265		let retry_config_1 = HttpRetryConfig::default();
266
267		// Config 2 (different retry count)
268		let mut retry_config_2 = HttpRetryConfig::default();
269		retry_config_2.max_retries = 5;
270
271		// Get a client for each config
272		let client1 = pool
273			.get_or_create_http_client(&retry_config_1)
274			.await
275			.unwrap();
276		let client2 = pool
277			.get_or_create_http_client(&retry_config_2)
278			.await
279			.unwrap();
280
281		// Pointers should NOT be equal, as they are different clients
282		assert!(
283			!Arc::ptr_eq(&client1, &client2),
284			"Should return different client instances for different configurations"
285		);
286
287		// The pool should now contain two distinct clients
288		assert_eq!(
289			pool.get_active_http_client_count().await,
290			2,
291			"Pool should have two active HTTP clients"
292		);
293
294		// Getting the first client again should return the original one
295		let client1_again = pool
296			.get_or_create_http_client(&retry_config_1)
297			.await
298			.unwrap();
299		assert!(
300			Arc::ptr_eq(&client1, &client1_again),
301			"Should return the same client instance when called again with the same config"
302		);
303
304		// Pool size should still be 2
305		assert_eq!(
306			pool.get_active_http_client_count().await,
307			2,
308			"Pool should still have two active HTTP clients after getting an existing one"
309		);
310	}
311
312	#[tokio::test]
313	async fn test_pool_returns_different_smtp_clients_for_different_configs() {
314		let pool = create_pool();
315
316		// Config 1 (default)
317		let smtp_config_1 = SmtpConfig {
318			host: "smtp.example.com".to_string(),
319			port: 587,
320			username: "user1".to_string(),
321			password: "pass1".to_string(),
322		};
323
324		// Config 2 (different credentials)
325		let smtp_config_2 = SmtpConfig {
326			host: "smtp.example.com".to_string(),
327			port: 587,
328			username: "user2".to_string(),
329			password: "pass2".to_string(),
330		};
331
332		// Get a client for each config
333		let client1 = pool
334			.get_or_create_smtp_client(&smtp_config_1)
335			.await
336			.unwrap();
337		let client2 = pool
338			.get_or_create_smtp_client(&smtp_config_2)
339			.await
340			.unwrap();
341
342		// Pointers should NOT be equal, as they are different clients
343		assert!(
344			!Arc::ptr_eq(&client1, &client2),
345			"Should return different client instances for different configurations"
346		);
347
348		// The pool should now contain two distinct clients
349		assert_eq!(
350			pool.get_active_smtp_client_count().await,
351			2,
352			"Pool should have two active SMTP clients"
353		);
354
355		// Getting the first client again should return the original one
356		let client1_again = pool
357			.get_or_create_smtp_client(&smtp_config_1)
358			.await
359			.unwrap();
360
361		assert!(
362			Arc::ptr_eq(&client1, &client1_again),
363			"Should return the same client instance when called again with the same config"
364		);
365
366		assert_eq!(
367			pool.get_active_smtp_client_count().await,
368			2,
369			"Pool should still have two active SMTP clients after getting an existing one"
370		);
371	}
372}