Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/model/diffusion/sefi_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ namespace SefiImage {
auto semantic_embedder = std::dynamic_pointer_cast<SefiTimestepEmbedding>(blocks["semantic_embedder"]);
auto texture_embedder = std::dynamic_pointer_cast<SefiTimestepEmbedding>(blocks["texture_embedder"]);

auto sem_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep_sem, timestep_guidance_in_dim, 10000, 1.f);
auto tex_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep_tex, timestep_guidance_in_dim, 10000, 1.f);
auto sem_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep_sem, (int)timestep_guidance_in_dim, 10000, 1.f);
auto tex_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep_tex, (int)timestep_guidance_in_dim, 10000, 1.f);
auto sem_emb = semantic_embedder->forward(ctx, sem_proj);
auto tex_emb = texture_embedder->forward(ctx, tex_proj);
return ggml_concat(ctx->ggml_ctx, sem_emb, tex_emb, 0);
Expand Down
30 changes: 14 additions & 16 deletions src/model/vae/wan_vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,24 @@ namespace WAN {
}
};


class Conv2dBut3d : public Conv2d {
public:
using Conv2d::Conv2d;

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* x_swapped = ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2);
x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped);
x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped);

ggml_tensor* out = Conv2d::forward(ctx, x_swapped);

ggml_tensor* out_swapped = ggml_permute(ctx->ggml_ctx, out, 0, 1, 3, 2);

out_swapped = ggml_cont(ctx->ggml_ctx, out_swapped);

return out_swapped;
}
};


class Resample : public GGMLBlock {
protected:
int64_t dim;
Expand Down Expand Up @@ -397,7 +395,7 @@ namespace WAN {
ggml_tensor* h = x;
if (in_dim != out_dim) {
if (is_2D) {
auto shortcut = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["shortcut"]);
auto shortcut = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["shortcut"]);

h = shortcut->forward(ctx, x);
} else {
Expand Down Expand Up @@ -843,9 +841,9 @@ namespace WAN {
}

// init block
if(is_2D){
if (is_2D) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
}else{
} else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}

Expand Down Expand Up @@ -890,7 +888,7 @@ namespace WAN {
}

// output blocks
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
int64_t final_dim = wan2_2 ? 12 : 3;
// head.1 is nn.SiLU()
if (is_2D) {
Expand Down Expand Up @@ -919,7 +917,7 @@ namespace WAN {
// conv1
if (is_2D) {
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
x = conv1->forward(ctx, x);
x = conv1->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
Expand Down Expand Up @@ -1011,7 +1009,7 @@ namespace WAN {
int num_res_blocks = 2;
std::vector<bool> temperal_upsample = {true, true, false};
std::vector<bool> temperal_downsample = {false, true, true};
bool is_2D = false;
bool is_2D = false;

int _conv_num = 33;
int _conv_idx = 0;
Expand Down Expand Up @@ -1040,8 +1038,8 @@ namespace WAN {
_enc_conv_num = 26;
}

if(is_2D){
temperal_upsample = {false, false, false};
if (is_2D) {
temperal_upsample = {false, false, false};
temperal_downsample = {false, false, false};
}

Expand Down Expand Up @@ -1149,7 +1147,7 @@ namespace WAN {
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
out = conv1->forward(ctx, out);
} else {
out = conv1->forward(ctx, out);
out = conv1->forward(ctx, out);
}
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
// sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu");
Expand All @@ -1170,9 +1168,9 @@ namespace WAN {

int64_t iter_ = z->ne[2];
auto x = z;
if(is_2D){
auto conv2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv2"]);
x = conv2->forward(ctx, z);
if (is_2D) {
auto conv2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv2"]);
x = conv2->forward(ctx, z);
} else {
x = conv2->forward(ctx, z);
}
Expand Down
3 changes: 1 addition & 2 deletions src/name_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,7 @@ std::string convert_diffusers_to_original_wan_vae(std::string name) {
if (name.find(".residual.") != std::string::npos) {
replace_with_name_map(name, resnet_name_map);
}



return name;
}

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ struct LogitNormalScheduler : SigmaScheduler {
}
}
if (image_seq_len > 0 && resolution_aware) {
mean += 0.5 * std::log(static_cast<float>(image_seq_len) / static_cast<float>(known_seq_len));
mean += 0.5f * std::log(static_cast<float>(image_seq_len) / static_cast<float>(known_seq_len));
}
}

Expand Down Expand Up @@ -735,7 +735,7 @@ struct LogitNormalScheduler : SigmaScheduler {
float t = static_cast<float>(i) / static_cast<float>(n);

// ndtri(1-t) == -ndtri(t)
float z = -ndtri(t);
float z = static_cast<float>(-ndtri(t));

float y = mean + std * z;

Expand Down
Loading