You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now.
Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets?
Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:
class StyleDistLoss(nn.Module):
'''
style distribition loss of s and s'
'''
def __init__(self, pool_size):
super(StyleDistLoss, self).__init__()
self.pool_size = pool_size
if self.pool_size > 0:
self.num_style_batch = 0
self.style_batches = []
self.loss = nn.L1Loss()
def __call__(self, sc, st):
'''
return the standart Gaussian distribution loss of input
style source {sc} and style traget {st} which are respective to s and s' in the paper
'''
styles = []
if self.pool_size == 0:
styles.extend([sc, st])
else:
styles += self.style_batches
styles.extend([sc, st])
detach_sc = sc.clone().detach()
detach_st = st.clone().detach()
if self.num_style_batch + 2 < self.pool_size:
self.style_batches.extend([detach_sc, detach_st])
self.num_style_batch += 2
else:
random_idx = [x for x in range(self.num_style_batch)]
random.shuffle(random_idx)
self.style_batches[random_idx[0]] = detach_sc
self.style_batches[random_idx[1]] = detach_st
tensor_styles = torch.squeeze(torch.cat(styles, 0))
styles_mean = torch.mean(tensor_styles, dim=0)
tminuss = tensor_styles - styles_mean
cov = torch.mm(tminuss.t(), tminuss) / tensor_styles.shape[0]
std_cov = cov.diag(diagonal=0)
total_loss = self.loss(styles_mean, torch.zeros_like(styles_mean))
total_loss += self.loss(cov, torch.ones_like(cov))
total_loss += self.loss(std_cov, torch.ones_like(std_cov))
return total_loss
Could you please give me some advice? Thanks!
The text was updated successfully, but these errors were encountered:
Hi! I believe you have issues with lhe line total_loss += self.loss(cov, torch.ones_like(cov))
torch.ones_like(cov) returns a matrix filled with 1s, not an identity matrix. Therefore, your loss does not enforce the disentanglement of the components of the style vector. Try torch.eye instead.
Hi! I believe you have issues with lhe line total_loss += self.loss(cov, torch.ones_like(cov))
torch.ones_like(cov) returns a matrix filled with 1s, not an identity matrix. Therefore, your loss does not enforce the disentanglement of the components of the style vector. Try torch.eye instead.
Thank you for your advice, I will have a try! Besides, I have a question about training on wikiart datasets, because I have noticed that in inference step in your readme.md file, you use landscape images as content images while wikiart images as style images, so, whether I should use the same strategy in my training stage on the wikiart?
Sorry for the delayed reply. For style transfer model, both content and style images are being sampled from the wikiart dataset. We observe that trained in this manner, the model can be applied to the real images
Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now.
Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets?
Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:
Could you please give me some advice? Thanks!
The text was updated successfully, but these errors were encountered: