yolov1_basic.py 498 B

12345678910111213141516
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Conv(nn.Module):
  5. def __init__(self, in_dim, out_dim, k, s=1, p=0, d=1, g=1, act=True):
  6. super(Conv, self).__init__()
  7. self.convs = nn.Sequential(
  8. nn.Conv2d(in_dim, out_dim, k, stride=s, padding=p, dilation=d, groups=g),
  9. nn.BatchNorm2d(out_dim),
  10. nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity()
  11. )
  12. def forward(self, x):
  13. return self.convs(x)