Logging when using multiprocessing with joblib

Snowhow1
New Contributor II

Hi,

I'm using joblib for multiprocessing in one of our processes. The logging does work well (except weird py4j errors which I supress) except when it's within multiprocessing. Also how do I supress the other errors that I always receive on DB - perhaps is there some guide on this? Thanks

 2023-04-17 11:12:32 [INFO] - Starting multiprocessing
2023-04-17 11:12:33 [INFO] - Exception while sending command.
Traceback (most recent call last):
  File "/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/clientserver.py", line 503, in send_command
    self.socket.sendall(command.encode("utf-8"))
ConnectionResetError: [Errno 104] Connection reset by peer
 
During handling of the above exception, another exception occurred:
 
Traceback (most recent call last):
  File "/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/clientserver.py", line 506, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending
Squares: [1, 4, 9, 16, 25]
import logging
import os
from joblib import Parallel, delayed
from time import sleep
 
def setup_logging():
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s [%(levelname)s] - %(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S")
    logging.getLogger().setLevel(logging.INFO)
    pyspark_log = logging.getLogger('pyspark')
    pyspark_log.setLevel(logging.WARNING)
    logging.getLogger("py4j").setLevel(logging.WARNING)
 
def calculate_square(number):
    sleep(1)  # Simulate a time-consuming task
    result = number ** 2
    logging.info(f"Square of {number} is {result}")
    return result
 
setup_logging()
logging.info(f"Starting multiprocessing")
# List of numbers to calculate squares
numbers = [1, 2, 3, 4, 5]
    
# Using joblib for multiprocessing
n_jobs = os.cpu_count()
results = Parallel(n_jobs=n_jobs)(delayed(calculate_square)(num) for num in numbers)
 
print(f"Squares: {results}")