Skip to content

Instantly share code, notes, and snippets.

@dsevero
Last active March 25, 2021 05:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsevero/8a885dbe0a547507a8e20ba922ffdbd6 to your computer and use it in GitHub Desktop.
Save dsevero/8a885dbe0a547507a8e20ba922ffdbd6 to your computer and use it in GitHub Desktop.
time profiling with contextmanager
from time import time
from contextlib import contextmanager
import json
import torch
import logging
logging.basicConfig(stream=sys.stdout,
level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s')
logger = logging.getLogger()
def log(data: dict):
logger.info(json.dumps(data))
@contextmanager
def log_runtime(**kwargs):
start = time()
yield
log({**kwargs, 'dt': time() - start, 'cuda': False})
@contextmanager
def log_cuda_runtime(**kwargs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
yield
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
log({**kwargs, 'dt': elapsed_time_ms, 'cuda': True})
# Examples
with log_runtime(foo='bar', something='else'):
... # your code here
@log_cuda_runtime(func='my_cuda_func')
def my_cuda_func(...):
...
model = ... # trained model
model = log_cuda_runtime(func='model')(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment