Skip to content

Utilities

seed

Functions:

  • seed_everything

    Sets the seed for generating random number in Pytorch, numpy and Python.

seed_everything

seed_everything(seed: int) -> None

Sets the seed for generating random number in Pytorch, numpy and Python.

Parameters:

  • seed (int) –

    The desired seed.

Notes
  • You may also want to set torch.backends.cudnn.deterministic = True and torch.backends.cudnn.benchmark = False for full determinism on GPU.
Source code in tgm/util/seed.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def seed_everything(seed: int) -> None:
    """Sets the seed for generating random number in Pytorch, numpy and Python.

    Args:
        seed (int): The desired seed.

    Notes:
        - You may also want to set `torch.backends.cudnn.deterministic = True`
          and `torch.backends.cudnn.benchmark = False` for full determinism on GPU.
    """
    logger.debug('Seeding RNG with %d', seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

logging

Functions:

  • enable_logging

    Enable library-wide logging to stdout and optionally to a file.

  • log_gpu

    Function decorator to log GPU memory usage during a function call.

  • log_latency

    Function decorator to log latency at configurable log level.

  • log_metric

    Log a metric with optional epoch and structured JSON output.

  • log_metrics_dict

    Log a set of metric with optional epoch and structured JSON output.

enable_logging

enable_logging(
    *,
    console_log_level: int = INFO,
    file_log_level: int = DEBUG,
    log_file_path: str | Path | None = None,
) -> None

Enable library-wide logging to stdout and optionally to a file.

Parameters:

  • console_log_level (int, default: INFO ) –

    Logging level for console stream handler (default = logging.INFO).

  • file_log_level (int, default: DEBUG ) –

    Logging level for file handler if configured (default = logging.DEBUG).

  • log_file_path (Optional[str | Path], default: None ) –

    Optional path to a log file.

Source code in tgm/util/logging.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def enable_logging(
    *,
    console_log_level: int = logging.INFO,
    file_log_level: int = logging.DEBUG,
    log_file_path: str | Path | None = None,
) -> None:
    """Enable library-wide logging to stdout and optionally to a file.

    Args:
        console_log_level (int): Logging level for console stream handler (default = logging.INFO).
        file_log_level (int): Logging level for file handler if configured (default = logging.DEBUG).
        log_file_path (Optional[str | Path]): Optional path to a log file.
    """
    global _TGM_LOGGING_ENABLED
    _TGM_LOGGING_ENABLED = True

    logger = logging.getLogger('tgm')
    logger.handlers.clear()  # Clear existing handlers, making this idempotent

    console_formatter = logging.Formatter(
        '[%(asctime)s.%(msecs)03d] %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
    )
    console_handler = logging.StreamHandler()
    console_handler.setLevel(console_log_level)
    console_handler.setFormatter(console_formatter)

    handlers: List[logging.Handler] = [console_handler]
    if log_file_path is not None:
        file_formatter = logging.Formatter(
            '[%(asctime)s.%(msecs)03d] %(name)s - %(levelname)s '
            '[%(processName)s %(threadName)s %(name)s.%(funcName)s:%(lineno)d] '
            '%(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
        )
        file_handler = logging.FileHandler(filename=log_file_path, mode='a')
        file_handler.setLevel(file_log_level)
        file_handler.setFormatter(file_formatter)
        handlers.append(file_handler)

    for handler in handlers:
        logger.addHandler(handler)

    logger.setLevel(min(console_log_level, file_log_level))
    logger.propagate = False  # Don't spam user's root logger

log_gpu

log_gpu(
    _func: Callable | None = None, *, level: int = INFO
) -> Any

Function decorator to log GPU memory usage during a function call.

Logs human-readable info at level, and JSON-formatted debug log at DEBUG.

Usage
  • @log_gpu # Logs at logging.INFO
  • @log_gpu() # Logs at logging.INFO
  • @log_gpu(level=logging.DEBUG) # Logs at DEBUG (JSON included)
Source code in tgm/util/logging.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def log_gpu(_func: Callable | None = None, *, level: int = logging.INFO) -> Any:
    """Function decorator to log GPU memory usage during a function call.

    Logs human-readable info at `level`, and JSON-formatted debug log at DEBUG.

    Usage:
        - @log_gpu                       # Logs at logging.INFO
        - @log_gpu()                     # Logs at logging.INFO
        - @log_gpu(level=logging.DEBUG)  # Logs at DEBUG (JSON included)
    """

    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            if not _TGM_LOGGING_ENABLED:
                return func(*args, **kwargs)

            cuda_available = torch.cuda.is_available()
            if cuda_available:
                torch.cuda.reset_peak_memory_stats()
                start_mem = torch.cuda.memory_allocated() / (1024**2)
            else:
                start_mem = 0.0

            result = func(*args, **kwargs)

            if cuda_available:
                peak_mem = torch.cuda.max_memory_allocated() / (1024**2)
                mem_diff = peak_mem - start_mem
            else:
                peak_mem = mem_diff = 0.0

            util_logger.log(
                level,
                'Function %s GPU memory (CUDA available=%s) [MB]: peak=%.2f, alloc=%.2f',
                func.__name__,
                cuda_available,
                peak_mem,
                mem_diff,
            )

            if util_logger.isEnabledFor(logging.DEBUG):
                log_entry = {
                    'metric': f'{func.__name__} peak_gpu_mb',
                    'value': peak_mem,
                }
                util_logger.debug(json.dumps(log_entry))

                log_entry = {
                    'metric': f'{func.__name__} alloc_gpu_mb',
                    'value': mem_diff,
                }
                util_logger.debug(json.dumps(log_entry))
            return result

        return wrapper

    if _func is None:
        return decorator
    else:
        return decorator(_func)

log_latency

log_latency(
    _func: Callable | None = None, *, level: int = INFO
) -> Any

Function decorator to log latency at configurable log level.

Logs human-readable info at level, and JSON-formatted debug log at DEBUG.

Usage
  • @log_latency # Logs at logging.INFO
  • @log_latency() # Logs at logging.INFO
  • @log_latency=level=logging.DEBUG) # Logs at logging.DEBUG (JSON included)

Returns:

  • Any

    The output of calling func.

Source code in tgm/util/logging.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def log_latency(_func: Callable | None = None, *, level: int = logging.INFO) -> Any:
    """Function decorator to log latency at configurable log level.

    Logs human-readable info at `level`, and JSON-formatted debug log at DEBUG.

    Usage:
        - @log_latency                      # Logs at logging.INFO
        - @log_latency()                    # Logs at logging.INFO
        - @log_latency=level=logging.DEBUG) # Logs at logging.DEBUG (JSON included)

    Returns:
        The output of calling func.
    """

    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            if not _TGM_LOGGING_ENABLED:
                return func(*args, **kwargs)

            start_time = time.perf_counter()
            result = func(*args, **kwargs)
            latency = time.perf_counter() - start_time
            util_logger.log(
                level, 'Function %s executed in %.4fs', func.__name__, latency
            )

            if util_logger.isEnabledFor(logging.DEBUG):
                log_entry = {
                    'metric': f'{func.__name__} latency',
                    'value': latency,
                }
                util_logger.debug(json.dumps(log_entry))
            return result

        return wrapper

    # If _func is None, decorator was called with parens
    if _func is None:
        return decorator
    else:
        # Decorator used without parens
        return decorator(_func)

log_metric

log_metric(
    metric_name: str,
    metric_value: Any,
    *,
    epoch: int | None = None,
    level: int = INFO,
    extra: Dict[str, Any] | None = None,
    logger: Logger | None = None,
) -> None

Log a metric with optional epoch and structured JSON output.

Logs human-readable info at level, and JSON-formatted debug log at DEBUG.

Parameters:

  • metric_name (str) –

    Name of the metric to log.

  • metric_value (Any) –

    Value of the metric to log.

  • epoch (Optional[int], default: None ) –

    Optional epoch number.

  • level (int, default: INFO ) –

    Logging level for human-readable log (default INFO)

  • extra (Dict[str, Any], default: None ) –

    Optional dictionary of extra metadata to include in JSON.

  • logger (Optional[Logger], default: None ) –

    Logger to log to, defaults to tgm.util logger.

Source code in tgm/util/logging.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def log_metric(
    metric_name: str,
    metric_value: Any,
    *,
    epoch: int | None = None,
    level: int = logging.INFO,
    extra: Dict[str, Any] | None = None,
    logger: logging.Logger | None = None,
) -> None:
    """Log a metric with optional epoch and structured JSON output.

    Logs human-readable info at `level`, and JSON-formatted debug log at DEBUG.

    Args:
        metric_name (str): Name of the metric to log.
        metric_value (Any): Value of the metric to log.
        epoch (Optional[int]): Optional epoch number.
        level (int): Logging level for human-readable log (default INFO)
        extra (Dict[str, Any]): Optional dictionary of extra metadata to include in JSON.
        logger (Optional[logging.Logger]): Logger to log to, defaults to tgm.util logger.
    """
    if not _TGM_LOGGING_ENABLED:
        return

    logger = logger or util_logger

    display_value = (
        round(metric_value, 4) if isinstance(metric_value, float) else metric_value
    )
    parts = []
    if epoch is not None:
        parts.append(f'Epoch={epoch:02d}')
    parts.append(f'{metric_name}={display_value}')
    msg = ' '.join(parts)
    logger.log(level, msg)

    if logger.isEnabledFor(logging.DEBUG):
        if epoch is not None:
            metric_name += f' epoch {epoch}'
        log_entry = {'metric': metric_name, 'value': metric_value}
        if extra is not None:
            log_entry.update(extra)
        logger.debug(json.dumps(log_entry))

log_metrics_dict

log_metrics_dict(
    metrics_dict: Dict[str, Any],
    *,
    epoch: int | None = None,
    level: int = INFO,
    extra: Dict[str, Any] | None = None,
    logger: Logger | None = None,
) -> None

Log a set of metric with optional epoch and structured JSON output.

Logs human-readable info at level, and JSON-formatted debug log at DEBUG.

Note: This is equivalent to calling log_metric for each key-value pair.

Parameters:

  • metrics_dict (Dict[str, Any]) –

    Dictionary of metric_name: metric_value pairs.

  • epoch (Optional[int], default: None ) –

    Optional epoch number.

  • level (int, default: INFO ) –

    Logging level for human-readable log (default INFO)

  • extra (Dict[str, Any], default: None ) –

    Optional dictionary of extra metadata to include in JSON.

  • logger (Optional[Logger], default: None ) –

    Logger to log to, defaults to tgm.util logger.

Source code in tgm/util/logging.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def log_metrics_dict(
    metrics_dict: Dict[str, Any],
    *,
    epoch: int | None = None,
    level: int = logging.INFO,
    extra: Dict[str, Any] | None = None,
    logger: logging.Logger | None = None,
) -> None:
    """Log a set of metric with optional epoch and structured JSON output.

    Logs human-readable info at `level`, and JSON-formatted debug log at DEBUG.

    Note: This is equivalent to calling log_metric for each key-value pair.

    Args:
        metrics_dict (Dict[str, Any]): Dictionary of metric_name: metric_value pairs.
        epoch (Optional[int]): Optional epoch number.
        level (int): Logging level for human-readable log (default INFO)
        extra (Dict[str, Any]): Optional dictionary of extra metadata to include in JSON.
        logger (Optional[logging.Logger]): Logger to log to, defaults to tgm.util logger.
    """
    for metric_name, metric_value in metrics_dict.items():
        log_metric(
            metric_name,
            metric_value,
            epoch=epoch,
            level=level,
            extra=extra,
            logger=logger,
        )