changed discrim
This commit is contained in:
parent
c74924ccea
commit
1a66f31048
4
model.py
4
model.py
@ -25,10 +25,10 @@ class Check:
|
|||||||
self.flat = 128 * self.h * self.w
|
self.flat = 128 * self.h * self.w
|
||||||
self.e1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
|
self.e1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
|
||||||
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
||||||
self.out = nn.Linear(self.flat, 2)
|
self.out = nn.Linear(self.flat, 1)
|
||||||
|
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
x = self.e1(x).relu()
|
x = self.e1(x).relu()
|
||||||
x = self.e2(x).relu()
|
x = self.e2(x).relu()
|
||||||
x = x.reshape(x.shape[0], -1)
|
x = x.reshape(x.shape[0], -1)
|
||||||
return self.out(x).sigmoid()
|
return self.out(x)#.sigmoid()
|
||||||
|
|||||||
19
train.py
19
train.py
@ -11,8 +11,9 @@ mlflow.set_tracking_uri("http://127.0.0.1:5000")
|
|||||||
mlflow.start_run(experiment_id=804883409598823668)
|
mlflow.start_run(experiment_id=804883409598823668)
|
||||||
#hyper
|
#hyper
|
||||||
BACH_SIZE=32
|
BACH_SIZE=32
|
||||||
glr=1e-3
|
BATCH_SIZE=BACH_SIZE
|
||||||
dlr=1e-3
|
glr=2e-4
|
||||||
|
dlr=1e-5
|
||||||
epochs=100
|
epochs=100
|
||||||
|
|
||||||
|
|
||||||
@ -43,12 +44,12 @@ difOpt = nn.optim.AdamW(nn.state.get_parameters(dif), lr=dlr)
|
|||||||
@TinyJit
|
@TinyJit
|
||||||
def step_dis(x:Tensor):
|
def step_dis(x:Tensor):
|
||||||
Tensor.training = True
|
Tensor.training = True
|
||||||
real = Tensor([1,0])
|
real = Tensor.ones((BATCH_SIZE,1))
|
||||||
fake = Tensor([0,1])
|
fake = Tensor.zeros((BACH_SIZE,1))
|
||||||
noise = Tensor.randn(BACH_SIZE, gen.ld)
|
noise = Tensor.randn(BACH_SIZE, gen.ld)
|
||||||
fake_data = gen(noise).detach()
|
fake_data = gen(noise).detach()
|
||||||
fake_loss = dif(fake_data).log_softmax().nll_loss(fake)
|
fake_loss = dif(fake_data).binary_crossentropy_logits(fake)
|
||||||
real_loss = dif(x).log_softmax().nll_loss(real)
|
real_loss = dif(x).binary_crossentropy_logits(real)
|
||||||
loss = (fake_loss + real_loss)/2
|
loss = (fake_loss + real_loss)/2
|
||||||
loss.backward()
|
loss.backward()
|
||||||
difOpt.step()
|
difOpt.step()
|
||||||
@ -57,10 +58,10 @@ def step_dis(x:Tensor):
|
|||||||
@TinyJit
|
@TinyJit
|
||||||
def step_gen():
|
def step_gen():
|
||||||
Tensor.training = True
|
Tensor.training = True
|
||||||
real = Tensor([1,0])
|
real = Tensor.ones((BATCH_SIZE,1))
|
||||||
noise = Tensor.randn(BACH_SIZE, gen.ld)
|
noise = Tensor.randn(BACH_SIZE, gen.ld)
|
||||||
fake_data = gen(noise).detach()
|
fake_data = gen(noise).detach()
|
||||||
loss = dif(fake_data).log_softmax().nll_loss(real)
|
loss = dif(fake_data).binary_crossentropy_logits(real)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
genOpt.step()
|
genOpt.step()
|
||||||
return loss.numpy()
|
return loss.numpy()
|
||||||
@ -88,7 +89,7 @@ for e in range(0,epochs):
|
|||||||
|
|
||||||
dl /= (size/BACH_SIZE)
|
dl /= (size/BACH_SIZE)
|
||||||
gl /= (size/BACH_SIZE)
|
gl /= (size/BACH_SIZE)
|
||||||
if e%4==0:
|
if e%5==0:
|
||||||
noise = Tensor.randn(BACH_SIZE, gen.ld)
|
noise = Tensor.randn(BACH_SIZE, gen.ld)
|
||||||
show.logSpec(gen(noise).numpy()[0][0],e)
|
show.logSpec(gen(noise).numpy()[0][0],e)
|
||||||
#todo test on test data
|
#todo test on test data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user