diff --git a/get_init.py b/get_init.py
index f991f0caa2977b91dc197cf578f978b205c6d4a9..5d1aefa5dc6ce5c827e569a4f83a314d8468419f 100644
--- a/get_init.py
+++ b/get_init.py
@@ -105,11 +105,11 @@ def get_init(spd, T, SST, lat, P, Rl, Rs, cskin, gust, L, tol, meth, qmeth):
     elif ((cskin == None) and (meth == "C30" or meth == "C35" or meth == "C40"
                                or meth == "ERA5")):
         cskin = 1
-    if ((gust == None) and (meth == "C30" or meth == "C35" or meth == "C40")):
+    if (np.all(gust == None) and (meth == "C30" or meth == "C35" or meth == "C40")):
         gust = [1, 1.2, 600]
-    elif ((gust == None) and (meth == "UA" or meth == "ERA5")):
+    elif (np.all(gust == None) and (meth == "UA" or meth == "ERA5")):
         gust = [1, 1, 1000]
-    elif (gust == None):
+    elif np.all(gust == None):
         gust = [1, 1.2, 800]
     elif (np.size(gust) < 3):
         sys.exit("gust input must be a 3x1 array")