changed discrim

This commit is contained in:
k 2025-09-11 17:07:15 -04:00
parent c74924ccea
commit 1a66f31048
2 changed files with 12 additions and 11 deletions

View File

@ -25,10 +25,10 @@ class Check:
self.flat = 128 * self.h * self.w
self.e1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
self.e2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.out = nn.Linear(self.flat, 2)
self.out = nn.Linear(self.flat, 1)
def __call__(self, x: Tensor) -> Tensor:
x = self.e1(x).relu()
x = self.e2(x).relu()
x = x.reshape(x.shape[0], -1)
return self.out(x).sigmoid()
return self.out(x)#.sigmoid()

View File

@ -11,8 +11,9 @@ mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.start_run(experiment_id=804883409598823668)
#hyper
BACH_SIZE=32
glr=1e-3
dlr=1e-3
BATCH_SIZE=BACH_SIZE
glr=2e-4
dlr=1e-5
epochs=100
@ -43,12 +44,12 @@ difOpt = nn.optim.AdamW(nn.state.get_parameters(dif), lr=dlr)
@TinyJit
def step_dis(x:Tensor):
Tensor.training = True
real = Tensor([1,0])
fake = Tensor([0,1])
real = Tensor.ones((BATCH_SIZE,1))
fake = Tensor.zeros((BACH_SIZE,1))
noise = Tensor.randn(BACH_SIZE, gen.ld)
fake_data = gen(noise).detach()
fake_loss = dif(fake_data).log_softmax().nll_loss(fake)
real_loss = dif(x).log_softmax().nll_loss(real)
fake_loss = dif(fake_data).binary_crossentropy_logits(fake)
real_loss = dif(x).binary_crossentropy_logits(real)
loss = (fake_loss + real_loss)/2
loss.backward()
difOpt.step()
@ -57,10 +58,10 @@ def step_dis(x:Tensor):
@TinyJit
def step_gen():
Tensor.training = True
real = Tensor([1,0])
real = Tensor.ones((BATCH_SIZE,1))
noise = Tensor.randn(BACH_SIZE, gen.ld)
fake_data = gen(noise).detach()
loss = dif(fake_data).log_softmax().nll_loss(real)
loss = dif(fake_data).binary_crossentropy_logits(real)
loss.backward()
genOpt.step()
return loss.numpy()
@ -88,7 +89,7 @@ for e in range(0,epochs):
dl /= (size/BACH_SIZE)
gl /= (size/BACH_SIZE)
if e%4==0:
if e%5==0:
noise = Tensor.randn(BACH_SIZE, gen.ld)
show.logSpec(gen(noise).numpy()[0][0],e)
#todo test on test data